From e676589b0c6e8228c421e18249d4635eb6c4bbe4 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Mon, 10 May 2021 16:25:03 -0700 Subject: [PATCH] x64: lower i64x2.imul to VPMULLQ when possible This adds the machinery to encode the VPMULLQ instruction which is available in AVX512VL and AVX512DQ. When these feature sets are available, we use this instruction instead of a lengthy 12-instruction sequence. --- cranelift/codegen/src/isa/x64/inst/args.rs | 4 + cranelift/codegen/src/isa/x64/inst/emit.rs | 27 +++ .../codegen/src/isa/x64/inst/emit_tests.rs | 7 + cranelift/codegen/src/isa/x64/inst/mod.rs | 57 +++++- cranelift/codegen/src/isa/x64/lower.rs | 191 +++++++++--------- 5 files changed, 195 insertions(+), 91 deletions(-) diff --git a/cranelift/codegen/src/isa/x64/inst/args.rs b/cranelift/codegen/src/isa/x64/inst/args.rs index b54f1b6126..6e0d507ab0 100644 --- a/cranelift/codegen/src/isa/x64/inst/args.rs +++ b/cranelift/codegen/src/isa/x64/inst/args.rs @@ -462,6 +462,7 @@ pub(crate) enum InstructionSet { BMI2, AVX512F, AVX512VL, + AVX512DQ, } /// Some SSE operations requiring 2 operands r/m and r. @@ -994,6 +995,7 @@ impl fmt::Display for SseOpcode { #[derive(Clone)] pub enum Avx512Opcode { Vpabsq, + Vpmullq, } impl Avx512Opcode { @@ -1001,6 +1003,7 @@ impl Avx512Opcode { pub(crate) fn available_from(&self) -> SmallVec<[InstructionSet; 2]> { match self { Avx512Opcode::Vpabsq => smallvec![InstructionSet::AVX512F, InstructionSet::AVX512VL], + Avx512Opcode::Vpmullq => smallvec![InstructionSet::AVX512VL, InstructionSet::AVX512DQ], } } } @@ -1009,6 +1012,7 @@ impl fmt::Debug for Avx512Opcode { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { let name = match self { Avx512Opcode::Vpabsq => "vpabsq", + Avx512Opcode::Vpmullq => "vpmullq", }; write!(fmt, "{}", name) } diff --git a/cranelift/codegen/src/isa/x64/inst/emit.rs b/cranelift/codegen/src/isa/x64/inst/emit.rs index 0bd74ecd8b..134d6eafa1 100644 --- a/cranelift/codegen/src/isa/x64/inst/emit.rs +++ b/cranelift/codegen/src/isa/x64/inst/emit.rs @@ -128,6 +128,7 @@ pub(crate) fn emit( InstructionSet::BMI2 => info.isa_flags.has_bmi2(), InstructionSet::AVX512F => info.isa_flags.has_avx512f(), InstructionSet::AVX512VL => info.isa_flags.has_avx512vl(), + InstructionSet::AVX512DQ => info.isa_flags.has_avx512dq(), } }; @@ -1409,6 +1410,7 @@ pub(crate) fn emit( Inst::XmmUnaryRmREvex { op, src, dst } => { let opcode = match op { Avx512Opcode::Vpabsq => 0x1f, + _ => unimplemented!("Opcode {:?} not implemented", op), }; match src { RegMem::Reg { reg: src } => EvexInstruction::new() @@ -1545,6 +1547,31 @@ pub(crate) fn emit( } } + Inst::XmmRmREvex { + op, + src1, + src2, + dst, + } => { + let opcode = match op { + Avx512Opcode::Vpmullq => 0x40, + _ => unimplemented!("Opcode {:?} not implemented", op), + }; + match src1 { + RegMem::Reg { reg: src } => EvexInstruction::new() + .length(EvexVectorLength::V128) + .prefix(LegacyPrefixes::_66) + .map(OpcodeMap::_0F38) + .w(true) + .opcode(opcode) + .reg(dst.to_reg().get_hw_encoding()) + .rm(src.get_hw_encoding()) + .vvvvv(src2.get_hw_encoding()) + .encode(sink), + _ => todo!(), + }; + } + Inst::XmmMinMaxSeq { size, is_min, diff --git a/cranelift/codegen/src/isa/x64/inst/emit_tests.rs b/cranelift/codegen/src/isa/x64/inst/emit_tests.rs index f03762b97b..1d0dd4aba5 100644 --- a/cranelift/codegen/src/isa/x64/inst/emit_tests.rs +++ b/cranelift/codegen/src/isa/x64/inst/emit_tests.rs @@ -3555,6 +3555,12 @@ fn test_x64_emit() { "pmullw %xmm14, %xmm1", )); + insns.push(( + Inst::xmm_rm_r_evex(Avx512Opcode::Vpmullq, RegMem::reg(xmm14), xmm10, w_xmm1), + "62D2AD0840CE", + "vpmullq %xmm14, %xmm10, %xmm1", + )); + insns.push(( Inst::xmm_rm_r(SseOpcode::Pmuludq, RegMem::reg(xmm8), w_xmm9), "66450FF4C8", @@ -4283,6 +4289,7 @@ fn test_x64_emit() { isa_flag_builder.enable("has_ssse3").unwrap(); isa_flag_builder.enable("has_sse41").unwrap(); isa_flag_builder.enable("has_avx512f").unwrap(); + isa_flag_builder.enable("has_avx512dq").unwrap(); let isa_flags = x64::settings::Flags::new(&flags, isa_flag_builder); let rru = regs::create_reg_universe_systemv(&flags); diff --git a/cranelift/codegen/src/isa/x64/inst/mod.rs b/cranelift/codegen/src/isa/x64/inst/mod.rs index fe89ac4c90..547d8413cb 100644 --- a/cranelift/codegen/src/isa/x64/inst/mod.rs +++ b/cranelift/codegen/src/isa/x64/inst/mod.rs @@ -212,6 +212,13 @@ pub enum Inst { dst: Writable, }, + XmmRmREvex { + op: Avx512Opcode, + src1: RegMem, + src2: Reg, + dst: Writable, + }, + /// XMM (scalar or vector) unary op: mov between XMM registers (32 64) (reg addr) reg, sqrt, /// etc. /// @@ -577,7 +584,7 @@ impl Inst { | Inst::XmmToGpr { op, .. } | Inst::XmmUnaryRmR { op, .. } => smallvec![op.available_from()], - Inst::XmmUnaryRmREvex { op, .. } => op.available_from(), + Inst::XmmUnaryRmREvex { op, .. } | Inst::XmmRmREvex { op, .. } => op.available_from(), } } } @@ -724,6 +731,23 @@ impl Inst { Inst::XmmRmR { op, src, dst } } + pub(crate) fn xmm_rm_r_evex( + op: Avx512Opcode, + src1: RegMem, + src2: Reg, + dst: Writable, + ) -> Self { + src1.assert_regclass_is(RegClass::V128); + debug_assert!(src2.get_class() == RegClass::V128); + debug_assert!(dst.to_reg().get_class() == RegClass::V128); + Inst::XmmRmREvex { + op, + src1, + src2, + dst, + } + } + pub(crate) fn xmm_uninit_value(dst: Writable) -> Self { debug_assert!(dst.to_reg().get_class() == RegClass::V128); Inst::XmmUninitializedValue { dst } @@ -1425,6 +1449,20 @@ impl PrettyPrint for Inst { show_ireg_sized(dst.to_reg(), mb_rru, 8), ), + Inst::XmmRmREvex { + op, + src1, + src2, + dst, + .. + } => format!( + "{} {}, {}, {}", + ljustify(op.to_string()), + src1.show_rru_sized(mb_rru, 8), + show_ireg_sized(*src2, mb_rru, 8), + show_ireg_sized(dst.to_reg(), mb_rru, 8), + ), + Inst::XmmMinMaxSeq { lhs, rhs_dst, @@ -1898,6 +1936,13 @@ fn x64_get_regs(inst: &Inst, collector: &mut RegUsageCollector) { collector.add_mod(*dst); } } + Inst::XmmRmREvex { + src1, src2, dst, .. + } => { + src1.get_regs_as_uses(collector); + collector.add_use(*src2); + collector.add_def(*dst); + } Inst::XmmRmRImm { op, src, dst, .. } => { if inst.produces_const() { // No need to account for src, since src == dst. @@ -2283,6 +2328,16 @@ fn x64_map_regs(inst: &mut Inst, mapper: &RUM) { map_mod(mapper, dst); } } + Inst::XmmRmREvex { + ref mut src1, + ref mut src2, + ref mut dst, + .. + } => { + src1.map_uses(mapper); + map_use(mapper, src2); + map_def(mapper, dst); + } Inst::XmmRmiReg { ref mut src, ref mut dst, diff --git a/cranelift/codegen/src/isa/x64/lower.rs b/cranelift/codegen/src/isa/x64/lower.rs index 3f62b375a7..9c77e879f2 100644 --- a/cranelift/codegen/src/isa/x64/lower.rs +++ b/cranelift/codegen/src/isa/x64/lower.rs @@ -1663,105 +1663,116 @@ fn lower_insn_to_regs>( Opcode::Imul => { let ty = ty.unwrap(); if ty == types::I64X2 { - // For I64X2 multiplication we describe a lane A as being - // composed of a 32-bit upper half "Ah" and a 32-bit lower half - // "Al". The 32-bit long hand multiplication can then be written - // as: - // Ah Al - // * Bh Bl - // ----- - // Al * Bl - // + (Ah * Bl) << 32 - // + (Al * Bh) << 32 - // - // So for each lane we will compute: - // A * B = (Al * Bl) + ((Ah * Bl) + (Al * Bh)) << 32 - // - // Note, the algorithm will use pmuldq which operates directly - // on the lower 32-bit (Al or Bl) of a lane and writes the - // result to the full 64-bits of the lane of the destination. - // For this reason we don't need shifts to isolate the lower - // 32-bits, however, we will need to use shifts to isolate the - // high 32-bits when doing calculations, i.e., Ah == A >> 32. - // - // The full sequence then is as follows: - // A' = A - // A' = A' >> 32 - // A' = Ah' * Bl - // B' = B - // B' = B' >> 32 - // B' = Bh' * Al - // B' = B' + A' - // B' = B' << 32 - // A' = A - // A' = Al' * Bl - // A' = A' + B' - // dst = A' - - // Get inputs rhs=A and lhs=B and the dst register + // Eventually one of these should be `input_to_reg_mem` (TODO). let lhs = put_input_in_reg(ctx, inputs[0]); let rhs = put_input_in_reg(ctx, inputs[1]); let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap(); - // A' = A - let rhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); - ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); + if isa_flags.use_avx512f_simd() || isa_flags.use_avx512vl_simd() { + // With the right AVX512 features (VL, DQ) this operation + // can lower to a single operation. + ctx.emit(Inst::xmm_rm_r_evex( + Avx512Opcode::Vpmullq, + RegMem::reg(rhs), + lhs, + dst, + )); + } else { + // Otherwise, for I64X2 multiplication we describe a lane A as being + // composed of a 32-bit upper half "Ah" and a 32-bit lower half + // "Al". The 32-bit long hand multiplication can then be written + // as: + // Ah Al + // * Bh Bl + // ----- + // Al * Bl + // + (Ah * Bl) << 32 + // + (Al * Bh) << 32 + // + // So for each lane we will compute: + // A * B = (Al * Bl) + ((Ah * Bl) + (Al * Bh)) << 32 + // + // Note, the algorithm will use pmuldq which operates directly + // on the lower 32-bit (Al or Bl) of a lane and writes the + // result to the full 64-bits of the lane of the destination. + // For this reason we don't need shifts to isolate the lower + // 32-bits, however, we will need to use shifts to isolate the + // high 32-bits when doing calculations, i.e., Ah == A >> 32. + // + // The full sequence then is as follows: + // A' = A + // A' = A' >> 32 + // A' = Ah' * Bl + // B' = B + // B' = B' >> 32 + // B' = Bh' * Al + // B' = B' + A' + // B' = B' << 32 + // A' = A + // A' = Al' * Bl + // A' = A' + B' + // dst = A' - // A' = A' >> 32 - // A' = Ah' * Bl - ctx.emit(Inst::xmm_rmi_reg( - SseOpcode::Psrlq, - RegMemImm::imm(32), - rhs_1, - )); - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Pmuludq, - RegMem::reg(lhs.clone()), - rhs_1, - )); + // A' = A + let rhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); + ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); - // B' = B - let lhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); - ctx.emit(Inst::gen_move(lhs_1, lhs, ty)); + // A' = A' >> 32 + // A' = Ah' * Bl + ctx.emit(Inst::xmm_rmi_reg( + SseOpcode::Psrlq, + RegMemImm::imm(32), + rhs_1, + )); + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Pmuludq, + RegMem::reg(lhs.clone()), + rhs_1, + )); - // B' = B' >> 32 - // B' = Bh' * Al - ctx.emit(Inst::xmm_rmi_reg( - SseOpcode::Psrlq, - RegMemImm::imm(32), - lhs_1, - )); - ctx.emit(Inst::xmm_rm_r(SseOpcode::Pmuludq, RegMem::reg(rhs), lhs_1)); + // B' = B + let lhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); + ctx.emit(Inst::gen_move(lhs_1, lhs, ty)); - // B' = B' + A' - // B' = B' << 32 - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Paddq, - RegMem::reg(rhs_1.to_reg()), - lhs_1, - )); - ctx.emit(Inst::xmm_rmi_reg( - SseOpcode::Psllq, - RegMemImm::imm(32), - lhs_1, - )); + // B' = B' >> 32 + // B' = Bh' * Al + ctx.emit(Inst::xmm_rmi_reg( + SseOpcode::Psrlq, + RegMemImm::imm(32), + lhs_1, + )); + ctx.emit(Inst::xmm_rm_r(SseOpcode::Pmuludq, RegMem::reg(rhs), lhs_1)); - // A' = A - // A' = Al' * Bl - // A' = A' + B' - // dst = A' - ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Pmuludq, - RegMem::reg(lhs.clone()), - rhs_1, - )); - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Paddq, - RegMem::reg(lhs_1.to_reg()), - rhs_1, - )); - ctx.emit(Inst::gen_move(dst, rhs_1.to_reg(), ty)); + // B' = B' + A' + // B' = B' << 32 + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Paddq, + RegMem::reg(rhs_1.to_reg()), + lhs_1, + )); + ctx.emit(Inst::xmm_rmi_reg( + SseOpcode::Psllq, + RegMemImm::imm(32), + lhs_1, + )); + + // A' = A + // A' = Al' * Bl + // A' = A' + B' + // dst = A' + ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Pmuludq, + RegMem::reg(lhs.clone()), + rhs_1, + )); + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Paddq, + RegMem::reg(lhs_1.to_reg()), + rhs_1, + )); + ctx.emit(Inst::gen_move(dst, rhs_1.to_reg(), ty)); + } } else if ty.lane_count() > 1 { // Emit single instruction lowerings for the remaining vector // multiplications.