diff --git a/cranelift/codegen/meta/src/isa/x86/settings.rs b/cranelift/codegen/meta/src/isa/x86/settings.rs index 67071558d9..824683bbf6 100644 --- a/cranelift/codegen/meta/src/isa/x86/settings.rs +++ b/cranelift/codegen/meta/src/isa/x86/settings.rs @@ -58,6 +58,12 @@ pub(crate) fn define(shared: &SettingGroup) -> SettingGroup { "AVX512VL: CPUID.07H:EBX.AVX512VL[bit 31]", false, ); + let has_avx512vbmi = settings.add_bool( + "has_avx512vbmi", + "Has support for AVX512VMBI.", + "AVX512VBMI: CPUID.07H:ECX.AVX512VBMI[bit 1]", + false, + ); let has_avx512f = settings.add_bool( "has_avx512f", "Has support for AVX512F.", @@ -126,6 +132,10 @@ pub(crate) fn define(shared: &SettingGroup) -> SettingGroup { "use_avx512vl_simd", predicate!(shared_enable_simd && has_avx512vl), ); + settings.add_predicate( + "use_avx512vbmi_simd", + predicate!(shared_enable_simd && has_avx512vbmi), + ); settings.add_predicate( "use_avx512f_simd", predicate!(shared_enable_simd && has_avx512f), diff --git a/cranelift/codegen/src/isa/x64/inst/args.rs b/cranelift/codegen/src/isa/x64/inst/args.rs index 44e359d22e..010dd87633 100644 --- a/cranelift/codegen/src/isa/x64/inst/args.rs +++ b/cranelift/codegen/src/isa/x64/inst/args.rs @@ -463,6 +463,7 @@ pub(crate) enum InstructionSet { AVX512BITALG, AVX512DQ, AVX512F, + AVX512VBMI, AVX512VL, } @@ -999,10 +1000,11 @@ impl fmt::Display for SseOpcode { } } -#[derive(Clone)] +#[derive(Clone, PartialEq)] pub enum Avx512Opcode { Vcvtudq2ps, Vpabsq, + Vpermi2b, Vpmullq, Vpopcntb, } @@ -1015,6 +1017,9 @@ impl Avx512Opcode { smallvec![InstructionSet::AVX512F, InstructionSet::AVX512VL] } Avx512Opcode::Vpabsq => smallvec![InstructionSet::AVX512F, InstructionSet::AVX512VL], + Avx512Opcode::Vpermi2b => { + smallvec![InstructionSet::AVX512VL, InstructionSet::AVX512VBMI] + } Avx512Opcode::Vpmullq => smallvec![InstructionSet::AVX512VL, InstructionSet::AVX512DQ], Avx512Opcode::Vpopcntb => { smallvec![InstructionSet::AVX512VL, InstructionSet::AVX512BITALG] @@ -1028,6 +1033,7 @@ impl fmt::Debug for Avx512Opcode { let name = match self { Avx512Opcode::Vcvtudq2ps => "vcvtudq2ps", Avx512Opcode::Vpabsq => "vpabsq", + Avx512Opcode::Vpermi2b => "vpermi2b", Avx512Opcode::Vpmullq => "vpmullq", Avx512Opcode::Vpopcntb => "vpopcntb", }; diff --git a/cranelift/codegen/src/isa/x64/inst/emit.rs b/cranelift/codegen/src/isa/x64/inst/emit.rs index b5c6c43c26..f3cd42f12e 100644 --- a/cranelift/codegen/src/isa/x64/inst/emit.rs +++ b/cranelift/codegen/src/isa/x64/inst/emit.rs @@ -127,8 +127,9 @@ pub(crate) fn emit( InstructionSet::BMI1 => info.isa_flags.use_bmi1(), InstructionSet::BMI2 => info.isa_flags.has_bmi2(), InstructionSet::AVX512BITALG => info.isa_flags.has_avx512bitalg(), - InstructionSet::AVX512F => info.isa_flags.has_avx512f(), InstructionSet::AVX512DQ => info.isa_flags.has_avx512dq(), + InstructionSet::AVX512F => info.isa_flags.has_avx512f(), + InstructionSet::AVX512VBMI => info.isa_flags.has_avx512vbmi(), InstructionSet::AVX512VL => info.isa_flags.has_avx512vl(), } }; @@ -1558,8 +1559,9 @@ pub(crate) fn emit( src2, dst, } => { - let opcode = match op { - Avx512Opcode::Vpmullq => 0x40, + let (w, opcode) = match op { + Avx512Opcode::Vpermi2b => (false, 0x75), + Avx512Opcode::Vpmullq => (true, 0x40), _ => unimplemented!("Opcode {:?} not implemented", op), }; match src1 { @@ -1567,7 +1569,7 @@ pub(crate) fn emit( .length(EvexVectorLength::V128) .prefix(LegacyPrefixes::_66) .map(OpcodeMap::_0F38) - .w(true) + .w(w) .opcode(opcode) .reg(dst.to_reg().get_hw_encoding()) .rm(src.get_hw_encoding()) diff --git a/cranelift/codegen/src/isa/x64/inst/emit_tests.rs b/cranelift/codegen/src/isa/x64/inst/emit_tests.rs index d08216612c..11acc3107e 100644 --- a/cranelift/codegen/src/isa/x64/inst/emit_tests.rs +++ b/cranelift/codegen/src/isa/x64/inst/emit_tests.rs @@ -3573,6 +3573,18 @@ fn test_x64_emit() { "vpmullq %xmm14, %xmm10, %xmm1", )); + insns.push(( + Inst::xmm_rm_r_evex(Avx512Opcode::Vpermi2b, RegMem::reg(xmm14), xmm10, w_xmm1), + "62D22D0875CE", + "vpermi2b %xmm14, %xmm10, %xmm1", + )); + + insns.push(( + Inst::xmm_rm_r_evex(Avx512Opcode::Vpermi2b, RegMem::reg(xmm1), xmm0, w_xmm2), + "62F27D0875D1", + "vpermi2b %xmm1, %xmm0, %xmm2", + )); + insns.push(( Inst::xmm_rm_r(SseOpcode::Pmuludq, RegMem::reg(xmm8), w_xmm9), "66450FF4C8", @@ -4315,6 +4327,7 @@ fn test_x64_emit() { isa_flag_builder.enable("has_avx512f").unwrap(); isa_flag_builder.enable("has_avx512dq").unwrap(); isa_flag_builder.enable("has_avx512vl").unwrap(); + isa_flag_builder.enable("has_avx512vbmi").unwrap(); let isa_flags = x64::settings::Flags::new(&flags, isa_flag_builder); let rru = regs::create_reg_universe_systemv(&flags); diff --git a/cranelift/codegen/src/isa/x64/inst/mod.rs b/cranelift/codegen/src/isa/x64/inst/mod.rs index b253e2d696..cb5b27dfbc 100644 --- a/cranelift/codegen/src/isa/x64/inst/mod.rs +++ b/cranelift/codegen/src/isa/x64/inst/mod.rs @@ -1944,11 +1944,18 @@ fn x64_get_regs(inst: &Inst, collector: &mut RegUsageCollector) { } } Inst::XmmRmREvex { - src1, src2, dst, .. + op, + src1, + src2, + dst, + .. } => { src1.get_regs_as_uses(collector); collector.add_use(*src2); - collector.add_def(*dst); + match *op { + Avx512Opcode::Vpermi2b => collector.add_mod(*dst), + _ => collector.add_def(*dst), + } } Inst::XmmRmRImm { op, src, dst, .. } => { if inst.produces_const() { @@ -2336,6 +2343,7 @@ fn x64_map_regs(inst: &mut Inst, mapper: &RUM) { } } Inst::XmmRmREvex { + op, ref mut src1, ref mut src2, ref mut dst, @@ -2343,7 +2351,10 @@ fn x64_map_regs(inst: &mut Inst, mapper: &RUM) { } => { src1.map_uses(mapper); map_use(mapper, src2); - map_def(mapper, dst); + match *op { + Avx512Opcode::Vpermi2b => map_mod(mapper, dst), + _ => map_def(mapper, dst), + } } Inst::XmmRmiReg { ref mut src, diff --git a/cranelift/codegen/src/isa/x64/lower.rs b/cranelift/codegen/src/isa/x64/lower.rs index b87f243344..5e6b4670ab 100644 --- a/cranelift/codegen/src/isa/x64/lower.rs +++ b/cranelift/codegen/src/isa/x64/lower.rs @@ -5551,35 +5551,55 @@ fn lower_insn_to_regs>( // `src` so we disregard this register). ctx.emit(Inst::xmm_rm_r(SseOpcode::Pshufb, RegMem::from(tmp), dst)); } else { - // If `lhs` and `rhs` are different, we must shuffle each separately and then OR - // them together. This is necessary due to PSHUFB semantics. As in the case above, - // we build the `constructed_mask` for each case statically. + if isa_flags.use_avx512vl_simd() && isa_flags.use_avx512vbmi_simd() { + assert!( + mask.iter().all(|b| *b < 32), + "shuffle mask values must be between 0 and 31" + ); - // PSHUFB the `lhs` argument into `tmp0`, placing zeroes for unused lanes. - let tmp0 = ctx.alloc_tmp(lhs_ty).only_reg().unwrap(); - ctx.emit(Inst::gen_move(tmp0, lhs, lhs_ty)); - let constructed_mask = mask.iter().cloned().map(zero_unknown_lane_index).collect(); - let constant = ctx.use_constant(VCodeConstantData::Generated(constructed_mask)); - let tmp1 = ctx.alloc_tmp(types::I8X16).only_reg().unwrap(); - ctx.emit(Inst::xmm_load_const(constant, tmp1, ty)); - ctx.emit(Inst::xmm_rm_r(SseOpcode::Pshufb, RegMem::from(tmp1), tmp0)); + // Load the mask into the destination register. + let constant = ctx.use_constant(VCodeConstantData::Generated(mask.into())); + ctx.emit(Inst::xmm_load_const(constant, dst, ty)); - // PSHUFB the second argument, placing zeroes for unused lanes. - let constructed_mask = mask - .iter() - .map(|b| b.wrapping_sub(16)) - .map(zero_unknown_lane_index) - .collect(); - let constant = ctx.use_constant(VCodeConstantData::Generated(constructed_mask)); - let tmp2 = ctx.alloc_tmp(types::I8X16).only_reg().unwrap(); - ctx.emit(Inst::xmm_load_const(constant, tmp2, ty)); - ctx.emit(Inst::xmm_rm_r(SseOpcode::Pshufb, RegMem::from(tmp2), dst)); + // VPERMI2B has the exact semantics of Wasm's shuffle: + // permute the bytes in `src1` and `src2` using byte indexes + // in `dst` and store the byte results in `dst`. + ctx.emit(Inst::xmm_rm_r_evex( + Avx512Opcode::Vpermi2b, + RegMem::reg(rhs), + lhs, + dst, + )); + } else { + // If `lhs` and `rhs` are different, we must shuffle each separately and then OR + // them together. This is necessary due to PSHUFB semantics. As in the case above, + // we build the `constructed_mask` for each case statically. - // OR the shuffled registers (the mechanism and lane-size for OR-ing the registers - // is not important). - ctx.emit(Inst::xmm_rm_r(SseOpcode::Orps, RegMem::from(tmp0), dst)); + // PSHUFB the `lhs` argument into `tmp0`, placing zeroes for unused lanes. + let tmp0 = ctx.alloc_tmp(lhs_ty).only_reg().unwrap(); + ctx.emit(Inst::gen_move(tmp0, lhs, lhs_ty)); + let constructed_mask = + mask.iter().cloned().map(zero_unknown_lane_index).collect(); + let constant = ctx.use_constant(VCodeConstantData::Generated(constructed_mask)); + let tmp1 = ctx.alloc_tmp(types::I8X16).only_reg().unwrap(); + ctx.emit(Inst::xmm_load_const(constant, tmp1, ty)); + ctx.emit(Inst::xmm_rm_r(SseOpcode::Pshufb, RegMem::from(tmp1), tmp0)); - // TODO when AVX512 is enabled we should replace this sequence with a single VPERMB + // PSHUFB the second argument, placing zeroes for unused lanes. + let constructed_mask = mask + .iter() + .map(|b| b.wrapping_sub(16)) + .map(zero_unknown_lane_index) + .collect(); + let constant = ctx.use_constant(VCodeConstantData::Generated(constructed_mask)); + let tmp2 = ctx.alloc_tmp(types::I8X16).only_reg().unwrap(); + ctx.emit(Inst::xmm_load_const(constant, tmp2, ty)); + ctx.emit(Inst::xmm_rm_r(SseOpcode::Pshufb, RegMem::from(tmp2), dst)); + + // OR the shuffled registers (the mechanism and lane-size for OR-ing the registers + // is not important). + ctx.emit(Inst::xmm_rm_r(SseOpcode::Orps, RegMem::from(tmp0), dst)); + } } } diff --git a/cranelift/native/src/lib.rs b/cranelift/native/src/lib.rs index 82b3e98eec..80af5f13a2 100644 --- a/cranelift/native/src/lib.rs +++ b/cranelift/native/src/lib.rs @@ -97,11 +97,14 @@ pub fn builder_with_options( if std::is_x86_feature_detected!("avx512dq") { isa_builder.enable("has_avx512dq").unwrap(); } + if std::is_x86_feature_detected!("avx512f") { + isa_builder.enable("has_avx512f").unwrap(); + } if std::is_x86_feature_detected!("avx512vl") { isa_builder.enable("has_avx512vl").unwrap(); } - if std::is_x86_feature_detected!("avx512f") { - isa_builder.enable("has_avx512f").unwrap(); + if std::is_x86_feature_detected!("avx512vbmi") { + isa_builder.enable("has_avx512vbmi").unwrap(); } if std::is_x86_feature_detected!("lzcnt") { isa_builder.enable("has_lzcnt").unwrap();