riscv64: Improve signed and zero extend codegen (#5844)
* riscv64: Remove unused code * riscv64: Group extend rules * riscv64: Remove more unused rules * riscv64: Cleanup existing extension rules * riscv64: Move the existing Extend rules to ISLE * riscv64: Use `sext.w` when extending * riscv64: Remove duplicate extend tests * riscv64: Use `zbb` instructions when extending values * riscv64: Use `zbkb` extensions when zero extending * riscv64: Enable additional tests for extend i128 * riscv64: Fix formatting for `Inst::Extend` * riscv64: Reverse register for pack * riscv64: Misc Cleanups * riscv64: Cleanup extend rules
This commit is contained in:
@@ -569,6 +569,11 @@
|
||||
(Clmul)
|
||||
(Clmulh)
|
||||
(Clmulr)
|
||||
|
||||
;; Zbkb: Bit-manipulation for Cryptography
|
||||
(Pack)
|
||||
(Packw)
|
||||
(Packh)
|
||||
))
|
||||
|
||||
|
||||
@@ -858,22 +863,6 @@
|
||||
(_ Unit (emit (MInst.AluRRImm12 op dst src (imm12_zero)))))
|
||||
dst))
|
||||
|
||||
;; extend int if need.
|
||||
(decl ext_int_if_need (bool ValueRegs Type) ValueRegs)
|
||||
;;; for I8, I16, and I32 ...
|
||||
(rule -1
|
||||
(ext_int_if_need signed val ty)
|
||||
(gen_extend val signed (ty_bits ty) 64))
|
||||
;;; otherwise this is a I64 or I128
|
||||
;;; no need to extend.
|
||||
(rule
|
||||
(ext_int_if_need _ r $I64)
|
||||
r)
|
||||
(rule
|
||||
(ext_int_if_need _ r $I128)
|
||||
r)
|
||||
|
||||
|
||||
;; Helper for get negative of Imm12
|
||||
(decl neg_imm12 (Imm12) Imm12)
|
||||
(extern constructor neg_imm12 neg_imm12)
|
||||
@@ -1031,50 +1020,116 @@
|
||||
;; add low and high together.
|
||||
(result Reg (alu_add high low)))
|
||||
(value_regs result (load_u64_constant 0))))
|
||||
|
||||
(decl gen_extend (Reg bool u8 u8) Reg)
|
||||
(rule
|
||||
(gen_extend r is_signed from_bits to_bits)
|
||||
(let
|
||||
((tmp WritableReg (temp_writable_reg $I16))
|
||||
(_ Unit (emit (MInst.Extend tmp r is_signed from_bits to_bits))))
|
||||
tmp))
|
||||
|
||||
;; val is_signed from_bits to_bits
|
||||
(decl lower_extend (Reg bool u8 u8) ValueRegs)
|
||||
(rule -1
|
||||
(lower_extend r is_signed from_bits to_bits)
|
||||
(gen_extend r is_signed from_bits to_bits))
|
||||
|
||||
;;;; for I128 signed extend.
|
||||
(rule 1
|
||||
(lower_extend r $true 64 128)
|
||||
(let
|
||||
((tmp Reg (alu_rrr (AluOPRRR.Slt) r (zero_reg)))
|
||||
(high Reg (gen_extend tmp $true 1 64)))
|
||||
(value_regs (gen_move2 r $I64 $I64) high)))
|
||||
|
||||
(rule
|
||||
(lower_extend r $true from_bits 128)
|
||||
(let
|
||||
((tmp Reg (gen_extend r $true from_bits 64))
|
||||
(tmp2 Reg (alu_rrr (AluOPRRR.Slt) tmp (zero_reg)))
|
||||
(high Reg (gen_extend tmp2 $true 1 64)))
|
||||
(value_regs (gen_move2 tmp $I64 $I64) high)))
|
||||
|
||||
;; Extends an integer if it is smaller than 64 bits.
|
||||
(decl ext_int_if_need (bool ValueRegs Type) ValueRegs)
|
||||
;;; For values smaller than 64 bits, we need to extend them to 64 bits
|
||||
(rule 0 (ext_int_if_need $true val (fits_in_32 (ty_int ty)))
|
||||
(sext val ty $I64))
|
||||
(rule 0 (ext_int_if_need $false val (fits_in_32 (ty_int ty)))
|
||||
(zext val ty $I64))
|
||||
;; If the value is larger than one machine register, we don't need to do anything
|
||||
(rule 1 (ext_int_if_need _ r $I64) r)
|
||||
(rule 2 (ext_int_if_need _ r $I128) r)
|
||||
|
||||
|
||||
;;;; for I128 unsigned extend.
|
||||
(rule 1
|
||||
(lower_extend r $false 64 128)
|
||||
(value_regs (gen_move2 r $I64 $I64) (load_u64_constant 0)))
|
||||
;; Performs a zero extension of the given value
|
||||
(decl zext (ValueRegs Type Type) ValueRegs)
|
||||
(rule (zext val from_ty to_ty) (extend val (ExtendOp.Zero) from_ty to_ty))
|
||||
|
||||
;; Performs a signed extension of the given value
|
||||
(decl sext (ValueRegs Type Type) ValueRegs)
|
||||
(rule (sext val from_ty to_ty) (extend val (ExtendOp.Signed) from_ty to_ty))
|
||||
|
||||
(type ExtendOp
|
||||
(enum
|
||||
(Zero)
|
||||
(Signed)))
|
||||
|
||||
;; Performs either a sign or zero extension of the given value
|
||||
(decl extend (ValueRegs ExtendOp Type Type) ValueRegs)
|
||||
|
||||
;;; Generic Rules Extending to I64
|
||||
(decl pure extend_shift_op (ExtendOp) AluOPRRI)
|
||||
(rule (extend_shift_op (ExtendOp.Zero)) (AluOPRRI.Srli))
|
||||
(rule (extend_shift_op (ExtendOp.Signed)) (AluOPRRI.Srai))
|
||||
|
||||
;; In the most generic case, we shift left and then shift right.
|
||||
;; The type of right shift is determined by the extend op.
|
||||
(rule 0 (extend val extend_op (fits_in_32 from_ty) (fits_in_64 to_ty))
|
||||
(let ((val Reg (value_regs_get val 0))
|
||||
(shift Imm12 (imm_from_bits (u64_sub 64 (ty_bits from_ty))))
|
||||
(left Reg (alu_rr_imm12 (AluOPRRI.Slli) val shift))
|
||||
(shift_op AluOPRRI (extend_shift_op extend_op))
|
||||
(right Reg (alu_rr_imm12 shift_op left shift)))
|
||||
right))
|
||||
|
||||
;; If we are zero extending a U8 we can use a `andi` instruction.
|
||||
(rule 1 (extend val (ExtendOp.Zero) $I8 (fits_in_64 to_ty))
|
||||
(let ((val Reg (value_regs_get val 0)))
|
||||
(alu_rr_imm12 (AluOPRRI.Andi) val (imm12_const 255))))
|
||||
|
||||
;; When signed extending from 32 to 64 bits we can use a
|
||||
;; `addiw val 0`. Also known as a `sext.w`
|
||||
(rule 1 (extend val (ExtendOp.Signed) $I32 $I64)
|
||||
(let ((val Reg (value_regs_get val 0)))
|
||||
(alu_rr_imm12 (AluOPRRI.Addiw) val (imm12_const 0))))
|
||||
|
||||
|
||||
;; No point in trying to use `packh` here to zero extend 8 bit values
|
||||
;; since we can just use `andi` instead which is part of the base ISA.
|
||||
|
||||
;; If we have the `zbkb` extension `packw` can be used to zero extend 16 bit values
|
||||
(rule 1 (extend val (ExtendOp.Zero) $I16 (fits_in_64 _))
|
||||
(if-let $true (has_zbkb))
|
||||
(let ((val Reg (value_regs_get val 0)))
|
||||
(alu_rrr (AluOPRRR.Packw) val (zero_reg))))
|
||||
|
||||
;; If we have the `zbkb` extension `pack` can be used to zero extend 32 bit registers
|
||||
(rule 1 (extend val (ExtendOp.Zero) $I32 $I64)
|
||||
(if-let $true (has_zbkb))
|
||||
(let ((val Reg (value_regs_get val 0)))
|
||||
(alu_rrr (AluOPRRR.Pack) val (zero_reg))))
|
||||
|
||||
|
||||
;; If we have the `zbb` extension we can use the dedicated `sext.b` instruction.
|
||||
(rule 1 (extend val (ExtendOp.Signed) $I8 (fits_in_64 _))
|
||||
(if-let $true (has_zbb))
|
||||
(let ((val Reg (value_regs_get val 0)))
|
||||
(alu_rr_imm12 (AluOPRRI.Sextb) val (imm12_const 0))))
|
||||
|
||||
;; If we have the `zbb` extension we can use the dedicated `sext.h` instruction.
|
||||
(rule 1 (extend val (ExtendOp.Signed) $I16 (fits_in_64 _))
|
||||
(if-let $true (has_zbb))
|
||||
(let ((val Reg (value_regs_get val 0)))
|
||||
(alu_rr_imm12 (AluOPRRI.Sexth) val (imm12_const 0))))
|
||||
|
||||
;; If we have the `zbb` extension we can use the dedicated `zext.h` instruction.
|
||||
(rule 2 (extend val (ExtendOp.Zero) $I16 (fits_in_64 _))
|
||||
(if-let $true (has_zbb))
|
||||
(let ((val Reg (value_regs_get val 0)))
|
||||
(alu_rr_imm12 (AluOPRRI.Zexth) val (imm12_const 0))))
|
||||
|
||||
;;; Signed rules extending to I128
|
||||
;; Extend the bottom part, and extract the sign bit from the bottom as the top
|
||||
(rule 2 (extend val (ExtendOp.Signed) (fits_in_64 from_ty) $I128)
|
||||
(let ((val Reg (value_regs_get val 0))
|
||||
(low Reg (extend val (ExtendOp.Signed) from_ty $I64))
|
||||
(high Reg (alu_rr_imm12 (AluOPRRI.Srai) low (imm12_const 63))))
|
||||
(value_regs low high)))
|
||||
|
||||
;;; Unsigned rules extending to I128
|
||||
;; Extend the bottom register to I64 and then just zero out the top half.
|
||||
(rule 3 (extend val (ExtendOp.Zero) (fits_in_64 from_ty) $I128)
|
||||
(let ((val Reg (value_regs_get val 0))
|
||||
(low Reg (extend val (ExtendOp.Zero) from_ty $I64))
|
||||
(high Reg (load_u64_constant 0)))
|
||||
(value_regs low high)))
|
||||
|
||||
;; Catch all rule for ignoring extensions of the same type.
|
||||
(rule 4 (extend val _ ty ty) val)
|
||||
|
||||
(rule
|
||||
(lower_extend r $false from_bits 128)
|
||||
(value_regs (gen_extend r $false from_bits 64) (load_u64_constant 0)))
|
||||
|
||||
;; extract the sign bit of integer.
|
||||
(decl ext_sign_bit (Type Reg) Reg)
|
||||
(extern constructor ext_sign_bit ext_sign_bit)
|
||||
|
||||
(decl lower_b128_binary (AluOPRRR ValueRegs ValueRegs) ValueRegs)
|
||||
(rule
|
||||
@@ -1795,50 +1850,6 @@
|
||||
(rule (lower_icmp cc x y ty)
|
||||
(gen_icmp cc (ext_int_if_need $false x ty) (ext_int_if_need $false y ty) ty))
|
||||
|
||||
(decl lower_icmp_over_flow (ValueRegs ValueRegs Type) Reg)
|
||||
|
||||
;;; for I8 I16 I32
|
||||
(rule 1
|
||||
(lower_icmp_over_flow x y ty)
|
||||
(let
|
||||
((tmp Reg (alu_sub (ext_int_if_need $true x ty) (ext_int_if_need $true y ty)))
|
||||
(tmp2 WritableReg (temp_writable_reg $I64))
|
||||
(_ Unit (emit (MInst.Extend tmp2 tmp $true (ty_bits ty) 64))))
|
||||
(gen_icmp (IntCC.NotEqual) (writable_reg_to_reg tmp2) tmp $I64)))
|
||||
|
||||
;;; $I64
|
||||
(rule 3
|
||||
(lower_icmp_over_flow x y $I64)
|
||||
(let
|
||||
((y_sign Reg (alu_rrr (AluOPRRR.Sgt) y (zero_reg)))
|
||||
(sub_result Reg (alu_sub x y))
|
||||
(tmp Reg (alu_rrr (AluOPRRR.Slt) sub_result x)))
|
||||
(gen_icmp (IntCC.NotEqual) y_sign tmp $I64)))
|
||||
|
||||
;;; $I128
|
||||
(rule 2
|
||||
(lower_icmp_over_flow x y $I128)
|
||||
(let
|
||||
( ;; x sign bit.
|
||||
(xs Reg (alu_rr_imm12 (AluOPRRI.Srli) (value_regs_get x 1) (imm12_const 63)))
|
||||
;; y sign bit.
|
||||
(ys Reg (alu_rr_imm12 (AluOPRRI.Srli) (value_regs_get y 1) (imm12_const 63)))
|
||||
;;
|
||||
(sub_result ValueRegs (i128_sub x y))
|
||||
;; result sign bit.
|
||||
(rs Reg (alu_rr_imm12 (AluOPRRI.Srli) (value_regs_get sub_result 1) (imm12_const 63)))
|
||||
|
||||
;;; xs && !ys && !rs
|
||||
;;; x is positive y is negtive and result is negative.
|
||||
;;; must overflow
|
||||
(tmp1 Reg (alu_and xs (alu_and (gen_bit_not ys) (gen_bit_not rs))))
|
||||
;;; !xs && ys && rs
|
||||
;;; x is negative y is positive and result is positive.
|
||||
;;; overflow
|
||||
(tmp2 Reg (alu_and (gen_bit_not xs) (alu_and ys rs)))
|
||||
;;;tmp3
|
||||
(tmp3 Reg (alu_rrr (AluOPRRR.Or) tmp1 tmp2)))
|
||||
(gen_extend tmp3 $true 1 64)))
|
||||
|
||||
(decl i128_sub (ValueRegs ValueRegs) ValueRegs)
|
||||
(rule
|
||||
|
||||
@@ -746,6 +746,9 @@ impl AluOPRRR {
|
||||
Self::Sh3add => "sh3add",
|
||||
Self::Sh3adduw => "sh3add.uw",
|
||||
Self::Xnor => "xnor",
|
||||
Self::Pack => "pack",
|
||||
Self::Packw => "packw",
|
||||
Self::Packh => "packh",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -785,6 +788,7 @@ impl AluOPRRR {
|
||||
AluOPRRR::Remw => 0b110,
|
||||
AluOPRRR::Remuw => 0b111,
|
||||
|
||||
// Zbb
|
||||
AluOPRRR::Adduw => 0b000,
|
||||
AluOPRRR::Andn => 0b111,
|
||||
AluOPRRR::Bclr => 0b001,
|
||||
@@ -810,6 +814,11 @@ impl AluOPRRR {
|
||||
AluOPRRR::Sh3add => 0b110,
|
||||
AluOPRRR::Sh3adduw => 0b110,
|
||||
AluOPRRR::Xnor => 0b100,
|
||||
|
||||
// Zbkb
|
||||
AluOPRRR::Pack => 0b100,
|
||||
AluOPRRR::Packw => 0b100,
|
||||
AluOPRRR::Packh => 0b111,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -826,11 +835,16 @@ impl AluOPRRR {
|
||||
| AluOPRRR::Srl
|
||||
| AluOPRRR::Sra
|
||||
| AluOPRRR::Or
|
||||
| AluOPRRR::And => 0b0110011,
|
||||
| AluOPRRR::And
|
||||
| AluOPRRR::Pack
|
||||
| AluOPRRR::Packh => 0b0110011,
|
||||
|
||||
AluOPRRR::Addw | AluOPRRR::Subw | AluOPRRR::Sllw | AluOPRRR::Srlw | AluOPRRR::Sraw => {
|
||||
0b0111011
|
||||
}
|
||||
AluOPRRR::Addw
|
||||
| AluOPRRR::Subw
|
||||
| AluOPRRR::Sllw
|
||||
| AluOPRRR::Srlw
|
||||
| AluOPRRR::Sraw
|
||||
| AluOPRRR::Packw => 0b0111011,
|
||||
|
||||
AluOPRRR::Mul
|
||||
| AluOPRRR::Mulh
|
||||
@@ -937,6 +951,11 @@ impl AluOPRRR {
|
||||
AluOPRRR::Sh3add => 0b0010000,
|
||||
AluOPRRR::Sh3adduw => 0b0010000,
|
||||
AluOPRRR::Xnor => 0b0100000,
|
||||
|
||||
// Zbkb
|
||||
AluOPRRR::Pack => 0b0000100,
|
||||
AluOPRRR::Packw => 0b0000100,
|
||||
AluOPRRR::Packh => 0b0000100,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -514,6 +514,38 @@ fn test_riscv64_binemit() {
|
||||
0x400545b3,
|
||||
));
|
||||
|
||||
// Zbkb
|
||||
insns.push(TestUnit::new(
|
||||
Inst::AluRRR {
|
||||
alu_op: AluOPRRR::Pack,
|
||||
rd: writable_a1(),
|
||||
rs1: a0(),
|
||||
rs2: zero_reg(),
|
||||
},
|
||||
"pack a1,a0,zero",
|
||||
0x080545b3,
|
||||
));
|
||||
insns.push(TestUnit::new(
|
||||
Inst::AluRRR {
|
||||
alu_op: AluOPRRR::Packw,
|
||||
rd: writable_a1(),
|
||||
rs1: a0(),
|
||||
rs2: zero_reg(),
|
||||
},
|
||||
"packw a1,a0,zero",
|
||||
0x080545bb,
|
||||
));
|
||||
insns.push(TestUnit::new(
|
||||
Inst::AluRRR {
|
||||
alu_op: AluOPRRR::Packh,
|
||||
rd: writable_a1(),
|
||||
rs1: a0(),
|
||||
rs2: zero_reg(),
|
||||
},
|
||||
"packh a1,a0,zero",
|
||||
0x080575b3,
|
||||
));
|
||||
|
||||
//
|
||||
insns.push(TestUnit::new(
|
||||
Inst::AluRRR {
|
||||
|
||||
@@ -845,16 +845,6 @@ impl Inst {
|
||||
x
|
||||
};
|
||||
|
||||
fn format_extend_op(signed: bool, from_bits: u8, _to_bits: u8) -> String {
|
||||
let type_name = match from_bits {
|
||||
1 => "b1",
|
||||
8 => "b",
|
||||
16 => "h",
|
||||
32 => "w",
|
||||
_ => unreachable!("from_bits:{:?}", from_bits),
|
||||
};
|
||||
format!("{}ext.{}", if signed { "s" } else { "u" }, type_name)
|
||||
}
|
||||
fn format_frm(rounding_mode: Option<FRM>) -> String {
|
||||
if let Some(r) = rounding_mode {
|
||||
format!(",{}", r.to_static_str(),)
|
||||
@@ -1341,15 +1331,23 @@ impl Inst {
|
||||
} => {
|
||||
let rs_s = format_reg(rs, allocs);
|
||||
let rd = format_reg(rd.to_reg(), allocs);
|
||||
// check if it is a load constant.
|
||||
if alu_op == AluOPRRI::Addi && rs == zero_reg() {
|
||||
format!("li {},{}", rd, imm12.as_i16())
|
||||
} else if alu_op == AluOPRRI::Xori && imm12.as_i16() == -1 {
|
||||
format!("not {},{}", rd, rs_s)
|
||||
} else {
|
||||
if alu_op.option_funct12().is_some() {
|
||||
|
||||
// Some of these special cases are better known as
|
||||
// their pseudo-instruction version, so prefer printing those.
|
||||
match (alu_op, rs, imm12) {
|
||||
(AluOPRRI::Addi, rs, _) if rs == zero_reg() => {
|
||||
return format!("li {},{}", rd, imm12.as_i16());
|
||||
}
|
||||
(AluOPRRI::Addiw, _, imm12) if imm12.as_i16() == 0 => {
|
||||
return format!("sext.w {},{}", rd, rs_s);
|
||||
}
|
||||
(AluOPRRI::Xori, _, imm12) if imm12.as_i16() == -1 => {
|
||||
return format!("not {},{}", rd, rs_s);
|
||||
}
|
||||
(alu_op, _, _) if alu_op.option_funct12().is_some() => {
|
||||
format!("{} {},{}", alu_op.op_name(), rd, rs_s)
|
||||
} else {
|
||||
}
|
||||
(alu_op, _, imm12) => {
|
||||
format!("{} {},{},{}", alu_op.op_name(), rd, rs_s, imm12.as_i16())
|
||||
}
|
||||
}
|
||||
@@ -1402,16 +1400,17 @@ impl Inst {
|
||||
rn,
|
||||
signed,
|
||||
from_bits,
|
||||
to_bits,
|
||||
..
|
||||
} => {
|
||||
let rn = format_reg(rn, allocs);
|
||||
let rm = format_reg(rd.to_reg(), allocs);
|
||||
format!(
|
||||
"{} {},{}",
|
||||
format_extend_op(signed, from_bits, to_bits),
|
||||
rm,
|
||||
rn
|
||||
)
|
||||
let rd = format_reg(rd.to_reg(), allocs);
|
||||
return if signed == false && from_bits == 8 {
|
||||
format!("andi {rd},{rn}")
|
||||
} else {
|
||||
let op = if signed { "srai" } else { "srli" };
|
||||
let shift_bits = (64 - from_bits) as i16;
|
||||
format!("slli {rd},{rn},{shift_bits}; {op} {rd},{rd},{shift_bits}")
|
||||
};
|
||||
}
|
||||
&MInst::AjustSp { amount } => {
|
||||
format!("{} sp,{:+}", "add", amount)
|
||||
|
||||
@@ -328,12 +328,12 @@
|
||||
(lower_clz_i128 x))
|
||||
|
||||
;;;; Rules for `uextend` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
|
||||
(rule (lower (has_type out (uextend x @ (value_type in))))
|
||||
(lower_extend x $false (ty_bits in) (ty_bits out)))
|
||||
(rule (lower (has_type out_ty (uextend val @ (value_type in_ty))))
|
||||
(zext val in_ty out_ty))
|
||||
|
||||
;;;; Rules for `sextend` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
|
||||
(rule (lower (has_type out (sextend x @ (value_type in))))
|
||||
(lower_extend x $true (ty_bits in) (ty_bits out)))
|
||||
(rule (lower (has_type out_ty (sextend val @ (value_type in_ty))))
|
||||
(sext val in_ty out_ty))
|
||||
|
||||
|
||||
;;;; Rules for `popcnt` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
|
||||
|
||||
@@ -229,17 +229,6 @@ impl generated_code::Context for IsleContext<'_, '_, MInst, Riscv64Backend> {
|
||||
x as i32
|
||||
}
|
||||
|
||||
fn ext_sign_bit(&mut self, ty: Type, r: Reg) -> Reg {
|
||||
assert!(ty.is_int());
|
||||
let rd = self.temp_writable_reg(I64);
|
||||
self.emit(&MInst::AluRRImm12 {
|
||||
alu_op: AluOPRRI::Bexti,
|
||||
rd,
|
||||
rs: r,
|
||||
imm12: Imm12::from_bits((ty.bits() - 1) as i16),
|
||||
});
|
||||
rd.to_reg()
|
||||
}
|
||||
fn imm12_const(&mut self, val: i32) -> Imm12 {
|
||||
if let Some(res) = Imm12::maybe_from_u64(val as u64) {
|
||||
res
|
||||
|
||||
Reference in New Issue
Block a user