cranelift: Implement scalar FMA on x86 (#4460)

x86 does not have dedicated instructions for scalar FMA, lower
to a libcall which seems to be what llvm does.
This commit is contained in:
Afonso Bordado
2022-08-03 18:29:10 +01:00
committed by GitHub
parent ff6082c0af
commit 709716bb8e
13 changed files with 167 additions and 50 deletions

View File

@@ -1,6 +1,7 @@
//! Naming well-known routines in the runtime library.
use crate::ir::{types, ExternalName, FuncRef, Function, Opcode, Type};
use crate::ir::{types, AbiParam, ExternalName, FuncRef, Function, Opcode, Signature, Type};
use crate::isa::CallConv;
use core::fmt;
use core::str::FromStr;
#[cfg(feature = "enable-serde")]
@@ -50,6 +51,10 @@ pub enum LibCall {
NearestF32,
/// nearest.f64
NearestF64,
/// fma.f32
FmaF32,
/// fma.f64
FmaF64,
/// libc.memcpy
Memcpy,
/// libc.memset
@@ -91,6 +96,8 @@ impl FromStr for LibCall {
"TruncF64" => Ok(Self::TruncF64),
"NearestF32" => Ok(Self::NearestF32),
"NearestF64" => Ok(Self::NearestF64),
"FmaF32" => Ok(Self::FmaF32),
"FmaF64" => Ok(Self::FmaF64),
"Memcpy" => Ok(Self::Memcpy),
"Memset" => Ok(Self::Memset),
"Memmove" => Ok(Self::Memmove),
@@ -124,6 +131,7 @@ impl LibCall {
Opcode::Floor => Self::FloorF32,
Opcode::Trunc => Self::TruncF32,
Opcode::Nearest => Self::NearestF32,
Opcode::Fma => Self::FmaF32,
_ => return None,
},
types::F64 => match opcode {
@@ -131,6 +139,7 @@ impl LibCall {
Opcode::Floor => Self::FloorF64,
Opcode::Trunc => Self::TruncF64,
Opcode::Nearest => Self::NearestF64,
Opcode::Fma => Self::FmaF64,
_ => return None,
},
_ => return None,
@@ -157,6 +166,8 @@ impl LibCall {
TruncF64,
NearestF32,
NearestF64,
FmaF32,
FmaF64,
Memcpy,
Memset,
Memmove,
@@ -164,6 +175,50 @@ impl LibCall {
ElfTlsGetAddr,
]
}
/// Get a [Signature] for the function targeted by this [LibCall].
pub fn signature(&self, call_conv: CallConv) -> Signature {
use types::*;
let mut sig = Signature::new(call_conv);
match self {
LibCall::UdivI64
| LibCall::SdivI64
| LibCall::UremI64
| LibCall::SremI64
| LibCall::IshlI64
| LibCall::UshrI64
| LibCall::SshrI64 => {
sig.params.push(AbiParam::new(I64));
sig.params.push(AbiParam::new(I64));
sig.returns.push(AbiParam::new(I64));
}
LibCall::CeilF32 | LibCall::FloorF32 | LibCall::TruncF32 | LibCall::NearestF32 => {
sig.params.push(AbiParam::new(F32));
sig.returns.push(AbiParam::new(F32));
}
LibCall::TruncF64 | LibCall::FloorF64 | LibCall::CeilF64 | LibCall::NearestF64 => {
sig.params.push(AbiParam::new(F64));
sig.returns.push(AbiParam::new(F64));
}
LibCall::FmaF32 | LibCall::FmaF64 => {
let ty = if *self == LibCall::FmaF32 { F32 } else { F64 };
sig.params.push(AbiParam::new(ty));
sig.params.push(AbiParam::new(ty));
sig.params.push(AbiParam::new(ty));
sig.returns.push(AbiParam::new(ty));
}
LibCall::Probestack
| LibCall::Memcpy
| LibCall::Memset
| LibCall::Memmove
| LibCall::Memcmp
| LibCall::ElfTlsGetAddr => unimplemented!(),
}
sig
}
}
/// Get a function reference for the probestack function in `func`.

View File

@@ -1551,7 +1551,7 @@ impl LowerBackend for AArch64Backend {
type MInst = Inst;
fn lower<C: LowerCtx<I = Inst>>(&self, ctx: &mut C, ir_inst: IRInst) -> CodegenResult<()> {
lower_inst::lower_insn_to_regs(ctx, ir_inst, &self.flags, &self.isa_flags)
lower_inst::lower_insn_to_regs(ctx, ir_inst, &self.triple, &self.flags, &self.isa_flags)
}
fn lower_branch_group<C: LowerCtx<I = Inst>>(

View File

@@ -30,6 +30,7 @@ use regalloc2::PReg;
use std::boxed::Box;
use std::convert::TryFrom;
use std::vec::Vec;
use target_lexicon::Triple;
type BoxCallInfo = Box<CallInfo>;
type BoxCallIndInfo = Box<CallIndInfo>;
@@ -40,6 +41,7 @@ type BoxExternalName = Box<ExternalName>;
/// The main entry point for lowering with ISLE.
pub(crate) fn lower<C>(
lower_ctx: &mut C,
triple: &Triple,
flags: &Flags,
isa_flags: &IsaFlags,
outputs: &[InsnOutput],
@@ -48,9 +50,15 @@ pub(crate) fn lower<C>(
where
C: LowerCtx<I = MInst>,
{
lower_common(lower_ctx, flags, isa_flags, outputs, inst, |cx, insn| {
generated_code::constructor_lower(cx, insn)
})
lower_common(
lower_ctx,
triple,
flags,
isa_flags,
outputs,
inst,
|cx, insn| generated_code::constructor_lower(cx, insn),
)
}
pub struct ExtendedValue {

View File

@@ -16,11 +16,13 @@ use crate::{CodegenError, CodegenResult};
use alloc::boxed::Box;
use alloc::vec::Vec;
use core::convert::TryFrom;
use target_lexicon::Triple;
/// Actually codegen an instruction's results into registers.
pub(crate) fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
ctx: &mut C,
insn: IRInst,
triple: &Triple,
flags: &Flags,
isa_flags: &aarch64_settings::Flags,
) -> CodegenResult<()> {
@@ -33,7 +35,7 @@ pub(crate) fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
None
};
if let Ok(()) = super::lower::isle::lower(ctx, flags, isa_flags, &outputs, insn) {
if let Ok(()) = super::lower::isle::lower(ctx, triple, flags, isa_flags, &outputs, insn) {
return Ok(());
}

View File

@@ -30,9 +30,14 @@ impl LowerBackend for S390xBackend {
None
};
if let Ok(()) =
super::lower::isle::lower(ctx, &self.flags, &self.isa_flags, &outputs, ir_inst)
{
if let Ok(()) = super::lower::isle::lower(
ctx,
&self.triple,
&self.flags,
&self.isa_flags,
&outputs,
ir_inst,
) {
return Ok(());
}
@@ -295,6 +300,7 @@ impl LowerBackend for S390xBackend {
// the second branch (if any) by emitting a two-way conditional branch.
if let Ok(()) = super::lower::isle::lower_branch(
ctx,
&self.triple,
&self.flags,
&self.isa_flags,
branches[0],

View File

@@ -26,6 +26,7 @@ use std::boxed::Box;
use std::cell::Cell;
use std::convert::TryFrom;
use std::vec::Vec;
use target_lexicon::Triple;
type BoxCallInfo = Box<CallInfo>;
type BoxCallIndInfo = Box<CallIndInfo>;
@@ -37,6 +38,7 @@ type VecMInstBuilder = Cell<Vec<MInst>>;
/// The main entry point for lowering with ISLE.
pub(crate) fn lower<C>(
lower_ctx: &mut C,
triple: &Triple,
flags: &Flags,
isa_flags: &IsaFlags,
outputs: &[InsnOutput],
@@ -45,14 +47,21 @@ pub(crate) fn lower<C>(
where
C: LowerCtx<I = MInst>,
{
lower_common(lower_ctx, flags, isa_flags, outputs, inst, |cx, insn| {
generated_code::constructor_lower(cx, insn)
})
lower_common(
lower_ctx,
triple,
flags,
isa_flags,
outputs,
inst,
|cx, insn| generated_code::constructor_lower(cx, insn),
)
}
/// The main entry point for branch lowering with ISLE.
pub(crate) fn lower_branch<C>(
lower_ctx: &mut C,
triple: &Triple,
flags: &Flags,
isa_flags: &IsaFlags,
branch: Inst,
@@ -61,9 +70,15 @@ pub(crate) fn lower_branch<C>(
where
C: LowerCtx<I = MInst>,
{
lower_common(lower_ctx, flags, isa_flags, &[], branch, |cx, insn| {
generated_code::constructor_lower_branch(cx, insn, &targets.to_vec())
})
lower_common(
lower_ctx,
triple,
flags,
isa_flags,
&[],
branch,
|cx, insn| generated_code::constructor_lower_branch(cx, insn, &targets.to_vec()),
)
}
impl<C> generated_code::Context for IsleContext<'_, C, Flags, IsaFlags, 6>

View File

@@ -3354,3 +3354,13 @@
(decl x64_rsp () Reg)
(rule (x64_rsp)
(mov_preg (preg_rsp)))
;;;; Helpers for Emitting LibCalls ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(type LibCall extern
(enum
FmaF32
FmaF64))
(decl libcall_3 (LibCall Reg Reg Reg) Reg)
(extern constructor libcall_3 libcall_3)

View File

@@ -2491,6 +2491,10 @@
;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type $F32 (fma x y z)))
(libcall_3 (LibCall.FmaF32) x y z))
(rule (lower (has_type $F64 (fma x y z)))
(libcall_3 (LibCall.FmaF64) x y z))
(rule (lower (has_type $F32X4 (fma x y z)))
(x64_vfmadd213ps x y z))
(rule (lower (has_type $F64X2 (fma x y z)))

View File

@@ -6,8 +6,7 @@ pub(super) mod isle;
use crate::data_value::DataValue;
use crate::ir::{
condcodes::{CondCode, FloatCC, IntCC},
types, AbiParam, ExternalName, Inst as IRInst, InstructionData, LibCall, Opcode, Signature,
Type,
types, ExternalName, Inst as IRInst, InstructionData, LibCall, Opcode, Type,
};
use crate::isa::x64::abi::*;
use crate::isa::x64::inst::args::*;
@@ -573,29 +572,13 @@ fn emit_fcmp<C: LowerCtx<I = Inst>>(
cond_result
}
fn make_libcall_sig<C: LowerCtx<I = Inst>>(
ctx: &mut C,
insn: IRInst,
call_conv: CallConv,
) -> Signature {
let mut sig = Signature::new(call_conv);
for i in 0..ctx.num_inputs(insn) {
sig.params.push(AbiParam::new(ctx.input_ty(insn, i)));
}
for i in 0..ctx.num_outputs(insn) {
sig.returns.push(AbiParam::new(ctx.output_ty(insn, i)));
}
sig
}
fn emit_vm_call<C: LowerCtx<I = Inst>>(
ctx: &mut C,
flags: &Flags,
triple: &Triple,
libcall: LibCall,
insn: IRInst,
inputs: SmallVec<[InsnInput; 4]>,
outputs: SmallVec<[InsnOutput; 2]>,
inputs: &[Reg],
outputs: &[Writable<Reg>],
) -> CodegenResult<()> {
let extname = ExternalName::LibCall(libcall);
@@ -607,7 +590,7 @@ fn emit_vm_call<C: LowerCtx<I = Inst>>(
// TODO avoid recreating signatures for every single Libcall function.
let call_conv = CallConv::for_libcall(flags, CallConv::triple_default(triple));
let sig = make_libcall_sig(ctx, insn, call_conv);
let sig = libcall.signature(call_conv);
let caller_conv = ctx.abi().call_conv();
let mut abi = X64ABICaller::from_func(&sig, &extname, dist, caller_conv, flags)?;
@@ -617,14 +600,12 @@ fn emit_vm_call<C: LowerCtx<I = Inst>>(
assert_eq!(inputs.len(), abi.num_args());
for (i, input) in inputs.iter().enumerate() {
let arg_reg = put_input_in_reg(ctx, *input);
abi.emit_copy_regs_to_arg(ctx, i, ValueRegs::one(arg_reg));
abi.emit_copy_regs_to_arg(ctx, i, ValueRegs::one(*input));
}
abi.emit_call(ctx);
for (i, output) in outputs.iter().enumerate() {
let retval_reg = get_output_reg(ctx, *output).only_reg().unwrap();
abi.emit_copy_retval_to_regs(ctx, i, ValueRegs::one(retval_reg));
abi.emit_copy_retval_to_regs(ctx, i, ValueRegs::one(*output));
}
abi.emit_stack_post_adjust(ctx);
@@ -810,7 +791,7 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
None
};
if let Ok(()) = isle::lower(ctx, flags, isa_flags, &outputs, insn) {
if let Ok(()) = isle::lower(ctx, triple, flags, isa_flags, &outputs, insn) {
return Ok(());
}
@@ -884,6 +865,7 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
| Opcode::FvpromoteLow
| Opcode::Fdemote
| Opcode::Fvdemote
| Opcode::Fma
| Opcode::Icmp
| Opcode::Fcmp
| Opcode::Load
@@ -1974,7 +1956,11 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
ty, op
),
};
emit_vm_call(ctx, flags, triple, libcall, insn, inputs, outputs)?;
let input = put_input_in_reg(ctx, inputs[0]);
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
emit_vm_call(ctx, flags, triple, libcall, &[input], &[dst])?;
}
}
@@ -2726,8 +2712,6 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
Opcode::Cls => unimplemented!("Cls not supported"),
Opcode::Fma => implemented_in_isle(ctx),
Opcode::BorNot | Opcode::BxorNot => {
unimplemented!("or-not / xor-not opcodes not implemented");
}

View File

@@ -10,6 +10,8 @@ use generated_code::{Context, MInst};
// Types that the generated ISLE code uses via `use super::*`.
use super::{is_int_or_ref_ty, is_mergeable_load, lower_to_amode};
use crate::ir::LibCall;
use crate::isa::x64::lower::emit_vm_call;
use crate::{
ir::{
condcodes::{FloatCC, IntCC},
@@ -35,6 +37,7 @@ use regalloc2::PReg;
use smallvec::SmallVec;
use std::boxed::Box;
use std::convert::TryFrom;
use target_lexicon::Triple;
type BoxCallInfo = Box<CallInfo>;
type BoxVecMachLabel = Box<SmallVec<[MachLabel; 4]>>;
@@ -48,6 +51,7 @@ pub struct SinkableLoad {
/// The main entry point for lowering with ISLE.
pub(crate) fn lower<C>(
lower_ctx: &mut C,
triple: &Triple,
flags: &Flags,
isa_flags: &IsaFlags,
outputs: &[InsnOutput],
@@ -56,9 +60,15 @@ pub(crate) fn lower<C>(
where
C: LowerCtx<I = MInst>,
{
lower_common(lower_ctx, flags, isa_flags, outputs, inst, |cx, insn| {
generated_code::constructor_lower(cx, insn)
})
lower_common(
lower_ctx,
triple,
flags,
isa_flags,
outputs,
inst,
|cx, insn| generated_code::constructor_lower(cx, insn),
)
}
impl<C> Context for IsleContext<'_, C, Flags, IsaFlags, 6>
@@ -647,6 +657,24 @@ where
fn preg_rsp(&mut self) -> PReg {
regs::rsp().to_real_reg().unwrap().into()
}
fn libcall_3(&mut self, libcall: &LibCall, a: Reg, b: Reg, c: Reg) -> Reg {
let call_conv = self.lower_ctx.abi().call_conv();
let ret_ty = libcall.signature(call_conv).returns[0].value_type;
let output_reg = self.lower_ctx.alloc_tmp(ret_ty).only_reg().unwrap();
emit_vm_call(
self.lower_ctx,
self.flags,
self.triple,
libcall.clone(),
&[a, b, c],
&[output_reg],
)
.expect("Failed to emit LibCall");
output_reg.to_reg()
}
}
impl<C> IsleContext<'_, C, Flags, IsaFlags, 6>

View File

@@ -4,6 +4,7 @@ use alloc::boxed::Box;
use alloc::vec::Vec;
use smallvec::SmallVec;
use std::cell::Cell;
use target_lexicon::Triple;
pub use super::MachLabel;
pub use crate::ir::{
@@ -899,6 +900,7 @@ where
[(C::I, bool); N]: smallvec::Array,
{
pub lower_ctx: &'a mut C,
pub triple: &'a Triple,
pub flags: &'a F,
pub isa_flags: &'a I,
}
@@ -910,6 +912,7 @@ where
/// lowering.
pub(crate) fn lower_common<C, F, I, IF, const N: usize>(
lower_ctx: &mut C,
triple: &Triple,
flags: &F,
isa_flags: &I,
outputs: &[InsnOutput],
@@ -925,6 +928,7 @@ where
// internal heap allocations.
let mut isle_ctx = IsleContext {
lower_ctx,
triple,
flags,
isa_flags,
};

View File

@@ -169,9 +169,8 @@ const OPCODE_SIGNATURES: &'static [(
(Opcode::Fcopysign, &[F32, F32], &[F32], insert_opcode),
(Opcode::Fcopysign, &[F64, F64], &[F64], insert_opcode),
// Fma
// TODO: Missing on X86, see https://github.com/bytecodealliance/wasmtime/pull/4460
// (Opcode::Fma, &[F32, F32, F32], &[F32], insert_opcode),
// (Opcode::Fma, &[F64, F64, F64], &[F64], insert_opcode),
(Opcode::Fma, &[F32, F32, F32], &[F32], insert_opcode),
(Opcode::Fma, &[F64, F64, F64], &[F64], insert_opcode),
// Fabs
(Opcode::Fabs, &[F32], &[F32], insert_opcode),
(Opcode::Fabs, &[F64], &[F64], insert_opcode),

View File

@@ -70,6 +70,8 @@ pub fn default_libcall_names() -> Box<dyn Fn(ir::LibCall) -> String + Send + Syn
ir::LibCall::TruncF64 => "trunc".to_owned(),
ir::LibCall::NearestF32 => "nearbyintf".to_owned(),
ir::LibCall::NearestF64 => "nearbyint".to_owned(),
ir::LibCall::FmaF32 => "fmaf".to_owned(),
ir::LibCall::FmaF64 => "fma".to_owned(),
ir::LibCall::Memcpy => "memcpy".to_owned(),
ir::LibCall::Memset => "memset".to_owned(),
ir::LibCall::Memmove => "memmove".to_owned(),