diff --git a/cranelift/codegen/src/isa/x64/lower.rs b/cranelift/codegen/src/isa/x64/lower.rs index a1969d5642..3f62b375a7 100644 --- a/cranelift/codegen/src/isa/x64/lower.rs +++ b/cranelift/codegen/src/isa/x64/lower.rs @@ -1511,7 +1511,6 @@ fn lower_insn_to_regs>( | Opcode::Isub | Opcode::SsubSat | Opcode::UsubSat - | Opcode::Imul | Opcode::AvgRound | Opcode::Band | Opcode::Bor @@ -1553,112 +1552,6 @@ fn lower_insn_to_regs>( types::I16X8 => SseOpcode::Psubusw, _ => panic!("Unsupported type for packed usub_sat instruction: {}", ty), }, - Opcode::Imul => match ty { - types::I16X8 => SseOpcode::Pmullw, - types::I32X4 => SseOpcode::Pmulld, - types::I64X2 => { - // Note for I64X2 we describe a lane A as being composed of a - // 32-bit upper half "Ah" and a 32-bit lower half "Al". - // The 32-bit long hand multiplication can then be written as: - // Ah Al - // * Bh Bl - // ----- - // Al * Bl - // + (Ah * Bl) << 32 - // + (Al * Bh) << 32 - // - // So for each lane we will compute: - // A * B = (Al * Bl) + ((Ah * Bl) + (Al * Bh)) << 32 - // - // Note, the algorithm will use pmuldq which operates directly on - // the lower 32-bit (Al or Bl) of a lane and writes the result - // to the full 64-bits of the lane of the destination. For this - // reason we don't need shifts to isolate the lower 32-bits, however - // we will need to use shifts to isolate the high 32-bits when doing - // calculations, i.e. Ah == A >> 32 - // - // The full sequence then is as follows: - // A' = A - // A' = A' >> 32 - // A' = Ah' * Bl - // B' = B - // B' = B' >> 32 - // B' = Bh' * Al - // B' = B' + A' - // B' = B' << 32 - // A' = A - // A' = Al' * Bl - // A' = A' + B' - // dst = A' - - // Get inputs rhs=A and lhs=B and the dst register - let lhs = put_input_in_reg(ctx, inputs[0]); - let rhs = put_input_in_reg(ctx, inputs[1]); - let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap(); - - // A' = A - let rhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); - ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); - - // A' = A' >> 32 - // A' = Ah' * Bl - ctx.emit(Inst::xmm_rmi_reg( - SseOpcode::Psrlq, - RegMemImm::imm(32), - rhs_1, - )); - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Pmuludq, - RegMem::reg(lhs.clone()), - rhs_1, - )); - - // B' = B - let lhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); - ctx.emit(Inst::gen_move(lhs_1, lhs, ty)); - - // B' = B' >> 32 - // B' = Bh' * Al - ctx.emit(Inst::xmm_rmi_reg( - SseOpcode::Psrlq, - RegMemImm::imm(32), - lhs_1, - )); - ctx.emit(Inst::xmm_rm_r(SseOpcode::Pmuludq, RegMem::reg(rhs), lhs_1)); - - // B' = B' + A' - // B' = B' << 32 - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Paddq, - RegMem::reg(rhs_1.to_reg()), - lhs_1, - )); - ctx.emit(Inst::xmm_rmi_reg( - SseOpcode::Psllq, - RegMemImm::imm(32), - lhs_1, - )); - - // A' = A - // A' = Al' * Bl - // A' = A' + B' - // dst = A' - ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Pmuludq, - RegMem::reg(lhs.clone()), - rhs_1, - )); - ctx.emit(Inst::xmm_rm_r( - SseOpcode::Paddq, - RegMem::reg(lhs_1.to_reg()), - rhs_1, - )); - ctx.emit(Inst::gen_move(dst, rhs_1.to_reg(), ty)); - return Ok(()); - } - _ => panic!("Unsupported type for packed imul instruction: {}", ty), - }, Opcode::AvgRound => match ty { types::I8X16 => SseOpcode::Pavgb, types::I16X8 => SseOpcode::Pavgw, @@ -1692,8 +1585,6 @@ fn lower_insn_to_regs>( let alu_ops = match op { Opcode::Iadd => (AluRmiROpcode::Add, AluRmiROpcode::Adc), Opcode::Isub => (AluRmiROpcode::Sub, AluRmiROpcode::Sbb), - // multiply handled specially below - Opcode::Imul => (AluRmiROpcode::Mul, AluRmiROpcode::Mul), Opcode::Band => (AluRmiROpcode::And, AluRmiROpcode::And), Opcode::Bor => (AluRmiROpcode::Or, AluRmiROpcode::Or), Opcode::Bxor => (AluRmiROpcode::Xor, AluRmiROpcode::Xor), @@ -1706,84 +1597,22 @@ fn lower_insn_to_regs>( assert_eq!(rhs.len(), 2); assert_eq!(dst.len(), 2); - if op != Opcode::Imul { - // add, sub, and, or, xor: just do ops on lower then upper half. Carry-flag - // propagation is implicit (add/adc, sub/sbb). - ctx.emit(Inst::gen_move(dst.regs()[0], lhs.regs()[0], types::I64)); - ctx.emit(Inst::gen_move(dst.regs()[1], lhs.regs()[1], types::I64)); - ctx.emit(Inst::alu_rmi_r( - OperandSize::Size64, - alu_ops.0, - RegMemImm::reg(rhs.regs()[0]), - dst.regs()[0], - )); - ctx.emit(Inst::alu_rmi_r( - OperandSize::Size64, - alu_ops.1, - RegMemImm::reg(rhs.regs()[1]), - dst.regs()[1], - )); - } else { - // mul: - // dst_lo = lhs_lo * rhs_lo - // dst_hi = umulhi(lhs_lo, rhs_lo) + lhs_lo * rhs_hi + lhs_hi * rhs_lo - // - // so we emit: - // mov dst_lo, lhs_lo - // mul dst_lo, rhs_lo - // mov dst_hi, lhs_lo - // mul dst_hi, rhs_hi - // mov tmp, lhs_hi - // mul tmp, rhs_lo - // add dst_hi, tmp - // mov rax, lhs_lo - // umulhi rhs_lo // implicit rax arg/dst - // add dst_hi, rax - let tmp = ctx.alloc_tmp(types::I64).only_reg().unwrap(); - ctx.emit(Inst::gen_move(dst.regs()[0], lhs.regs()[0], types::I64)); - ctx.emit(Inst::alu_rmi_r( - OperandSize::Size64, - AluRmiROpcode::Mul, - RegMemImm::reg(rhs.regs()[0]), - dst.regs()[0], - )); - ctx.emit(Inst::gen_move(dst.regs()[1], lhs.regs()[0], types::I64)); - ctx.emit(Inst::alu_rmi_r( - OperandSize::Size64, - AluRmiROpcode::Mul, - RegMemImm::reg(rhs.regs()[1]), - dst.regs()[1], - )); - ctx.emit(Inst::gen_move(tmp, lhs.regs()[1], types::I64)); - ctx.emit(Inst::alu_rmi_r( - OperandSize::Size64, - AluRmiROpcode::Mul, - RegMemImm::reg(rhs.regs()[0]), - tmp, - )); - ctx.emit(Inst::alu_rmi_r( - OperandSize::Size64, - AluRmiROpcode::Add, - RegMemImm::reg(tmp.to_reg()), - dst.regs()[1], - )); - ctx.emit(Inst::gen_move( - Writable::from_reg(regs::rax()), - lhs.regs()[0], - types::I64, - )); - ctx.emit(Inst::mul_hi( - OperandSize::Size64, - /* signed = */ false, - RegMem::reg(rhs.regs()[0]), - )); - ctx.emit(Inst::alu_rmi_r( - OperandSize::Size64, - AluRmiROpcode::Add, - RegMemImm::reg(regs::rdx()), - dst.regs()[1], - )); - } + // For add, sub, and, or, xor: just do ops on lower then upper + // half. Carry-flag propagation is implicit (add/adc, sub/sbb). + ctx.emit(Inst::gen_move(dst.regs()[0], lhs.regs()[0], types::I64)); + ctx.emit(Inst::gen_move(dst.regs()[1], lhs.regs()[1], types::I64)); + ctx.emit(Inst::alu_rmi_r( + OperandSize::Size64, + alu_ops.0, + RegMemImm::reg(rhs.regs()[0]), + dst.regs()[0], + )); + ctx.emit(Inst::alu_rmi_r( + OperandSize::Size64, + alu_ops.1, + RegMemImm::reg(rhs.regs()[1]), + dst.regs()[1], + )); } else { let size = if ty == types::I64 { OperandSize::Size64 @@ -1793,7 +1622,6 @@ fn lower_insn_to_regs>( let alu_op = match op { Opcode::Iadd | Opcode::IaddIfcout => AluRmiROpcode::Add, Opcode::Isub => AluRmiROpcode::Sub, - Opcode::Imul => AluRmiROpcode::Mul, Opcode::Band => AluRmiROpcode::And, Opcode::Bor => AluRmiROpcode::Or, Opcode::Bxor => AluRmiROpcode::Xor, @@ -1803,7 +1631,6 @@ fn lower_insn_to_regs>( let (lhs, rhs) = match op { Opcode::Iadd | Opcode::IaddIfcout - | Opcode::Imul | Opcode::Band | Opcode::Bor | Opcode::Bxor => { @@ -1833,6 +1660,218 @@ fn lower_insn_to_regs>( } } + Opcode::Imul => { + let ty = ty.unwrap(); + if ty == types::I64X2 { + // For I64X2 multiplication we describe a lane A as being + // composed of a 32-bit upper half "Ah" and a 32-bit lower half + // "Al". The 32-bit long hand multiplication can then be written + // as: + // Ah Al + // * Bh Bl + // ----- + // Al * Bl + // + (Ah * Bl) << 32 + // + (Al * Bh) << 32 + // + // So for each lane we will compute: + // A * B = (Al * Bl) + ((Ah * Bl) + (Al * Bh)) << 32 + // + // Note, the algorithm will use pmuldq which operates directly + // on the lower 32-bit (Al or Bl) of a lane and writes the + // result to the full 64-bits of the lane of the destination. + // For this reason we don't need shifts to isolate the lower + // 32-bits, however, we will need to use shifts to isolate the + // high 32-bits when doing calculations, i.e., Ah == A >> 32. + // + // The full sequence then is as follows: + // A' = A + // A' = A' >> 32 + // A' = Ah' * Bl + // B' = B + // B' = B' >> 32 + // B' = Bh' * Al + // B' = B' + A' + // B' = B' << 32 + // A' = A + // A' = Al' * Bl + // A' = A' + B' + // dst = A' + + // Get inputs rhs=A and lhs=B and the dst register + let lhs = put_input_in_reg(ctx, inputs[0]); + let rhs = put_input_in_reg(ctx, inputs[1]); + let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap(); + + // A' = A + let rhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); + ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); + + // A' = A' >> 32 + // A' = Ah' * Bl + ctx.emit(Inst::xmm_rmi_reg( + SseOpcode::Psrlq, + RegMemImm::imm(32), + rhs_1, + )); + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Pmuludq, + RegMem::reg(lhs.clone()), + rhs_1, + )); + + // B' = B + let lhs_1 = ctx.alloc_tmp(types::I64X2).only_reg().unwrap(); + ctx.emit(Inst::gen_move(lhs_1, lhs, ty)); + + // B' = B' >> 32 + // B' = Bh' * Al + ctx.emit(Inst::xmm_rmi_reg( + SseOpcode::Psrlq, + RegMemImm::imm(32), + lhs_1, + )); + ctx.emit(Inst::xmm_rm_r(SseOpcode::Pmuludq, RegMem::reg(rhs), lhs_1)); + + // B' = B' + A' + // B' = B' << 32 + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Paddq, + RegMem::reg(rhs_1.to_reg()), + lhs_1, + )); + ctx.emit(Inst::xmm_rmi_reg( + SseOpcode::Psllq, + RegMemImm::imm(32), + lhs_1, + )); + + // A' = A + // A' = Al' * Bl + // A' = A' + B' + // dst = A' + ctx.emit(Inst::gen_move(rhs_1, rhs, ty)); + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Pmuludq, + RegMem::reg(lhs.clone()), + rhs_1, + )); + ctx.emit(Inst::xmm_rm_r( + SseOpcode::Paddq, + RegMem::reg(lhs_1.to_reg()), + rhs_1, + )); + ctx.emit(Inst::gen_move(dst, rhs_1.to_reg(), ty)); + } else if ty.lane_count() > 1 { + // Emit single instruction lowerings for the remaining vector + // multiplications. + let sse_op = match ty { + types::I16X8 => SseOpcode::Pmullw, + types::I32X4 => SseOpcode::Pmulld, + _ => panic!("Unsupported type for packed imul instruction: {}", ty), + }; + let lhs = put_input_in_reg(ctx, inputs[0]); + let rhs = input_to_reg_mem(ctx, inputs[1]); + let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap(); + + // Move the `lhs` to the same register as `dst`. + ctx.emit(Inst::gen_move(dst, lhs, ty)); + ctx.emit(Inst::xmm_rm_r(sse_op, rhs, dst)); + } else if ty == types::I128 || ty == types::B128 { + // Handle 128-bit multiplications. + let lhs = put_input_in_regs(ctx, inputs[0]); + let rhs = put_input_in_regs(ctx, inputs[1]); + let dst = get_output_reg(ctx, outputs[0]); + assert_eq!(lhs.len(), 2); + assert_eq!(rhs.len(), 2); + assert_eq!(dst.len(), 2); + + // mul: + // dst_lo = lhs_lo * rhs_lo + // dst_hi = umulhi(lhs_lo, rhs_lo) + lhs_lo * rhs_hi + lhs_hi * rhs_lo + // + // so we emit: + // mov dst_lo, lhs_lo + // mul dst_lo, rhs_lo + // mov dst_hi, lhs_lo + // mul dst_hi, rhs_hi + // mov tmp, lhs_hi + // mul tmp, rhs_lo + // add dst_hi, tmp + // mov rax, lhs_lo + // umulhi rhs_lo // implicit rax arg/dst + // add dst_hi, rax + let tmp = ctx.alloc_tmp(types::I64).only_reg().unwrap(); + ctx.emit(Inst::gen_move(dst.regs()[0], lhs.regs()[0], types::I64)); + ctx.emit(Inst::alu_rmi_r( + OperandSize::Size64, + AluRmiROpcode::Mul, + RegMemImm::reg(rhs.regs()[0]), + dst.regs()[0], + )); + ctx.emit(Inst::gen_move(dst.regs()[1], lhs.regs()[0], types::I64)); + ctx.emit(Inst::alu_rmi_r( + OperandSize::Size64, + AluRmiROpcode::Mul, + RegMemImm::reg(rhs.regs()[1]), + dst.regs()[1], + )); + ctx.emit(Inst::gen_move(tmp, lhs.regs()[1], types::I64)); + ctx.emit(Inst::alu_rmi_r( + OperandSize::Size64, + AluRmiROpcode::Mul, + RegMemImm::reg(rhs.regs()[0]), + tmp, + )); + ctx.emit(Inst::alu_rmi_r( + OperandSize::Size64, + AluRmiROpcode::Add, + RegMemImm::reg(tmp.to_reg()), + dst.regs()[1], + )); + ctx.emit(Inst::gen_move( + Writable::from_reg(regs::rax()), + lhs.regs()[0], + types::I64, + )); + ctx.emit(Inst::mul_hi( + OperandSize::Size64, + /* signed = */ false, + RegMem::reg(rhs.regs()[0]), + )); + ctx.emit(Inst::alu_rmi_r( + OperandSize::Size64, + AluRmiROpcode::Add, + RegMemImm::reg(regs::rdx()), + dst.regs()[1], + )); + } else { + let size = if ty == types::I64 { + OperandSize::Size64 + } else { + OperandSize::Size32 + }; + let alu_op = AluRmiROpcode::Mul; + + // For commutative operations, try to commute operands if one is + // an immediate or direct memory reference. Do so by converting + // LHS to RMI; if reg, then always convert RHS to RMI; else, use + // LHS as RMI and convert RHS to reg. + let lhs = input_to_reg_mem_imm(ctx, inputs[0]); + let (lhs, rhs) = if let RegMemImm::Reg { reg: lhs_reg } = lhs { + let rhs = input_to_reg_mem_imm(ctx, inputs[1]); + (lhs_reg, rhs) + } else { + let rhs_reg = put_input_in_reg(ctx, inputs[1]); + (rhs_reg, lhs) + }; + + let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap(); + ctx.emit(Inst::mov_r_r(OperandSize::Size64, lhs, dst)); + ctx.emit(Inst::alu_rmi_r(size, alu_op, rhs, dst)); + } + } + Opcode::BandNot => { let ty = ty.unwrap(); debug_assert!(ty.is_vector() && ty.bytes() == 16);