x64: Implement SIMD fma (#4474)
* x64: Add VEX Instruction Encoder This uses a similar builder pattern to the EVEX Encoder. Does not yet support memory accesses. * x64: Add FMA Flag * x64: Implement SIMD `fma` * x64: Use 4 register Vex Inst * x64: Reorder VEX pretty print args
This commit is contained in:
@@ -794,6 +794,7 @@ pub(crate) enum InstructionSet {
|
||||
BMI1,
|
||||
#[allow(dead_code)] // never constructed (yet).
|
||||
BMI2,
|
||||
FMA,
|
||||
AVX512BITALG,
|
||||
AVX512DQ,
|
||||
AVX512F,
|
||||
@@ -1386,6 +1387,38 @@ impl fmt::Display for SseOpcode {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub enum AvxOpcode {
|
||||
Vfmadd213ps,
|
||||
Vfmadd213pd,
|
||||
}
|
||||
|
||||
impl AvxOpcode {
|
||||
/// Which `InstructionSet`s support the opcode?
|
||||
pub(crate) fn available_from(&self) -> SmallVec<[InstructionSet; 2]> {
|
||||
match self {
|
||||
AvxOpcode::Vfmadd213ps => smallvec![InstructionSet::FMA],
|
||||
AvxOpcode::Vfmadd213pd => smallvec![InstructionSet::FMA],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for AvxOpcode {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
let name = match self {
|
||||
AvxOpcode::Vfmadd213ps => "vfmadd213ps",
|
||||
AvxOpcode::Vfmadd213pd => "vfmadd213pd",
|
||||
};
|
||||
write!(fmt, "{}", name)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AvxOpcode {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
fmt::Debug::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub enum Avx512Opcode {
|
||||
Vcvtudq2ps,
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::isa::x64::encoding::rex::{
|
||||
low8_will_sign_extend_to_32, low8_will_sign_extend_to_64, reg_enc, LegacyPrefixes, OpcodeMap,
|
||||
RexFlags,
|
||||
};
|
||||
use crate::isa::x64::encoding::vex::{VexInstruction, VexVectorLength};
|
||||
use crate::isa::x64::inst::args::*;
|
||||
use crate::isa::x64::inst::*;
|
||||
use crate::machinst::{inst_common, MachBuffer, MachInstEmit, MachLabel, Reg, Writable};
|
||||
@@ -119,6 +120,7 @@ pub(crate) fn emit(
|
||||
InstructionSet::Lzcnt => info.isa_flags.use_lzcnt(),
|
||||
InstructionSet::BMI1 => info.isa_flags.use_bmi1(),
|
||||
InstructionSet::BMI2 => info.isa_flags.has_bmi2(),
|
||||
InstructionSet::FMA => info.isa_flags.has_fma(),
|
||||
InstructionSet::AVX512BITALG => info.isa_flags.has_avx512bitalg(),
|
||||
InstructionSet::AVX512DQ => info.isa_flags.has_avx512dq(),
|
||||
InstructionSet::AVX512F => info.isa_flags.has_avx512f(),
|
||||
@@ -1689,6 +1691,39 @@ pub(crate) fn emit(
|
||||
}
|
||||
}
|
||||
|
||||
Inst::XmmRmRVex {
|
||||
op,
|
||||
src1,
|
||||
src2,
|
||||
src3,
|
||||
dst,
|
||||
} => {
|
||||
let src1 = allocs.next(src1.to_reg());
|
||||
let dst = allocs.next(dst.to_reg().to_reg());
|
||||
debug_assert_eq!(src1, dst);
|
||||
let src2 = allocs.next(src2.to_reg());
|
||||
let src3 = src3.clone().to_reg_mem().with_allocs(allocs);
|
||||
|
||||
let (w, opcode) = match op {
|
||||
AvxOpcode::Vfmadd213ps => (false, 0xA8),
|
||||
AvxOpcode::Vfmadd213pd => (true, 0xA8),
|
||||
};
|
||||
|
||||
match src3 {
|
||||
RegMem::Reg { reg: src } => VexInstruction::new()
|
||||
.length(VexVectorLength::V128)
|
||||
.prefix(LegacyPrefixes::_66)
|
||||
.map(OpcodeMap::_0F38)
|
||||
.w(w)
|
||||
.opcode(opcode)
|
||||
.reg(dst.to_real_reg().unwrap().hw_enc())
|
||||
.rm(src.to_real_reg().unwrap().hw_enc())
|
||||
.vvvv(src2.to_real_reg().unwrap().hw_enc())
|
||||
.encode(sink),
|
||||
_ => todo!(),
|
||||
};
|
||||
}
|
||||
|
||||
Inst::XmmRmREvex {
|
||||
op,
|
||||
src1,
|
||||
|
||||
@@ -3701,6 +3701,21 @@ fn test_x64_emit() {
|
||||
"jmp *321(%r10,%rdx,4)",
|
||||
));
|
||||
|
||||
// ========================================================
|
||||
// XMM FMA
|
||||
|
||||
insns.push((
|
||||
Inst::xmm_rm_r_vex(AvxOpcode::Vfmadd213ps, RegMem::reg(xmm2), xmm1, w_xmm0),
|
||||
"C4E271A8C2",
|
||||
"vfmadd213ps %xmm0, %xmm1, %xmm2, %xmm0",
|
||||
));
|
||||
|
||||
insns.push((
|
||||
Inst::xmm_rm_r_vex(AvxOpcode::Vfmadd213pd, RegMem::reg(xmm5), xmm4, w_xmm3),
|
||||
"C4E2D9A8DD",
|
||||
"vfmadd213pd %xmm3, %xmm4, %xmm5, %xmm3",
|
||||
));
|
||||
|
||||
// ========================================================
|
||||
// XMM_CMP_RM_R
|
||||
|
||||
@@ -4866,6 +4881,7 @@ fn test_x64_emit() {
|
||||
let mut isa_flag_builder = x64::settings::builder();
|
||||
isa_flag_builder.enable("has_ssse3").unwrap();
|
||||
isa_flag_builder.enable("has_sse41").unwrap();
|
||||
isa_flag_builder.enable("has_fma").unwrap();
|
||||
isa_flag_builder.enable("has_avx512bitalg").unwrap();
|
||||
isa_flag_builder.enable("has_avx512dq").unwrap();
|
||||
isa_flag_builder.enable("has_avx512f").unwrap();
|
||||
|
||||
@@ -129,6 +129,8 @@ impl Inst {
|
||||
| Inst::XmmUnaryRmR { op, .. } => smallvec![op.available_from()],
|
||||
|
||||
Inst::XmmUnaryRmREvex { op, .. } | Inst::XmmRmREvex { op, .. } => op.available_from(),
|
||||
|
||||
Inst::XmmRmRVex { op, .. } => op.available_from(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -324,6 +326,20 @@ impl Inst {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn xmm_rm_r_vex(op: AvxOpcode, src3: RegMem, src2: Reg, dst: Writable<Reg>) -> Self {
|
||||
src3.assert_regclass_is(RegClass::Float);
|
||||
debug_assert!(src2.class() == RegClass::Float);
|
||||
debug_assert!(dst.to_reg().class() == RegClass::Float);
|
||||
Inst::XmmRmRVex {
|
||||
op,
|
||||
src3: XmmMem::new(src3).unwrap(),
|
||||
src2: Xmm::new(src2).unwrap(),
|
||||
src1: Xmm::new(dst.to_reg()).unwrap(),
|
||||
dst: WritableXmm::from_writable_reg(dst).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn xmm_rm_r_evex(
|
||||
op: Avx512Opcode,
|
||||
src1: RegMem,
|
||||
@@ -1136,6 +1152,29 @@ impl PrettyPrint for Inst {
|
||||
format!("{} {}, {}, {}", ljustify(op.to_string()), src1, src2, dst)
|
||||
}
|
||||
|
||||
Inst::XmmRmRVex {
|
||||
op,
|
||||
src1,
|
||||
src2,
|
||||
src3,
|
||||
dst,
|
||||
..
|
||||
} => {
|
||||
let src1 = pretty_print_reg(src1.to_reg(), 8, allocs);
|
||||
let dst = pretty_print_reg(dst.to_reg().to_reg(), 8, allocs);
|
||||
let src2 = pretty_print_reg(src2.to_reg(), 8, allocs);
|
||||
let src3 = src3.pretty_print(8, allocs);
|
||||
|
||||
format!(
|
||||
"{} {}, {}, {}, {}",
|
||||
ljustify(op.to_string()),
|
||||
src1,
|
||||
src2,
|
||||
src3,
|
||||
dst
|
||||
)
|
||||
}
|
||||
|
||||
Inst::XmmRmREvex {
|
||||
op,
|
||||
src1,
|
||||
@@ -1840,6 +1879,24 @@ fn x64_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut OperandCol
|
||||
}
|
||||
}
|
||||
}
|
||||
Inst::XmmRmRVex {
|
||||
op,
|
||||
src1,
|
||||
src2,
|
||||
src3,
|
||||
dst,
|
||||
..
|
||||
} => {
|
||||
// Vfmadd uses and defs the dst reg, that is not the case with all
|
||||
// AVX's ops, if you're adding a new op, make sure to correctly define
|
||||
// register uses.
|
||||
assert!(*op == AvxOpcode::Vfmadd213ps || *op == AvxOpcode::Vfmadd213pd);
|
||||
|
||||
collector.reg_use(src1.to_reg());
|
||||
collector.reg_reuse_def(dst.to_writable_reg(), 0);
|
||||
collector.reg_use(src2.to_reg());
|
||||
src3.get_operands(collector);
|
||||
}
|
||||
Inst::XmmRmREvex {
|
||||
op,
|
||||
src1,
|
||||
|
||||
Reference in New Issue
Block a user