aarch64: Specialize constant vector shifts (#5976)
* aarch64: Specialize constant vector shifts This commit adds special lowering rules for vector-shifts-by-constant-amounts to use dedicated instructions which cuts down on the codegen here quite a bit for constant values. * Fix codegen for 0-shift-rights * Special-case zero left-shifts as well * Remove left-shift special case
This commit is contained in:
@@ -2633,6 +2633,18 @@
|
||||
(decl ushl (Reg Reg VectorSize) Reg)
|
||||
(rule (ushl x y size) (vec_rrr (VecALUOp.Ushl) x y size))
|
||||
|
||||
;; Helpers for generating `ushl` instructions.
|
||||
(decl ushl_vec_imm (Reg u8 VectorSize) Reg)
|
||||
(rule (ushl_vec_imm x amt size) (vec_shift_imm (VecShiftImmOp.Shl) amt x size))
|
||||
|
||||
;; Helpers for generating `ushr` instructions.
|
||||
(decl ushr_vec_imm (Reg u8 VectorSize) Reg)
|
||||
(rule (ushr_vec_imm x amt size) (vec_shift_imm (VecShiftImmOp.Ushr) amt x size))
|
||||
|
||||
;; Helpers for generating `sshr` instructions.
|
||||
(decl sshr_vec_imm (Reg u8 VectorSize) Reg)
|
||||
(rule (sshr_vec_imm x amt size) (vec_shift_imm (VecShiftImmOp.Sshr) amt x size))
|
||||
|
||||
;; Helpers for generating `rotr` instructions.
|
||||
|
||||
(decl a64_rotr (Type Reg Reg) Reg)
|
||||
@@ -3321,7 +3333,7 @@
|
||||
dst))
|
||||
(rule (fcopy_sign x y ty @ (multi_lane _ _))
|
||||
(let ((dst WritableReg (temp_writable_reg $I8X16))
|
||||
(tmp Reg (vec_shift_imm (VecShiftImmOp.Ushr) (max_shift (lane_type ty)) y (vector_size ty)))
|
||||
(tmp Reg (ushr_vec_imm y (max_shift (lane_type ty)) (vector_size ty)))
|
||||
(_ Unit (emit (MInst.VecShiftImmMod (VecShiftImmModOp.Sli) dst x tmp (vector_size ty) (max_shift (lane_type ty))))))
|
||||
dst))
|
||||
|
||||
|
||||
@@ -352,10 +352,8 @@
|
||||
(let ((one Reg (splat_const 1 (VectorSize.Size64x2)))
|
||||
(c Reg (orr_vec x y (VectorSize.Size64x2)))
|
||||
(c Reg (and_vec c one (VectorSize.Size64x2)))
|
||||
(x Reg (vec_shift_imm (VecShiftImmOp.Ushr) 1 x
|
||||
(VectorSize.Size64x2)))
|
||||
(y Reg (vec_shift_imm (VecShiftImmOp.Ushr) 1 y
|
||||
(VectorSize.Size64x2)))
|
||||
(x Reg (ushr_vec_imm x 1 (VectorSize.Size64x2)))
|
||||
(y Reg (ushr_vec_imm y 1 (VectorSize.Size64x2)))
|
||||
(sum Reg (add_vec x y (VectorSize.Size64x2))))
|
||||
(add_vec c sum (VectorSize.Size64x2))))
|
||||
|
||||
@@ -1291,11 +1289,16 @@
|
||||
(csel (Cond.Ne) lo_lshift maybe_hi)))))
|
||||
|
||||
;; Shift for vector types.
|
||||
(rule -2 (lower (has_type (ty_vec128 ty) (ishl x y)))
|
||||
(rule -3 (lower (has_type (ty_vec128 ty) (ishl x y)))
|
||||
(let ((size VectorSize (vector_size ty))
|
||||
(masked_shift_amt Reg (and_imm $I32 y (shift_mask ty)))
|
||||
(shift Reg (vec_dup masked_shift_amt size)))
|
||||
(sshl x shift size)))
|
||||
(rule -2 (lower (has_type (ty_vec128 ty) (ishl x (iconst (u64_from_imm64 n)))))
|
||||
(ushl_vec_imm x (shift_masked_imm ty n) (vector_size ty)))
|
||||
|
||||
(decl pure shift_masked_imm (Type u64) u8)
|
||||
(extern constructor shift_masked_imm shift_masked_imm)
|
||||
|
||||
;; Helper function to emit a shift operation with the opcode specified and
|
||||
;; the output type specified. The `Reg` provided is shifted by the `Value`
|
||||
@@ -1351,11 +1354,20 @@
|
||||
(lower_ushr128 x (value_regs_get y 0)))
|
||||
|
||||
;; Vector shifts.
|
||||
(rule -2 (lower (has_type (ty_vec128 ty) (ushr x y)))
|
||||
;;
|
||||
;; Note that for constant shifts a 0-width shift can't be emitted so it's
|
||||
;; special cased to pass through the input as-is since a 0-shift doesn't modify
|
||||
;; the input anyway.
|
||||
(rule -4 (lower (has_type (ty_vec128 ty) (ushr x y)))
|
||||
(let ((size VectorSize (vector_size ty))
|
||||
(masked_shift_amt Reg (and_imm $I32 y (shift_mask ty)))
|
||||
(shift Reg (vec_dup (sub $I64 (zero_reg) masked_shift_amt) size)))
|
||||
(ushl x shift size)))
|
||||
(rule -3 (lower (has_type (ty_vec128 ty) (ushr x (iconst (u64_from_imm64 n)))))
|
||||
(ushr_vec_imm x (shift_masked_imm ty n) (vector_size ty)))
|
||||
(rule -2 (lower (has_type (ty_vec128 ty) (ushr x (iconst (u64_from_imm64 n)))))
|
||||
(if-let 0 (shift_masked_imm ty n))
|
||||
x)
|
||||
|
||||
;; lsr lo_rshift, src_lo, amt
|
||||
;; lsr hi_rshift, src_hi, amt
|
||||
@@ -1387,7 +1399,7 @@
|
||||
;;;; Rules for `sshr` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
|
||||
|
||||
;; Shift for i8/i16/i32.
|
||||
(rule -2 (lower (has_type (fits_in_32 ty) (sshr x y)))
|
||||
(rule -4 (lower (has_type (fits_in_32 ty) (sshr x y)))
|
||||
(do_shift (ALUOp.Asr) ty (put_in_reg_sext32 x) y))
|
||||
|
||||
;; Shift for i64.
|
||||
@@ -1400,12 +1412,20 @@
|
||||
|
||||
;; Vector shifts.
|
||||
;;
|
||||
;; Note that right shifts are implemented with a negative left shift.
|
||||
(rule -1 (lower (has_type (ty_vec128 ty) (sshr x y)))
|
||||
;; Note that right shifts are implemented with a negative left shift. Also note
|
||||
;; that for constant shifts a 0-width shift can't be emitted so it's special
|
||||
;; cased to pass through the input as-is since a 0-shift doesn't modify the
|
||||
;; input anyway.
|
||||
(rule -3 (lower (has_type (ty_vec128 ty) (sshr x y)))
|
||||
(let ((size VectorSize (vector_size ty))
|
||||
(masked_shift_amt Reg (and_imm $I32 y (shift_mask ty)))
|
||||
(shift Reg (vec_dup (sub $I64 (zero_reg) masked_shift_amt) size)))
|
||||
(sshl x shift size)))
|
||||
(rule -2 (lower (has_type (ty_vec128 ty) (sshr x (iconst (u64_from_imm64 n)))))
|
||||
(sshr_vec_imm x (shift_masked_imm ty n) (vector_size ty)))
|
||||
(rule -1 (lower (has_type (ty_vec128 ty) (sshr x (iconst (u64_from_imm64 n)))))
|
||||
(if-let 0 (shift_masked_imm ty n))
|
||||
x)
|
||||
|
||||
;; lsr lo_rshift, src_lo, amt
|
||||
;; asr hi_rshift, src_hi, amt
|
||||
@@ -2452,7 +2472,7 @@
|
||||
(let (
|
||||
;; Replicate the MSB of each of the 16 byte lanes across
|
||||
;; the whole lane (sshr is an arithmetic right shift).
|
||||
(shifted Reg (vec_shift_imm (VecShiftImmOp.Sshr) 7 vec (VectorSize.Size8x16)))
|
||||
(shifted Reg (sshr_vec_imm vec 7 (VectorSize.Size8x16)))
|
||||
;; Bitwise-and with a mask
|
||||
;; `0x80402010_08040201_80402010_08040201` to get the bit
|
||||
;; in the proper location for each group of 8 lanes.
|
||||
@@ -2476,7 +2496,7 @@
|
||||
(let (
|
||||
;; Replicate the MSB of each of the 8 16-bit lanes across
|
||||
;; the whole lane (sshr is an arithmetic right shift).
|
||||
(shifted Reg (vec_shift_imm (VecShiftImmOp.Sshr) 15 vec (VectorSize.Size16x8)))
|
||||
(shifted Reg (sshr_vec_imm vec 15 (VectorSize.Size16x8)))
|
||||
;; Bitwise-and with a mask
|
||||
;; `0x0080_0040_0020_0010_0008_0004_0002_0001` to get the
|
||||
;; bit in the proper location for each group of 4 lanes.
|
||||
@@ -2489,7 +2509,7 @@
|
||||
(let (
|
||||
;; Replicate the MSB of each of the 4 32-bit lanes across
|
||||
;; the whole lane (sshr is an arithmetic right shift).
|
||||
(shifted Reg (vec_shift_imm (VecShiftImmOp.Sshr) 31 vec (VectorSize.Size32x4)))
|
||||
(shifted Reg (sshr_vec_imm vec 31 (VectorSize.Size32x4)))
|
||||
;; Bitwise-and with a mask
|
||||
;; `0x00000008_00000004_00000002_00000001` to get the bit
|
||||
;; in the proper location for each group of 4 lanes.
|
||||
|
||||
@@ -806,4 +806,8 @@ impl Context for IsleContext<'_, '_, MInst, AArch64Backend> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn shift_masked_imm(&mut self, ty: Type, imm: u64) -> u8 {
|
||||
(imm as u8) & ((ty.lane_bits() - 1) as u8)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user