x64: Add native lowering for scalar fma (#4539)

Use `vfmadd213{ss,sd}` for these lowerings.
This commit is contained in:
Afonso Bordado
2022-08-11 23:48:16 +01:00
committed by GitHub
parent 755cd4311e
commit 3ea1813173
10 changed files with 124 additions and 6 deletions

View File

@@ -1095,7 +1095,9 @@
(extern extractor cc_nz_or_z cc_nz_or_z)
(type AvxOpcode extern
(enum Vfmadd213ps
(enum Vfmadd213ss
Vfmadd213sd
Vfmadd213ps
Vfmadd213pd))
(type Avx512Opcode extern
@@ -1389,6 +1391,9 @@
(decl use_popcnt () Type)
(extern extractor use_popcnt use_popcnt)
(decl use_fma () Type)
(extern extractor use_fma use_fma)
;;;; Helpers for Merging and Sinking Immediates/Loads ;;;;;;;;;;;;;;;;;;;;;;;;;
;; Extract a constant `Imm8Reg.Imm8` from a value operand.
@@ -2935,6 +2940,16 @@
dst))))
dst))
;; Helper for creating `vfmadd213ss` instructions.
(decl x64_vfmadd213ss (Xmm Xmm XmmMem) Xmm)
(rule (x64_vfmadd213ss x y z)
(xmm_rmr_vex (AvxOpcode.Vfmadd213ss) x y z))
;; Helper for creating `vfmadd213sd` instructions.
(decl x64_vfmadd213sd (Xmm Xmm XmmMem) Xmm)
(rule (x64_vfmadd213sd x y z)
(xmm_rmr_vex (AvxOpcode.Vfmadd213sd) x y z))
;; Helper for creating `vfmadd213ps` instructions.
(decl x64_vfmadd213ps (Xmm Xmm XmmMem) Xmm)
(rule (x64_vfmadd213ps x y z)

View File

@@ -1383,6 +1383,8 @@ impl fmt::Display for SseOpcode {
#[derive(Clone, PartialEq)]
pub enum AvxOpcode {
Vfmadd213ss,
Vfmadd213sd,
Vfmadd213ps,
Vfmadd213pd,
}
@@ -1391,8 +1393,10 @@ impl AvxOpcode {
/// Which `InstructionSet`s support the opcode?
pub(crate) fn available_from(&self) -> SmallVec<[InstructionSet; 2]> {
match self {
AvxOpcode::Vfmadd213ps => smallvec![InstructionSet::FMA],
AvxOpcode::Vfmadd213pd => smallvec![InstructionSet::FMA],
AvxOpcode::Vfmadd213ss
| AvxOpcode::Vfmadd213sd
| AvxOpcode::Vfmadd213ps
| AvxOpcode::Vfmadd213pd => smallvec![InstructionSet::FMA],
}
}
}
@@ -1400,6 +1404,8 @@ impl AvxOpcode {
impl fmt::Debug for AvxOpcode {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
let name = match self {
AvxOpcode::Vfmadd213ss => "vfmadd213ss",
AvxOpcode::Vfmadd213sd => "vfmadd213sd",
AvxOpcode::Vfmadd213ps => "vfmadd213ps",
AvxOpcode::Vfmadd213pd => "vfmadd213pd",
};

View File

@@ -1742,6 +1742,8 @@ pub(crate) fn emit(
let src3 = src3.clone().to_reg_mem().with_allocs(allocs);
let (w, opcode) = match op {
AvxOpcode::Vfmadd213ss => (false, 0xA9),
AvxOpcode::Vfmadd213sd => (true, 0xA9),
AvxOpcode::Vfmadd213ps => (false, 0xA8),
AvxOpcode::Vfmadd213pd => (true, 0xA8),
};

View File

@@ -3531,6 +3531,18 @@ fn test_x64_emit() {
// ========================================================
// XMM FMA
insns.push((
Inst::xmm_rm_r_vex(AvxOpcode::Vfmadd213ss, RegMem::reg(xmm2), xmm1, w_xmm0),
"C4E271A9C2",
"vfmadd213ss %xmm0, %xmm1, %xmm2, %xmm0",
));
insns.push((
Inst::xmm_rm_r_vex(AvxOpcode::Vfmadd213sd, RegMem::reg(xmm5), xmm4, w_xmm3),
"C4E2D9A9DD",
"vfmadd213sd %xmm3, %xmm4, %xmm5, %xmm3",
));
insns.push((
Inst::xmm_rm_r_vex(AvxOpcode::Vfmadd213ps, RegMem::reg(xmm2), xmm1, w_xmm0),
"C4E271A8C2",

View File

@@ -1847,7 +1847,12 @@ fn x64_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut OperandCol
// Vfmadd uses and defs the dst reg, that is not the case with all
// AVX's ops, if you're adding a new op, make sure to correctly define
// register uses.
assert!(*op == AvxOpcode::Vfmadd213ps || *op == AvxOpcode::Vfmadd213pd);
assert!(
*op == AvxOpcode::Vfmadd213ss
|| *op == AvxOpcode::Vfmadd213sd
|| *op == AvxOpcode::Vfmadd213ps
|| *op == AvxOpcode::Vfmadd213pd
);
collector.reg_use(src1.to_reg());
collector.reg_reuse_def(dst.to_writable_reg(), 0);

View File

@@ -2504,9 +2504,13 @@
(libcall_3 (LibCall.FmaF32) x y z))
(rule (lower (has_type $F64 (fma x y z)))
(libcall_3 (LibCall.FmaF64) x y z))
(rule (lower (has_type $F32X4 (fma x y z)))
(rule 1 (lower (has_type (and (use_fma) $F32) (fma x y z)))
(x64_vfmadd213ss x y z))
(rule 1 (lower (has_type (and (use_fma) $F64) (fma x y z)))
(x64_vfmadd213sd x y z))
(rule (lower (has_type (and (use_fma) $F32X4) (fma x y z)))
(x64_vfmadd213ps x y z))
(rule (lower (has_type $F64X2 (fma x y z)))
(rule (lower (has_type (and (use_fma) $F64X2) (fma x y z)))
(x64_vfmadd213pd x y z))
;; Rules for `load*` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

View File

@@ -248,6 +248,15 @@ where
}
}
#[inline]
fn use_fma(&mut self, _: Type) -> Option<()> {
if self.isa_flags.use_fma() {
Some(())
} else {
None
}
}
#[inline]
fn imm8_from_value(&mut self, val: Value) -> Option<Imm8Reg> {
let inst = self.lower_ctx.dfg().value_def(val).inst()?;

View File

@@ -0,0 +1,33 @@
test compile precise-output
target x86_64 has_avx=false has_fma=false
function %fma_f32(f32, f32, f32) -> f32 {
block0(v0: f32, v1: f32, v2: f32):
v3 = fma v0, v1, v2
return v3
}
; pushq %rbp
; movq %rsp, %rbp
; block0:
; load_ext_name %FmaF32+0, %rax
; call *%rax
; movq %rbp, %rsp
; popq %rbp
; ret
function %fma_f64(f64, f64, f64) -> f64 {
block0(v0: f64, v1: f64, v2: f64):
v3 = fma v0, v1, v2
return v3
}
; pushq %rbp
; movq %rsp, %rbp
; block0:
; load_ext_name %FmaF64+0, %rax
; call *%rax
; movq %rbp, %rsp
; popq %rbp
; ret

View File

@@ -0,0 +1,31 @@
test compile precise-output
target x86_64 has_avx=true has_fma=true
function %fma_f32(f32, f32, f32) -> f32 {
block0(v0: f32, v1: f32, v2: f32):
v3 = fma v0, v1, v2
return v3
}
; pushq %rbp
; movq %rsp, %rbp
; block0:
; vfmadd213ss %xmm0, %xmm1, %xmm2, %xmm0
; movq %rbp, %rsp
; popq %rbp
; ret
function %fma_f64(f64, f64, f64) -> f64 {
block0(v0: f64, v1: f64, v2: f64):
v3 = fma v0, v1, v2
return v3
}
; pushq %rbp
; movq %rsp, %rbp
; block0:
; vfmadd213sd %xmm0, %xmm1, %xmm2, %xmm0
; movq %rbp, %rsp
; popq %rbp
; ret

View File

@@ -2,6 +2,7 @@ test interpret
test run
target aarch64
target s390x
target x86_64 has_avx has_fma
target x86_64 has_avx=false has_fma=false
function %fma_f32(f32, f32, f32) -> f32 {