riscv64: Improve ctz/clz/cls codegen (#5854)

* cranelift: Add extra runtests for `clz`/`ctz`

* riscv64: Restrict lowering rules for `ctz`/`clz`

* cranelift: Add `u64` isle helpers

* riscv64: Improve `ctz` codegen

* riscv64: Improve `clz` codegen

* riscv64: Improve `cls` codegen

* riscv64: Improve `clz.i128` codegen

Instead of checking if we have 64 zeros in the top half. Check
if it *is* 0, that way we avoid loading the `64` constant.

* riscv64: Improve `ctz.i128` codegen

Instead of checking if we have 64 zeros in the bottom half. Check
if it *is* 0, that way we avoid loading the `64` constant.

* riscv64: Use extended value in `lower_cls`

* riscv64: Use pattern matches on `bseti`
This commit is contained in:
Afonso Bordado
2023-03-21 23:15:14 +00:00
committed by GitHub
parent ff6f17ca52
commit 7a3df7dcc0
14 changed files with 617 additions and 167 deletions

View File

@@ -808,6 +808,28 @@
(decl imm12_from_u64 (Imm12) u64)
(extern extractor imm12_from_u64 imm12_from_u64)
;;;; Instruction Helpers ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; bseti: Set a single bit in a register, indexed by a constant.
(decl bseti (Reg u64) Reg)
(rule (bseti val bit)
(if-let $false (has_zbs))
(if-let $false (u64_le bit 12))
(let ((const Reg (load_u64_constant (u64_shl 1 bit))))
(alu_rrr (AluOPRRR.Or) val const)))
(rule (bseti val bit)
(if-let $false (has_zbs))
(if-let $true (u64_le bit 12))
(alu_rr_imm12 (AluOPRRI.Ori) val (imm12_const (u64_as_i32 (u64_shl 1 bit)))))
(rule (bseti val bit)
(if-let $true (has_zbs))
(alu_rr_imm12 (AluOPRRI.Bseti) val (imm12_const (u64_as_i32 bit))))
;; Float Helpers
(decl gen_default_frm () OptionFloatRoundingMode)
@@ -948,89 +970,103 @@
(decl lower_ctz (Type Reg) Reg)
(rule
(lower_ctz ty x)
(if-let $false (has_zbb))
(rule (lower_ctz ty x)
(gen_cltz $false x ty))
(rule 2
(lower_ctz $I64 x)
(rule 1 (lower_ctz (fits_in_16 ty) x)
(if-let $true (has_zbb))
(alu_rr_funct12 (AluOPRRI.Ctz) x))
(let ((tmp Reg (bseti x (ty_bits ty))))
(alu_rr_funct12 (AluOPRRI.Ctzw) tmp)))
(rule 2
(lower_ctz $I32 x)
(rule 2 (lower_ctz $I32 x)
(if-let $true (has_zbb))
(alu_rr_funct12 (AluOPRRI.Ctzw) x))
;;;; for I8 and I16
(rule 1
(lower_ctz ty x)
(rule 2 (lower_ctz $I64 x)
(if-let $true (has_zbb))
(if-let $true (has_zbs))
(let
((tmp Reg (alu_rr_imm12 (AluOPRRI.Bseti) x (imm12_const (ty_bits ty)))))
(alu_rr_funct12 (AluOPRRI.Ctzw) x)))
(alu_rr_funct12 (AluOPRRI.Ctz) x))
;;;;
;; Count trailing zeros from a i128 bit value.
;; We count both halves separately and conditionally add them if it makes sense.
(decl lower_ctz_128 (ValueRegs) ValueRegs)
(rule
(lower_ctz_128 x)
(let
(;; count the low part.
(low Reg (lower_ctz $I64 (value_regs_get x 0)))
;; count the high part.
(high_part Reg (lower_ctz $I64 (value_regs_get x 1)))
;;;
(constant_64 Reg (load_u64_constant 64))
;;;
(high Reg (gen_select_reg (IntCC.Equal) constant_64 low high_part (zero_reg)))
(rule (lower_ctz_128 x)
(let ((x_lo Reg (value_regs_get x 0))
(x_hi Reg (value_regs_get x 1))
;; Count both halves
(high Reg (lower_ctz $I64 x_hi))
(low Reg (lower_ctz $I64 x_lo))
;; Only add the top half if the bottom is zero
(high Reg (gen_select_reg (IntCC.Equal) x_lo (zero_reg) high (zero_reg)))
(result Reg (alu_add low high)))
(zext result $I64 $I128)))
;; add low and high together.
(result Reg (alu_add low high)))
(value_regs result (load_u64_constant 0))))
(decl lower_clz (Type Reg) Reg)
(rule
(lower_clz ty rs)
(if-let $false (has_zbb))
(rule (lower_clz ty rs)
(gen_cltz $true rs ty))
(rule 2
(lower_clz $I64 r)
(rule 1 (lower_clz (fits_in_16 ty) r)
(if-let $true (has_zbb))
(alu_rr_funct12 (AluOPRRI.Clz) r))
(rule 2
(lower_clz $I32 r)
(let ((tmp Reg (zext r ty $I64))
(count Reg (alu_rr_funct12 (AluOPRRI.Clz) tmp))
;; We always do the operation on the full 64-bit register, so subtract 64 from the result.
(result Reg (alu_rr_imm12 (AluOPRRI.Addi) count (imm12_const_add (ty_bits ty) -64))))
result))
(rule 2 (lower_clz $I32 r)
(if-let $true (has_zbb))
(alu_rr_funct12 (AluOPRRI.Clzw) r))
;;; for I8 and I16
(rule 1
(lower_clz ty r)
(rule 2 (lower_clz $I64 r)
(if-let $true (has_zbb))
(let
( ;; narrow int make all upper bits are zeros.
(tmp Reg (ext_int_if_need $false r ty ))
;;
(count Reg (alu_rr_funct12 (AluOPRRI.Clz) tmp))
;;make result
(result Reg (alu_rr_imm12 (AluOPRRI.Addi) count (imm12_const_add (ty_bits ty) -64))))
result))
(alu_rr_funct12 (AluOPRRI.Clz) r))
;; Count leading zeros from a i128 bit value.
;; We count both halves separately and conditionally add them if it makes sense.
(decl lower_clz_i128 (ValueRegs) ValueRegs)
(rule
(lower_clz_i128 x)
(let
( ;; count high part.
(high Reg (lower_clz $I64 (value_regs_get x 1)))
;; coumt low part.
(low_part Reg (lower_clz $I64 (value_regs_get x 0)))
;;; load constant 64.
(constant_64 Reg (load_u64_constant 64))
(low Reg (gen_select_reg (IntCC.Equal) constant_64 high low_part (zero_reg)))
;; add low and high together.
(result Reg (alu_add high low)))
(value_regs result (load_u64_constant 0))))
(rule (lower_clz_i128 x)
(let ((x_lo Reg (value_regs_get x 0))
(x_hi Reg (value_regs_get x 1))
;; Count both halves
(high Reg (lower_clz $I64 x_hi))
(low Reg (lower_clz $I64 x_lo))
;; Only add the bottom zeros if the top half is zero
(low Reg (gen_select_reg (IntCC.Equal) x_hi (zero_reg) low (zero_reg)))
(result Reg (alu_add high low)))
(zext result $I64 $I128)))
(decl lower_cls (Type Reg) Reg)
(rule (lower_cls ty r)
(let ((tmp Reg (ext_int_if_need $true r ty))
(tmp2 Reg (gen_select_reg (IntCC.SignedLessThan) tmp (zero_reg) (gen_bit_not tmp) tmp))
(tmp3 Reg (lower_clz ty tmp2)))
(alu_rr_imm12 (AluOPRRI.Addi) tmp3 (imm12_const -1))))
;; If the sign bit is set, we count the leading zeros of the inverted value.
;; Otherwise we can just count the leading zeros of the original value.
;; Subtract 1 since the sign bit does not count.
(decl lower_cls_i128 (ValueRegs) ValueRegs)
(rule (lower_cls_i128 x)
(let ((low Reg (value_regs_get x 0))
(high Reg (value_regs_get x 1))
(low Reg (gen_select_reg (IntCC.SignedLessThan) high (zero_reg) (gen_bit_not low) low))
(high Reg (gen_select_reg (IntCC.SignedLessThan) high (zero_reg) (gen_bit_not high) high))
(tmp ValueRegs (lower_clz_i128 (value_regs low high)))
(count Reg (value_regs_get tmp 0))
(result Reg (alu_rr_imm12 (AluOPRRI.Addi) count (imm12_const -1))))
(zext result $I64 $I128)))
(decl gen_cltz (bool Reg Type) Reg)
(rule (gen_cltz leading rs ty)
(let ((tmp WritableReg (temp_writable_reg $I64))
(step WritableReg (temp_writable_reg $I64))
(sum WritableReg (temp_writable_reg $I64))
(_ Unit (emit (MInst.Cltz leading sum step tmp rs ty))))
sum))
;; Extends an integer if it is smaller than 64 bits.
(decl ext_int_if_need (bool ValueRegs Type) ValueRegs)
@@ -1267,27 +1303,6 @@
(part3 Reg (gen_select_reg (IntCC.Equal) shamt (zero_reg) (zero_reg) part2)))
(alu_rrr (AluOPRRR.Or) part1 part3)))
(decl lower_cls (Reg Type) Reg)
(rule
(lower_cls r ty)
(let
( ;; extract sign bit.
(tmp Reg (ext_int_if_need $true r ty))
;;
(tmp2 Reg (gen_select_reg (IntCC.SignedLessThan) tmp (zero_reg) (gen_bit_not r) r))
;;
(tmp3 Reg (lower_clz ty tmp2)))
(alu_rr_imm12 (AluOPRRI.Addi) tmp3 (imm12_const -1))))
(decl gen_cltz (bool Reg Type) Reg)
(rule
(gen_cltz leading rs ty)
(let
((tmp WritableReg (temp_writable_reg $I64))
(step WritableReg (temp_writable_reg $I64))
(sum WritableReg (temp_writable_reg $I64))
(_ Unit (emit (MInst.Cltz leading sum step tmp rs ty))))
(writable_reg_to_reg sum)))
(decl gen_popcnt (Reg Type) Reg)
(rule
@@ -1454,24 +1469,6 @@
(gen_select_reg (IntCC.UnsignedGreaterThanOrEqual) shamt_128 const64 high_replacement high))))
(decl lower_cls_i128 (ValueRegs) ValueRegs)
(rule
(lower_cls_i128 x)
(let
( ;;; we use clz to implement cls
;;; if value is negtive we need inverse all bits.
(low Reg
(gen_select_reg (IntCC.SignedLessThan) (value_regs_get x 1) (zero_reg) (gen_bit_not (value_regs_get x 0)) (value_regs_get x 0)))
;;;
(high Reg
(gen_select_reg (IntCC.SignedLessThan) (value_regs_get x 1) (zero_reg) (gen_bit_not (value_regs_get x 1)) (value_regs_get x 1)))
;; count leading zeros.
(tmp ValueRegs (lower_clz_i128 (value_regs low high)))
(count Reg (value_regs_get tmp 0))
(result Reg (alu_rr_imm12 (AluOPRRI.Addi) count (imm12_const -1))))
(value_regs result (load_u64_constant 0))))
(decl gen_amode (Reg Offset32 Type) AMode)
(extern constructor gen_amode gen_amode)

View File

@@ -327,14 +327,14 @@
;;;; Rules for `ctz` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (ctz x)))
(rule (lower (has_type (fits_in_64 ty) (ctz x)))
(lower_ctz ty x))
(rule 1 (lower (has_type $I128 (ctz x)))
(lower_ctz_128 x))
;;;; Rules for `clz` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (clz x)))
(rule (lower (has_type (fits_in_64 ty) (clz x)))
(lower_clz ty x))
(rule 1 (lower (has_type $I128 (clz x)))
@@ -342,7 +342,7 @@
;;;; Rules for `cls` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type (fits_in_64 ty) (cls x)))
(lower_cls x ty))
(lower_cls ty x))
(rule 1 (lower (has_type $I128 (cls x)))
(lower_cls_i128 x))

View File

@@ -38,6 +38,11 @@ macro_rules! isle_common_prelude_methods {
x as u64
}
#[inline]
fn u64_as_i32(&mut self, x: u64) -> i32 {
x as i32
}
#[inline]
fn i64_neg(&mut self, x: i64) -> i64 {
x.wrapping_neg()

View File

@@ -102,6 +102,9 @@
(decl u64_as_u32 (u32) u64)
(extern extractor u64_as_u32 u64_as_u32)
(decl pure u64_as_i32 (u64) i32)
(extern constructor u64_as_i32 u64_as_i32)
;;;; Primitive Arithmetic ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(decl pure u8_and (u8 u8) u8)