cranelift: Mask high bits on bmask for types smaller than a register (#5118)

* aarch64: Fix incorrect masking for small types on bmask

`bmask` was accidentally relying on the uppermost bits of the register
for small types.

This was found by fuzzgen,  when it generated a shift left followed by
a bmask, the shift left shifted the bits out of the range of the input
type (i8), however these are not automatically cleared since they
remained inside the 32 bits of the register.

That caused issues when the bmask tried to compare the whole register
instead of just the bottom bits. The solution here is to mask the upper
bits for small types.

* aarch64: Emit 32bit cmp on bmask

This fixes an issue where bmask was accidentally comparing the
upper bits of the register by always using a 64bit cmp.

* riscv: Mask high bits in bmask

* riscv: Add compile tests for br{z,nz}

* riscv: Use shifts to mask 32bit values

This produces less code than the AND since that version needs to
load an immediate constant from memory.

* cranelift: Update test input to hexadecimal values

This makes it a bit more clear what is being tested.

* riscv: Use addiw for masking 32 bit values

Co-authored-by: Trevor Elliott <telliott@fastly.com>

* aarch64: Update bmask rule priority

Co-authored-by: Trevor Elliott <telliott@fastly.com>
This commit is contained in:
Afonso Bordado
2022-10-27 17:45:39 +01:00
committed by GitHub
parent 02620441c3
commit e8f3d03bbe
7 changed files with 241 additions and 32 deletions

View File

@@ -3466,14 +3466,15 @@
(decl lower_bmask (Type Type ValueRegs) ValueRegs)
;; For conversions that fit in a register, we can use csetm.
;; For conversions that exactly fit a register, we can use csetm.
;;
;; cmp val, #0
;; csetm res, ne
(rule 0
(lower_bmask (fits_in_64 _) (fits_in_64 _) val)
(lower_bmask (fits_in_64 _) (ty_32_or_64 in_ty) val)
(with_flags_reg
(cmp64_imm (value_regs_get val 0) (u8_into_imm12 0))
(cmp_imm (operand_size in_ty) (value_regs_get val 0) (u8_into_imm12 0))
(csetm (Cond.Ne))))
;; For conversions from a 128-bit value into a 64-bit or smaller one, we or the
@@ -3506,6 +3507,18 @@
(res Reg (value_regs_get res 0)))
(value_regs res res)))
;; For conversions smaller than a register, we need to mask off the high bits, and then
;; we can recurse into the general case.
;;
;; and tmp, val, #ty_mask
;; cmp tmp, #0
;; csetm res, ne
(rule 4
(lower_bmask out_ty (fits_in_16 in_ty) val)
(let ((mask_bits ImmLogic (imm_logic_from_u64 $I32 (ty_mask in_ty)))
(masked Reg (and_imm $I32 (value_regs_get val 0) mask_bits)))
(lower_bmask out_ty $I32 masked)))
;; Exceptional `lower_icmp_into_flags` rules.
;; We need to guarantee that the flags for `cond` are correct, so we
;; compare `dst` with 1.

View File

@@ -1896,25 +1896,26 @@
(decl lower_brz_or_nz (IntCC ValueRegs VecMachLabel Type) InstOutput)
(extern constructor lower_brz_or_nz lower_brz_or_nz)
;; Normalize a value by masking to its bit-size.
(decl normalize_value (Type ValueRegs) ValueRegs)
;; Normalize a value for comparision.
;;
;; This ensures that types smaller than a register don't accidentally
;; pass undefined high bits when being compared as a full register.
(decl normalize_cmp_value (Type ValueRegs) ValueRegs)
(rule (normalize_value $I8 r)
(rule (normalize_cmp_value $I8 r)
(value_reg (alu_rr_imm12 (AluOPRRI.Andi) r (imm12_const 255))))
(rule (normalize_value $I16 r)
(rule (normalize_cmp_value $I16 r)
(value_reg (alu_rrr (AluOPRRR.And) r (imm $I16 65535))))
(rule (normalize_value $I32 r)
(value_reg (alu_rr_imm12 (AluOPRRI.Andi) r (imm12_const -1))))
(rule (normalize_cmp_value $I32 r)
(value_reg (alu_rr_imm12 (AluOPRRI.Addiw) r (imm12_const 0))))
(rule (normalize_value $I64 r) r)
(rule (normalize_value $I128 r) r)
(rule (normalize_value $F32 r) r)
(rule (normalize_value $F64 r) r)
(rule (normalize_cmp_value $I64 r) r)
(rule (normalize_cmp_value $I128 r) r)
;;;;;
(rule
(lower_branch (brz v @ (value_type ty) _ _) targets)
(lower_brz_or_nz (IntCC.Equal) (normalize_value ty v) targets ty))
(lower_brz_or_nz (IntCC.Equal) (normalize_cmp_value ty v) targets ty))
(rule 1
(lower_branch (brz (icmp cc a @ (value_type ty) b) _ _) targets)
@@ -1927,7 +1928,7 @@
;;;;
(rule
(lower_branch (brnz v @ (value_type ty) _ _) targets)
(lower_brz_or_nz (IntCC.NotEqual) (normalize_value ty v) targets ty))
(lower_brz_or_nz (IntCC.NotEqual) (normalize_cmp_value ty v) targets ty))
(rule 1
(lower_branch (brnz (icmp cc a @ (value_type ty) b) _ _) targets)
@@ -2097,10 +2098,12 @@
(decl lower_bmask (Type Type ValueRegs) ValueRegs)
;; Produces -1 if the 64-bit value is non-zero, and 0 otherwise.
;; If the type is smaller than 64 bits, we need to mask off the
;; high bits.
(rule
0
(lower_bmask (fits_in_64 _) (fits_in_64 _) val)
(let ((input Reg val)
(lower_bmask (fits_in_64 _) (fits_in_64 in_ty) val)
(let ((input Reg (normalize_cmp_value in_ty val))
(zero Reg (zero_reg))
(ones Reg (load_imm12 -1)))
(value_reg (gen_select_reg (IntCC.Equal) zero input zero ones))))
@@ -2119,8 +2122,8 @@
;; bmask of the 64-bit value into both result registers of the i128.
(rule
2
(lower_bmask $I128 (fits_in_64 _) val)
(let ((res ValueRegs (lower_bmask $I64 $I64 val)))
(lower_bmask $I128 (fits_in_64 in_ty) val)
(let ((res ValueRegs (lower_bmask $I64 in_ty val)))
(value_regs (value_regs_get res 0) (value_regs_get res 0))))
;; Conversion of one 64-bit value to a 128-bit one. Duplicate the result of

View File

@@ -605,13 +605,11 @@
;;;;; Rules for `select`;;;;;;;;;
(rule
(lower (has_type ty (select c @ (value_type cty) x y)))
(gen_select ty (normalize_value cty c) x y)
)
(gen_select ty (normalize_cmp_value cty c) x y))
(rule 1
(lower (has_type ty (select (icmp cc a b) x y)))
(gen_select_reg cc a b x y)
)
(gen_select_reg cc a b x y))
;;;;; Rules for `bitselect`;;;;;;;;;

View File

@@ -81,7 +81,7 @@ block0(v0: i32):
}
; block0:
; subs xzr, x0, #0
; subs wzr, w0, #0
; csetm x1, ne
; mov x0, x1
; ret
@@ -93,7 +93,8 @@ block0(v0: i16):
}
; block0:
; subs xzr, x0, #0
; and w4, w0, #65535
; subs wzr, w4, #0
; csetm x1, ne
; mov x0, x1
; ret
@@ -105,7 +106,8 @@ block0(v0: i8):
}
; block0:
; subs xzr, x0, #0
; and w4, w0, #255
; subs wzr, w4, #0
; csetm x1, ne
; mov x0, x1
; ret

View File

@@ -399,3 +399,167 @@ block1:
; block3:
; ret
function %i8_brz(i8){
block0(v0: i8):
brz v0, block1
jump block1
block1:
nop
return
}
; block0:
; andi t2,a0,255
; beq t2,zero,taken(label1),not_taken(label2)
; block1:
; j label3
; block2:
; j label3
; block3:
; ret
function %i8_brnz(i8){
block0(v0: i8):
brnz v0, block1
jump block1
block1:
nop
return
}
; block0:
; andi t2,a0,255
; bne t2,zero,taken(label1),not_taken(label2)
; block1:
; j label3
; block2:
; j label3
; block3:
; ret
function %i16_brz(i16){
block0(v0: i16):
brz v0, block1
jump block1
block1:
nop
return
}
; block0:
; lui t2,16
; addi t2,t2,4095
; and a2,a0,t2
; beq a2,zero,taken(label1),not_taken(label2)
; block1:
; j label3
; block2:
; j label3
; block3:
; ret
function %i16_brnz(i16){
block0(v0: i16):
brnz v0, block1
jump block1
block1:
nop
return
}
; block0:
; lui t2,16
; addi t2,t2,4095
; and a2,a0,t2
; bne a2,zero,taken(label1),not_taken(label2)
; block1:
; j label3
; block2:
; j label3
; block3:
; ret
function %i32_brz(i32){
block0(v0: i32):
brz v0, block1
jump block1
block1:
nop
return
}
; block0:
; addiw t2,a0,0
; beq t2,zero,taken(label1),not_taken(label2)
; block1:
; j label3
; block2:
; j label3
; block3:
; ret
function %i32_brnz(i32){
block0(v0: i32):
brnz v0, block1
jump block1
block1:
nop
return
}
; block0:
; addiw t2,a0,0
; bne t2,zero,taken(label1),not_taken(label2)
; block1:
; j label3
; block2:
; j label3
; block3:
; ret
function %i64_brz(i64){
block0(v0: i64):
brz v0, block1
jump block1
block1:
nop
return
}
; block0:
; beq a0,zero,taken(label1),not_taken(label2)
; block1:
; j label3
; block2:
; j label3
; block3:
; ret
function %i64_brnz(i64){
block0(v0: i64):
brnz v0, block1
jump block1
block1:
nop
return
}
; block0:
; bne a0,zero,taken(label1),not_taken(label2)
; block1:
; j label3
; block2:
; j label3
; block3:
; ret

View File

@@ -82,8 +82,9 @@ block0(v0: i32):
}
; block0:
; li a1,-1
; select_reg a1,zero,a1##condition=(zero eq a0)
; addiw a1,a0,0
; li a3,-1
; select_reg a1,zero,a3##condition=(zero eq a1)
; mv a0,a1
; ret
@@ -94,8 +95,11 @@ block0(v0: i16):
}
; block0:
; li a1,-1
; select_reg a1,zero,a1##condition=(zero eq a0)
; lui a1,16
; addi a1,a1,4095
; and a4,a0,a1
; li a6,-1
; select_reg a1,zero,a6##condition=(zero eq a4)
; mv a0,a1
; ret
@@ -106,8 +110,9 @@ block0(v0: i8):
}
; block0:
; li a1,-1
; select_reg a1,zero,a1##condition=(zero eq a0)
; andi a1,a0,255
; li a3,-1
; select_reg a1,zero,a3##condition=(zero eq a1)
; mv a0,a1
; ret

View File

@@ -131,3 +131,27 @@ block0(v0: i8):
}
; run: %bmask_i8_i8(1) == -1
; run: %bmask_i8_i8(0) == 0
; This is a regression test for AArch64, where the high bits weren't
; correctly being masked off for smaller types
function %bmask_masks_small_types() -> i8 {
block0:
v0 = iconst.i8 120
v1 = iconst.i8 7
v2 = ishl.i8 v0, v1
v3 = bmask.i8 v2
return v3
}
; run: %bmask_masks_small_types() == 0
; Similar to the above, this issue happened due to us always using a 64 bit
; comparison, even on a 32 bit type. This is triggered by ireduce since it
; doesn't actually produce any instructions, but is just a "type cast".
function %bmask_uses_32bit_cmp(i64) -> i8 {
block0(v0: i64):
v1 = ireduce.i32 v0
v2 = bmask.i8 v1
return v2
}
; run: %bmask_uses_32bit_cmp(0x2520B6E9_00000000) == 0