diff --git a/cranelift/codegen/src/isa/aarch64/lower_inst.rs b/cranelift/codegen/src/isa/aarch64/lower_inst.rs index dbabc9d58c..5e2fed9064 100644 --- a/cranelift/codegen/src/isa/aarch64/lower_inst.rs +++ b/cranelift/codegen/src/isa/aarch64/lower_inst.rs @@ -74,23 +74,51 @@ pub(crate) fn lower_insn_to_regs>( } Opcode::Iadd => { let rd = get_output_reg(ctx, outputs[0]); - let rn = put_input_in_reg(ctx, inputs[0], NarrowValueMode::None); let ty = ty.unwrap(); if !ty.is_vector() { - let (rm, negated) = put_input_in_rse_imm12_maybe_negated( - ctx, - inputs[1], - ty_bits(ty), - NarrowValueMode::None, - ); - let alu_op = if !negated { - choose_32_64(ty, ALUOp::Add32, ALUOp::Add64) + let mul_insn = + if let Some(mul_insn) = maybe_input_insn(ctx, inputs[1], Opcode::Imul) { + Some((mul_insn, 0)) + } else if let Some(mul_insn) = maybe_input_insn(ctx, inputs[0], Opcode::Imul) { + Some((mul_insn, 1)) + } else { + None + }; + // If possible combine mul + add into madd. + if let Some((insn, addend_idx)) = mul_insn { + let alu_op = choose_32_64(ty, ALUOp3::MAdd32, ALUOp3::MAdd64); + let rn_input = InsnInput { insn, input: 0 }; + let rm_input = InsnInput { insn, input: 1 }; + + let rn = put_input_in_reg(ctx, rn_input, NarrowValueMode::None); + let rm = put_input_in_reg(ctx, rm_input, NarrowValueMode::None); + let ra = put_input_in_reg(ctx, inputs[addend_idx], NarrowValueMode::None); + + ctx.emit(Inst::AluRRRR { + alu_op, + rd, + rn, + rm, + ra, + }); } else { - choose_32_64(ty, ALUOp::Sub32, ALUOp::Sub64) - }; - ctx.emit(alu_inst_imm12(alu_op, rd, rn, rm)); + let rn = put_input_in_reg(ctx, inputs[0], NarrowValueMode::None); + let (rm, negated) = put_input_in_rse_imm12_maybe_negated( + ctx, + inputs[1], + ty_bits(ty), + NarrowValueMode::None, + ); + let alu_op = if !negated { + choose_32_64(ty, ALUOp::Add32, ALUOp::Add64) + } else { + choose_32_64(ty, ALUOp::Sub32, ALUOp::Sub64) + }; + ctx.emit(alu_inst_imm12(alu_op, rd, rn, rm)); + } } else { let rm = put_input_in_reg(ctx, inputs[1], NarrowValueMode::None); + let rn = put_input_in_reg(ctx, inputs[0], NarrowValueMode::None); ctx.emit(Inst::VecRRR { rd, rn,