riscv64: Improve ctz/clz/cls codegen (#5854)
* cranelift: Add extra runtests for `clz`/`ctz` * riscv64: Restrict lowering rules for `ctz`/`clz` * cranelift: Add `u64` isle helpers * riscv64: Improve `ctz` codegen * riscv64: Improve `clz` codegen * riscv64: Improve `cls` codegen * riscv64: Improve `clz.i128` codegen Instead of checking if we have 64 zeros in the top half. Check if it *is* 0, that way we avoid loading the `64` constant. * riscv64: Improve `ctz.i128` codegen Instead of checking if we have 64 zeros in the bottom half. Check if it *is* 0, that way we avoid loading the `64` constant. * riscv64: Use extended value in `lower_cls` * riscv64: Use pattern matches on `bseti`
This commit is contained in:
@@ -808,6 +808,28 @@
|
||||
(decl imm12_from_u64 (Imm12) u64)
|
||||
(extern extractor imm12_from_u64 imm12_from_u64)
|
||||
|
||||
|
||||
;;;; Instruction Helpers ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
|
||||
|
||||
|
||||
;; bseti: Set a single bit in a register, indexed by a constant.
|
||||
(decl bseti (Reg u64) Reg)
|
||||
(rule (bseti val bit)
|
||||
(if-let $false (has_zbs))
|
||||
(if-let $false (u64_le bit 12))
|
||||
(let ((const Reg (load_u64_constant (u64_shl 1 bit))))
|
||||
(alu_rrr (AluOPRRR.Or) val const)))
|
||||
|
||||
(rule (bseti val bit)
|
||||
(if-let $false (has_zbs))
|
||||
(if-let $true (u64_le bit 12))
|
||||
(alu_rr_imm12 (AluOPRRI.Ori) val (imm12_const (u64_as_i32 (u64_shl 1 bit)))))
|
||||
|
||||
(rule (bseti val bit)
|
||||
(if-let $true (has_zbs))
|
||||
(alu_rr_imm12 (AluOPRRI.Bseti) val (imm12_const (u64_as_i32 bit))))
|
||||
|
||||
|
||||
;; Float Helpers
|
||||
|
||||
(decl gen_default_frm () OptionFloatRoundingMode)
|
||||
@@ -948,89 +970,103 @@
|
||||
|
||||
|
||||
(decl lower_ctz (Type Reg) Reg)
|
||||
(rule
|
||||
(lower_ctz ty x)
|
||||
(if-let $false (has_zbb))
|
||||
(rule (lower_ctz ty x)
|
||||
(gen_cltz $false x ty))
|
||||
|
||||
(rule 2
|
||||
(lower_ctz $I64 x)
|
||||
(rule 1 (lower_ctz (fits_in_16 ty) x)
|
||||
(if-let $true (has_zbb))
|
||||
(alu_rr_funct12 (AluOPRRI.Ctz) x))
|
||||
(let ((tmp Reg (bseti x (ty_bits ty))))
|
||||
(alu_rr_funct12 (AluOPRRI.Ctzw) tmp)))
|
||||
|
||||
(rule 2
|
||||
(lower_ctz $I32 x)
|
||||
(rule 2 (lower_ctz $I32 x)
|
||||
(if-let $true (has_zbb))
|
||||
(alu_rr_funct12 (AluOPRRI.Ctzw) x))
|
||||
|
||||
;;;; for I8 and I16
|
||||
(rule 1
|
||||
(lower_ctz ty x)
|
||||
(rule 2 (lower_ctz $I64 x)
|
||||
(if-let $true (has_zbb))
|
||||
(if-let $true (has_zbs))
|
||||
(let
|
||||
((tmp Reg (alu_rr_imm12 (AluOPRRI.Bseti) x (imm12_const (ty_bits ty)))))
|
||||
(alu_rr_funct12 (AluOPRRI.Ctzw) x)))
|
||||
(alu_rr_funct12 (AluOPRRI.Ctz) x))
|
||||
|
||||
;;;;
|
||||
;; Count trailing zeros from a i128 bit value.
|
||||
;; We count both halves separately and conditionally add them if it makes sense.
|
||||
(decl lower_ctz_128 (ValueRegs) ValueRegs)
|
||||
(rule
|
||||
(lower_ctz_128 x)
|
||||
(let
|
||||
(;; count the low part.
|
||||
(low Reg (lower_ctz $I64 (value_regs_get x 0)))
|
||||
;; count the high part.
|
||||
(high_part Reg (lower_ctz $I64 (value_regs_get x 1)))
|
||||
;;;
|
||||
(constant_64 Reg (load_u64_constant 64))
|
||||
;;;
|
||||
(high Reg (gen_select_reg (IntCC.Equal) constant_64 low high_part (zero_reg)))
|
||||
(rule (lower_ctz_128 x)
|
||||
(let ((x_lo Reg (value_regs_get x 0))
|
||||
(x_hi Reg (value_regs_get x 1))
|
||||
;; Count both halves
|
||||
(high Reg (lower_ctz $I64 x_hi))
|
||||
(low Reg (lower_ctz $I64 x_lo))
|
||||
;; Only add the top half if the bottom is zero
|
||||
(high Reg (gen_select_reg (IntCC.Equal) x_lo (zero_reg) high (zero_reg)))
|
||||
(result Reg (alu_add low high)))
|
||||
(zext result $I64 $I128)))
|
||||
|
||||
|
||||
;; add low and high together.
|
||||
(result Reg (alu_add low high)))
|
||||
(value_regs result (load_u64_constant 0))))
|
||||
|
||||
(decl lower_clz (Type Reg) Reg)
|
||||
(rule
|
||||
(lower_clz ty rs)
|
||||
(if-let $false (has_zbb))
|
||||
(rule (lower_clz ty rs)
|
||||
(gen_cltz $true rs ty))
|
||||
(rule 2
|
||||
(lower_clz $I64 r)
|
||||
|
||||
(rule 1 (lower_clz (fits_in_16 ty) r)
|
||||
(if-let $true (has_zbb))
|
||||
(alu_rr_funct12 (AluOPRRI.Clz) r))
|
||||
(rule 2
|
||||
(lower_clz $I32 r)
|
||||
(let ((tmp Reg (zext r ty $I64))
|
||||
(count Reg (alu_rr_funct12 (AluOPRRI.Clz) tmp))
|
||||
;; We always do the operation on the full 64-bit register, so subtract 64 from the result.
|
||||
(result Reg (alu_rr_imm12 (AluOPRRI.Addi) count (imm12_const_add (ty_bits ty) -64))))
|
||||
result))
|
||||
|
||||
(rule 2 (lower_clz $I32 r)
|
||||
(if-let $true (has_zbb))
|
||||
(alu_rr_funct12 (AluOPRRI.Clzw) r))
|
||||
|
||||
;;; for I8 and I16
|
||||
(rule 1
|
||||
(lower_clz ty r)
|
||||
(rule 2 (lower_clz $I64 r)
|
||||
(if-let $true (has_zbb))
|
||||
(let
|
||||
( ;; narrow int make all upper bits are zeros.
|
||||
(tmp Reg (ext_int_if_need $false r ty ))
|
||||
;;
|
||||
(count Reg (alu_rr_funct12 (AluOPRRI.Clz) tmp))
|
||||
;;make result
|
||||
(result Reg (alu_rr_imm12 (AluOPRRI.Addi) count (imm12_const_add (ty_bits ty) -64))))
|
||||
result))
|
||||
(alu_rr_funct12 (AluOPRRI.Clz) r))
|
||||
|
||||
;; Count leading zeros from a i128 bit value.
|
||||
;; We count both halves separately and conditionally add them if it makes sense.
|
||||
(decl lower_clz_i128 (ValueRegs) ValueRegs)
|
||||
(rule
|
||||
(lower_clz_i128 x)
|
||||
(let
|
||||
( ;; count high part.
|
||||
(high Reg (lower_clz $I64 (value_regs_get x 1)))
|
||||
;; coumt low part.
|
||||
(low_part Reg (lower_clz $I64 (value_regs_get x 0)))
|
||||
;;; load constant 64.
|
||||
(constant_64 Reg (load_u64_constant 64))
|
||||
(low Reg (gen_select_reg (IntCC.Equal) constant_64 high low_part (zero_reg)))
|
||||
;; add low and high together.
|
||||
(result Reg (alu_add high low)))
|
||||
(value_regs result (load_u64_constant 0))))
|
||||
(rule (lower_clz_i128 x)
|
||||
(let ((x_lo Reg (value_regs_get x 0))
|
||||
(x_hi Reg (value_regs_get x 1))
|
||||
;; Count both halves
|
||||
(high Reg (lower_clz $I64 x_hi))
|
||||
(low Reg (lower_clz $I64 x_lo))
|
||||
;; Only add the bottom zeros if the top half is zero
|
||||
(low Reg (gen_select_reg (IntCC.Equal) x_hi (zero_reg) low (zero_reg)))
|
||||
(result Reg (alu_add high low)))
|
||||
(zext result $I64 $I128)))
|
||||
|
||||
|
||||
(decl lower_cls (Type Reg) Reg)
|
||||
(rule (lower_cls ty r)
|
||||
(let ((tmp Reg (ext_int_if_need $true r ty))
|
||||
(tmp2 Reg (gen_select_reg (IntCC.SignedLessThan) tmp (zero_reg) (gen_bit_not tmp) tmp))
|
||||
(tmp3 Reg (lower_clz ty tmp2)))
|
||||
(alu_rr_imm12 (AluOPRRI.Addi) tmp3 (imm12_const -1))))
|
||||
|
||||
;; If the sign bit is set, we count the leading zeros of the inverted value.
|
||||
;; Otherwise we can just count the leading zeros of the original value.
|
||||
;; Subtract 1 since the sign bit does not count.
|
||||
(decl lower_cls_i128 (ValueRegs) ValueRegs)
|
||||
(rule (lower_cls_i128 x)
|
||||
(let ((low Reg (value_regs_get x 0))
|
||||
(high Reg (value_regs_get x 1))
|
||||
(low Reg (gen_select_reg (IntCC.SignedLessThan) high (zero_reg) (gen_bit_not low) low))
|
||||
(high Reg (gen_select_reg (IntCC.SignedLessThan) high (zero_reg) (gen_bit_not high) high))
|
||||
(tmp ValueRegs (lower_clz_i128 (value_regs low high)))
|
||||
(count Reg (value_regs_get tmp 0))
|
||||
(result Reg (alu_rr_imm12 (AluOPRRI.Addi) count (imm12_const -1))))
|
||||
(zext result $I64 $I128)))
|
||||
|
||||
|
||||
(decl gen_cltz (bool Reg Type) Reg)
|
||||
(rule (gen_cltz leading rs ty)
|
||||
(let ((tmp WritableReg (temp_writable_reg $I64))
|
||||
(step WritableReg (temp_writable_reg $I64))
|
||||
(sum WritableReg (temp_writable_reg $I64))
|
||||
(_ Unit (emit (MInst.Cltz leading sum step tmp rs ty))))
|
||||
sum))
|
||||
|
||||
|
||||
;; Extends an integer if it is smaller than 64 bits.
|
||||
(decl ext_int_if_need (bool ValueRegs Type) ValueRegs)
|
||||
@@ -1267,27 +1303,6 @@
|
||||
(part3 Reg (gen_select_reg (IntCC.Equal) shamt (zero_reg) (zero_reg) part2)))
|
||||
(alu_rrr (AluOPRRR.Or) part1 part3)))
|
||||
|
||||
(decl lower_cls (Reg Type) Reg)
|
||||
(rule
|
||||
(lower_cls r ty)
|
||||
(let
|
||||
( ;; extract sign bit.
|
||||
(tmp Reg (ext_int_if_need $true r ty))
|
||||
;;
|
||||
(tmp2 Reg (gen_select_reg (IntCC.SignedLessThan) tmp (zero_reg) (gen_bit_not r) r))
|
||||
;;
|
||||
(tmp3 Reg (lower_clz ty tmp2)))
|
||||
(alu_rr_imm12 (AluOPRRI.Addi) tmp3 (imm12_const -1))))
|
||||
|
||||
(decl gen_cltz (bool Reg Type) Reg)
|
||||
(rule
|
||||
(gen_cltz leading rs ty)
|
||||
(let
|
||||
((tmp WritableReg (temp_writable_reg $I64))
|
||||
(step WritableReg (temp_writable_reg $I64))
|
||||
(sum WritableReg (temp_writable_reg $I64))
|
||||
(_ Unit (emit (MInst.Cltz leading sum step tmp rs ty))))
|
||||
(writable_reg_to_reg sum)))
|
||||
|
||||
(decl gen_popcnt (Reg Type) Reg)
|
||||
(rule
|
||||
@@ -1454,24 +1469,6 @@
|
||||
(gen_select_reg (IntCC.UnsignedGreaterThanOrEqual) shamt_128 const64 high_replacement high))))
|
||||
|
||||
|
||||
(decl lower_cls_i128 (ValueRegs) ValueRegs)
|
||||
(rule
|
||||
(lower_cls_i128 x)
|
||||
(let
|
||||
( ;;; we use clz to implement cls
|
||||
;;; if value is negtive we need inverse all bits.
|
||||
(low Reg
|
||||
(gen_select_reg (IntCC.SignedLessThan) (value_regs_get x 1) (zero_reg) (gen_bit_not (value_regs_get x 0)) (value_regs_get x 0)))
|
||||
;;;
|
||||
(high Reg
|
||||
(gen_select_reg (IntCC.SignedLessThan) (value_regs_get x 1) (zero_reg) (gen_bit_not (value_regs_get x 1)) (value_regs_get x 1)))
|
||||
;; count leading zeros.
|
||||
(tmp ValueRegs (lower_clz_i128 (value_regs low high)))
|
||||
(count Reg (value_regs_get tmp 0))
|
||||
(result Reg (alu_rr_imm12 (AluOPRRI.Addi) count (imm12_const -1))))
|
||||
(value_regs result (load_u64_constant 0))))
|
||||
|
||||
|
||||
(decl gen_amode (Reg Offset32 Type) AMode)
|
||||
(extern constructor gen_amode gen_amode)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user