x64: lower i64x2.imul to VPMULLQ when possible

This adds the machinery to encode the VPMULLQ instruction which is
available in AVX512VL and AVX512DQ. When these feature sets are
available, we use this instruction instead of a lengthy 12-instruction
sequence.
This commit is contained in:
Andrew Brown
2021-05-10 16:25:03 -07:00
parent 5929a5e6ee
commit e676589b0c
5 changed files with 195 additions and 91 deletions

View File

@@ -1663,105 +1663,116 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
Opcode::Imul => {
let ty = ty.unwrap();
if ty == types::I64X2 {
// For I64X2 multiplication we describe a lane A as being
// composed of a 32-bit upper half "Ah" and a 32-bit lower half
// "Al". The 32-bit long hand multiplication can then be written
// as:
// Ah Al
// * Bh Bl
// -----
// Al * Bl
// + (Ah * Bl) << 32
// + (Al * Bh) << 32
//
// So for each lane we will compute:
// A * B = (Al * Bl) + ((Ah * Bl) + (Al * Bh)) << 32
//
// Note, the algorithm will use pmuldq which operates directly
// on the lower 32-bit (Al or Bl) of a lane and writes the
// result to the full 64-bits of the lane of the destination.
// For this reason we don't need shifts to isolate the lower
// 32-bits, however, we will need to use shifts to isolate the
// high 32-bits when doing calculations, i.e., Ah == A >> 32.
//
// The full sequence then is as follows:
// A' = A
// A' = A' >> 32
// A' = Ah' * Bl
// B' = B
// B' = B' >> 32
// B' = Bh' * Al
// B' = B' + A'
// B' = B' << 32
// A' = A
// A' = Al' * Bl
// A' = A' + B'
// dst = A'
// Get inputs rhs=A and lhs=B and the dst register
// Eventually one of these should be `input_to_reg_mem` (TODO).
let lhs = put_input_in_reg(ctx, inputs[0]);
let rhs = put_input_in_reg(ctx, inputs[1]);
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
// A' = A
let rhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap();
ctx.emit(Inst::gen_move(rhs_1, rhs, ty));
if isa_flags.use_avx512f_simd() || isa_flags.use_avx512vl_simd() {
// With the right AVX512 features (VL, DQ) this operation
// can lower to a single operation.
ctx.emit(Inst::xmm_rm_r_evex(
Avx512Opcode::Vpmullq,
RegMem::reg(rhs),
lhs,
dst,
));
} else {
// Otherwise, for I64X2 multiplication we describe a lane A as being
// composed of a 32-bit upper half "Ah" and a 32-bit lower half
// "Al". The 32-bit long hand multiplication can then be written
// as:
// Ah Al
// * Bh Bl
// -----
// Al * Bl
// + (Ah * Bl) << 32
// + (Al * Bh) << 32
//
// So for each lane we will compute:
// A * B = (Al * Bl) + ((Ah * Bl) + (Al * Bh)) << 32
//
// Note, the algorithm will use pmuldq which operates directly
// on the lower 32-bit (Al or Bl) of a lane and writes the
// result to the full 64-bits of the lane of the destination.
// For this reason we don't need shifts to isolate the lower
// 32-bits, however, we will need to use shifts to isolate the
// high 32-bits when doing calculations, i.e., Ah == A >> 32.
//
// The full sequence then is as follows:
// A' = A
// A' = A' >> 32
// A' = Ah' * Bl
// B' = B
// B' = B' >> 32
// B' = Bh' * Al
// B' = B' + A'
// B' = B' << 32
// A' = A
// A' = Al' * Bl
// A' = A' + B'
// dst = A'
// A' = A' >> 32
// A' = Ah' * Bl
ctx.emit(Inst::xmm_rmi_reg(
SseOpcode::Psrlq,
RegMemImm::imm(32),
rhs_1,
));
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Pmuludq,
RegMem::reg(lhs.clone()),
rhs_1,
));
// A' = A
let rhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap();
ctx.emit(Inst::gen_move(rhs_1, rhs, ty));
// B' = B
let lhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap();
ctx.emit(Inst::gen_move(lhs_1, lhs, ty));
// A' = A' >> 32
// A' = Ah' * Bl
ctx.emit(Inst::xmm_rmi_reg(
SseOpcode::Psrlq,
RegMemImm::imm(32),
rhs_1,
));
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Pmuludq,
RegMem::reg(lhs.clone()),
rhs_1,
));
// B' = B' >> 32
// B' = Bh' * Al
ctx.emit(Inst::xmm_rmi_reg(
SseOpcode::Psrlq,
RegMemImm::imm(32),
lhs_1,
));
ctx.emit(Inst::xmm_rm_r(SseOpcode::Pmuludq, RegMem::reg(rhs), lhs_1));
// B' = B
let lhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap();
ctx.emit(Inst::gen_move(lhs_1, lhs, ty));
// B' = B' + A'
// B' = B' << 32
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Paddq,
RegMem::reg(rhs_1.to_reg()),
lhs_1,
));
ctx.emit(Inst::xmm_rmi_reg(
SseOpcode::Psllq,
RegMemImm::imm(32),
lhs_1,
));
// B' = B' >> 32
// B' = Bh' * Al
ctx.emit(Inst::xmm_rmi_reg(
SseOpcode::Psrlq,
RegMemImm::imm(32),
lhs_1,
));
ctx.emit(Inst::xmm_rm_r(SseOpcode::Pmuludq, RegMem::reg(rhs), lhs_1));
// A' = A
// A' = Al' * Bl
// A' = A' + B'
// dst = A'
ctx.emit(Inst::gen_move(rhs_1, rhs, ty));
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Pmuludq,
RegMem::reg(lhs.clone()),
rhs_1,
));
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Paddq,
RegMem::reg(lhs_1.to_reg()),
rhs_1,
));
ctx.emit(Inst::gen_move(dst, rhs_1.to_reg(), ty));
// B' = B' + A'
// B' = B' << 32
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Paddq,
RegMem::reg(rhs_1.to_reg()),
lhs_1,
));
ctx.emit(Inst::xmm_rmi_reg(
SseOpcode::Psllq,
RegMemImm::imm(32),
lhs_1,
));
// A' = A
// A' = Al' * Bl
// A' = A' + B'
// dst = A'
ctx.emit(Inst::gen_move(rhs_1, rhs, ty));
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Pmuludq,
RegMem::reg(lhs.clone()),
rhs_1,
));
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Paddq,
RegMem::reg(lhs_1.to_reg()),
rhs_1,
));
ctx.emit(Inst::gen_move(dst, rhs_1.to_reg(), ty));
}
} else if ty.lane_count() > 1 {
// Emit single instruction lowerings for the remaining vector
// multiplications.