aarch64: Add more lowerings for the CLIF fma (#6150)

This commit adds new lowerings to the AArch64 backend of the
element-based `fmla` and `fmls` instructions. These instructions have
one of the multiplicands as an implicit broadcast of a single lane of
another register and can help remove `shuffle` or `dup` instructions
that would otherwise be used to implement them.
This commit is contained in:
Alex Crichton
2023-04-05 12:22:55 -05:00
committed by GitHub
parent bf741955f0
commit 967543eb43
8 changed files with 321 additions and 15 deletions

View File

@@ -651,6 +651,16 @@
(rm Reg)
(size VectorSize))
;; A vector ALU op modifying a source register.
(VecFmlaElem
(alu_op VecALUModOp)
(rd WritableReg)
(ri Reg)
(rn Reg)
(rm Reg)
(size VectorSize)
(idx u8))
;; Vector two register miscellaneous instruction.
(VecMisc
(op VecMisc2)
@@ -1850,7 +1860,7 @@
(_ Unit (emit (MInst.FpuRR op size dst src))))
dst))
;; Helper for emitting `MInst.VecRRR` instructions which use three registers,
;; Helper for emitting `MInst.VecRRRMod` instructions which use three registers,
;; one of which is both source and output.
(decl vec_rrr_mod (VecALUModOp Reg Reg Reg VectorSize) Reg)
(rule (vec_rrr_mod op src1 src2 src3 size)
@@ -1858,6 +1868,14 @@
(_1 Unit (emit (MInst.VecRRRMod op dst src1 src2 src3 size))))
dst))
;; Helper for emitting `MInst.VecFmlaElem` instructions which use three registers,
;; one of which is both source and output.
(decl vec_fmla_elem (VecALUModOp Reg Reg Reg VectorSize u8) Reg)
(rule (vec_fmla_elem op src1 src2 src3 size idx)
(let ((dst WritableReg (temp_writable_reg $I8X16))
(_1 Unit (emit (MInst.VecFmlaElem op dst src1 src2 src3 size idx))))
dst))
(decl fpu_rri (FPUOpRI Reg) Reg)
(rule (fpu_rri op src)
(let ((dst WritableReg (temp_writable_reg $F64))

View File

@@ -2914,6 +2914,45 @@ impl MachInstEmit for Inst {
};
sink.put4(enc_vec_rrr(top11 | q << 9, rm, bit15_10, rn, rd));
}
&Inst::VecFmlaElem {
rd,
ri,
rn,
rm,
alu_op,
size,
idx,
} => {
let rd = allocs.next_writable(rd);
let ri = allocs.next(ri);
debug_assert_eq!(rd.to_reg(), ri);
let rn = allocs.next(rn);
let rm = allocs.next(rm);
let idx = u32::from(idx);
let (q, _size) = size.enc_size();
let o2 = match alu_op {
VecALUModOp::Fmla => 0b0,
VecALUModOp::Fmls => 0b1,
_ => unreachable!(),
};
let (h, l) = match size {
VectorSize::Size32x4 => {
assert!(idx < 4);
(idx >> 1, idx & 1)
}
VectorSize::Size64x2 => {
assert!(idx < 2);
(idx, 0)
}
_ => unreachable!(),
};
let top11 = 0b000_011111_00 | (q << 9) | (size.enc_float_size() << 1) | l;
let bit15_10 = 0b000100 | (o2 << 4) | (h << 1);
sink.put4(enc_vec_rrr(top11, rm, bit15_10, rn, rd));
}
&Inst::VecLoadReplicate {
rd,
rn,

View File

@@ -812,7 +812,7 @@ fn aarch64_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut Operan
collector.reg_use(rn);
collector.reg_use(rm);
}
&Inst::VecRRRMod { rd, ri, rn, rm, .. } => {
&Inst::VecRRRMod { rd, ri, rn, rm, .. } | &Inst::VecFmlaElem { rd, ri, rn, rm, .. } => {
collector.reg_reuse_def(rd, 1); // `rd` == `ri`.
collector.reg_use(ri);
collector.reg_use(rn);
@@ -2171,6 +2171,26 @@ impl Inst {
let rm = pretty_print_vreg_vector(rm, size, allocs);
format!("{} {}, {}, {}, {}", op, rd, ri, rn, rm)
}
&Inst::VecFmlaElem {
rd,
ri,
rn,
rm,
alu_op,
size,
idx,
} => {
let (op, size) = match alu_op {
VecALUModOp::Fmla => ("fmla", size),
VecALUModOp::Fmls => ("fmls", size),
_ => unreachable!(),
};
let rd = pretty_print_vreg_vector(rd.to_reg(), size, allocs);
let ri = pretty_print_vreg_vector(ri, size, allocs);
let rn = pretty_print_vreg_vector(rn, size, allocs);
let rm = pretty_print_vreg_element(rm, idx.into(), size.lane_size(), allocs);
format!("{} {}, {}, {}, {}", op, rd, ri, rn, rm)
}
&Inst::VecRRRLong {
rd,
rn,

View File

@@ -513,17 +513,62 @@
;;;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty @ (multi_lane _ _) (fma x y z)))
(vec_rrr_mod (VecALUModOp.Fmla) z x y (vector_size ty)))
(rule (lower (has_type (ty_scalar_float ty) (fma x y z)))
(fpu_rrrr (FPUOp3.MAdd) (scalar_size ty) x y z))
(rule 1 (lower (has_type ty @ (multi_lane _ _) (fma (fneg x) y z)))
(vec_rrr_mod (VecALUModOp.Fmls) z x y (vector_size ty)))
;; Delegate vector-based lowerings to helpers below
(rule 1 (lower (has_type ty @ (multi_lane _ _) (fma x y z)))
(lower_fmla (VecALUModOp.Fmla) x y z (vector_size ty)))
(rule 2 (lower (has_type ty @ (multi_lane _ _) (fma x (fneg y) z)))
(vec_rrr_mod (VecALUModOp.Fmls) z x y (vector_size ty)))
;; Lowers a fused-multiply-add operation handling various forms of the
;; instruction to get maximal coverage of what's available on AArch64.
(decl lower_fmla (VecALUModOp Value Value Value VectorSize) Reg)
(rule 3 (lower (has_type (ty_scalar_float ty) (fma x y z)))
(fpu_rrrr (FPUOp3.MAdd) (scalar_size ty) x y z))
;; Base case, emit the op requested.
(rule (lower_fmla op x y z size)
(vec_rrr_mod op z x y size))
;; Special case: if one of the multiplicands are a splat then the element-based
;; fma can be used instead with 0 as the element index.
(rule 1 (lower_fmla op (splat x) y z size)
(vec_fmla_elem op z y x size 0))
(rule 2 (lower_fmla op x (splat y) z size)
(vec_fmla_elem op z x y size 0))
;; Special case: if one of the multiplicands is a shuffle to broadcast a
;; single element of a vector then the element-based fma can be used like splat
;; above.
;;
;; Note that in Cranelift shuffle always has i8x16 inputs and outputs so
;; a `bitcast` is matched here explicitly since that's the main way a shuffle
;; output will be fed into this instruction.
(rule 3 (lower_fmla op (bitcast _ (shuffle x x (shuffle32_from_imm n n n n))) y z size @ (VectorSize.Size32x4))
(if-let $true (u64_lt n 4))
(vec_fmla_elem op z y x size n))
(rule 4 (lower_fmla op x (bitcast _ (shuffle y y (shuffle32_from_imm n n n n))) z size @ (VectorSize.Size32x4))
(if-let $true (u64_lt n 4))
(vec_fmla_elem op z x y size n))
(rule 3 (lower_fmla op (bitcast _ (shuffle x x (shuffle64_from_imm n n))) y z size @ (VectorSize.Size64x2))
(if-let $true (u64_lt n 2))
(vec_fmla_elem op z y x size n))
(rule 4 (lower_fmla op x (bitcast _ (shuffle y y (shuffle64_from_imm n n))) z size @ (VectorSize.Size64x2))
(if-let $true (u64_lt n 2))
(vec_fmla_elem op z x y size n))
;; Special case: if one of the multiplicands is `fneg` then peel that away,
;; reverse the operation being performed, and then recurse on `lower_fmla`
;; again to generate the actual instruction.
;;
;; Note that these are the highest priority cases for `lower_fmla` to peel
;; away as many `fneg` operations as possible.
(rule 5 (lower_fmla op (fneg x) y z size)
(lower_fmla (neg_fmla op) x y z size))
(rule 6 (lower_fmla op x (fneg y) z size)
(lower_fmla (neg_fmla op) x y z size))
(decl neg_fmla (VecALUModOp) VecALUModOp)
(rule (neg_fmla (VecALUModOp.Fmla)) (VecALUModOp.Fmls))
(rule (neg_fmla (VecALUModOp.Fmls)) (VecALUModOp.Fmla))
;;;; Rules for `fcopysign` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

View File

@@ -708,8 +708,6 @@
(decl u8_as_i32 (u8) i32)
(extern constructor u8_as_i32 u8_as_i32)
(convert u8 u64 u8_as_u64)
(decl convert_valueregs_reg (ValueRegs) Reg)
(rule (convert_valueregs_reg x)
(value_regs_get x 0))
@@ -1283,7 +1281,7 @@
(rule
(load_imm12 x)
(rv_addi (zero_reg) (imm12_const x)))
;; for load immediate
(decl imm_from_bits (u64) Imm12)
(extern constructor imm_from_bits imm_from_bits)
@@ -1509,7 +1507,7 @@
(_ Unit (emit (MInst.Cltz leading sum step tmp rs ty))))
sum))
;; Extends an integer if it is smaller than 64 bits.
(decl ext_int_if_need (bool ValueRegs Type) ValueRegs)
;;; For values smaller than 64 bits, we need to extend them to 64 bits
@@ -2117,7 +2115,7 @@
(reuslt VecWritableReg (vec_writable_clone dst))
(_ Unit (emit (MInst.Select dst ty c x y))))
(vec_writable_to_regs reuslt)))
;; Parameters are "intcc compare_a compare_b rs1 rs2".
(decl gen_select_reg (IntCC Reg Reg Reg Reg) Reg)
(extern constructor gen_select_reg gen_select_reg)

View File

@@ -82,6 +82,7 @@
(decl pure u8_as_u64 (u8) u64)
(extern constructor u8_as_u64 u8_as_u64)
(convert u8 u64 u8_as_u64)
(decl pure u16_as_u64 (u16) u64)
(extern constructor u16_as_u64 u16_as_u64)