From e8f3d03bbe17151530601bac82af912869f09080 Mon Sep 17 00:00:00 2001 From: Afonso Bordado Date: Thu, 27 Oct 2022 17:45:39 +0100 Subject: [PATCH] cranelift: Mask high bits on `bmask` for types smaller than a register (#5118) * aarch64: Fix incorrect masking for small types on bmask `bmask` was accidentally relying on the uppermost bits of the register for small types. This was found by fuzzgen, when it generated a shift left followed by a bmask, the shift left shifted the bits out of the range of the input type (i8), however these are not automatically cleared since they remained inside the 32 bits of the register. That caused issues when the bmask tried to compare the whole register instead of just the bottom bits. The solution here is to mask the upper bits for small types. * aarch64: Emit 32bit cmp on bmask This fixes an issue where bmask was accidentally comparing the upper bits of the register by always using a 64bit cmp. * riscv: Mask high bits in bmask * riscv: Add compile tests for br{z,nz} * riscv: Use shifts to mask 32bit values This produces less code than the AND since that version needs to load an immediate constant from memory. * cranelift: Update test input to hexadecimal values This makes it a bit more clear what is being tested. * riscv: Use addiw for masking 32 bit values Co-authored-by: Trevor Elliott * aarch64: Update bmask rule priority Co-authored-by: Trevor Elliott --- cranelift/codegen/src/isa/aarch64/inst.isle | 19 +- cranelift/codegen/src/isa/riscv64/inst.isle | 35 ++-- cranelift/codegen/src/isa/riscv64/lower.isle | 6 +- .../filetests/isa/aarch64/i128-bmask.clif | 8 +- .../filetests/isa/riscv64/condbr.clif | 164 ++++++++++++++++++ .../filetests/isa/riscv64/i128-bmask.clif | 17 +- .../filetests/filetests/runtests/bmask.clif | 24 +++ 7 files changed, 241 insertions(+), 32 deletions(-) diff --git a/cranelift/codegen/src/isa/aarch64/inst.isle b/cranelift/codegen/src/isa/aarch64/inst.isle index b1b1a561a5..68dd824634 100644 --- a/cranelift/codegen/src/isa/aarch64/inst.isle +++ b/cranelift/codegen/src/isa/aarch64/inst.isle @@ -3466,14 +3466,15 @@ (decl lower_bmask (Type Type ValueRegs) ValueRegs) -;; For conversions that fit in a register, we can use csetm. + +;; For conversions that exactly fit a register, we can use csetm. ;; ;; cmp val, #0 ;; csetm res, ne (rule 0 - (lower_bmask (fits_in_64 _) (fits_in_64 _) val) + (lower_bmask (fits_in_64 _) (ty_32_or_64 in_ty) val) (with_flags_reg - (cmp64_imm (value_regs_get val 0) (u8_into_imm12 0)) + (cmp_imm (operand_size in_ty) (value_regs_get val 0) (u8_into_imm12 0)) (csetm (Cond.Ne)))) ;; For conversions from a 128-bit value into a 64-bit or smaller one, we or the @@ -3506,6 +3507,18 @@ (res Reg (value_regs_get res 0))) (value_regs res res))) +;; For conversions smaller than a register, we need to mask off the high bits, and then +;; we can recurse into the general case. +;; +;; and tmp, val, #ty_mask +;; cmp tmp, #0 +;; csetm res, ne +(rule 4 + (lower_bmask out_ty (fits_in_16 in_ty) val) + (let ((mask_bits ImmLogic (imm_logic_from_u64 $I32 (ty_mask in_ty))) + (masked Reg (and_imm $I32 (value_regs_get val 0) mask_bits))) + (lower_bmask out_ty $I32 masked))) + ;; Exceptional `lower_icmp_into_flags` rules. ;; We need to guarantee that the flags for `cond` are correct, so we ;; compare `dst` with 1. diff --git a/cranelift/codegen/src/isa/riscv64/inst.isle b/cranelift/codegen/src/isa/riscv64/inst.isle index 404231482a..69ae36a904 100644 --- a/cranelift/codegen/src/isa/riscv64/inst.isle +++ b/cranelift/codegen/src/isa/riscv64/inst.isle @@ -1896,25 +1896,26 @@ (decl lower_brz_or_nz (IntCC ValueRegs VecMachLabel Type) InstOutput) (extern constructor lower_brz_or_nz lower_brz_or_nz) -;; Normalize a value by masking to its bit-size. -(decl normalize_value (Type ValueRegs) ValueRegs) +;; Normalize a value for comparision. +;; +;; This ensures that types smaller than a register don't accidentally +;; pass undefined high bits when being compared as a full register. +(decl normalize_cmp_value (Type ValueRegs) ValueRegs) -(rule (normalize_value $I8 r) +(rule (normalize_cmp_value $I8 r) (value_reg (alu_rr_imm12 (AluOPRRI.Andi) r (imm12_const 255)))) -(rule (normalize_value $I16 r) +(rule (normalize_cmp_value $I16 r) (value_reg (alu_rrr (AluOPRRR.And) r (imm $I16 65535)))) -(rule (normalize_value $I32 r) - (value_reg (alu_rr_imm12 (AluOPRRI.Andi) r (imm12_const -1)))) +(rule (normalize_cmp_value $I32 r) + (value_reg (alu_rr_imm12 (AluOPRRI.Addiw) r (imm12_const 0)))) -(rule (normalize_value $I64 r) r) -(rule (normalize_value $I128 r) r) -(rule (normalize_value $F32 r) r) -(rule (normalize_value $F64 r) r) +(rule (normalize_cmp_value $I64 r) r) +(rule (normalize_cmp_value $I128 r) r) ;;;;; (rule (lower_branch (brz v @ (value_type ty) _ _) targets) - (lower_brz_or_nz (IntCC.Equal) (normalize_value ty v) targets ty)) + (lower_brz_or_nz (IntCC.Equal) (normalize_cmp_value ty v) targets ty)) (rule 1 (lower_branch (brz (icmp cc a @ (value_type ty) b) _ _) targets) @@ -1927,7 +1928,7 @@ ;;;; (rule (lower_branch (brnz v @ (value_type ty) _ _) targets) - (lower_brz_or_nz (IntCC.NotEqual) (normalize_value ty v) targets ty)) + (lower_brz_or_nz (IntCC.NotEqual) (normalize_cmp_value ty v) targets ty)) (rule 1 (lower_branch (brnz (icmp cc a @ (value_type ty) b) _ _) targets) @@ -2097,10 +2098,12 @@ (decl lower_bmask (Type Type ValueRegs) ValueRegs) ;; Produces -1 if the 64-bit value is non-zero, and 0 otherwise. +;; If the type is smaller than 64 bits, we need to mask off the +;; high bits. (rule 0 - (lower_bmask (fits_in_64 _) (fits_in_64 _) val) - (let ((input Reg val) + (lower_bmask (fits_in_64 _) (fits_in_64 in_ty) val) + (let ((input Reg (normalize_cmp_value in_ty val)) (zero Reg (zero_reg)) (ones Reg (load_imm12 -1))) (value_reg (gen_select_reg (IntCC.Equal) zero input zero ones)))) @@ -2119,8 +2122,8 @@ ;; bmask of the 64-bit value into both result registers of the i128. (rule 2 - (lower_bmask $I128 (fits_in_64 _) val) - (let ((res ValueRegs (lower_bmask $I64 $I64 val))) + (lower_bmask $I128 (fits_in_64 in_ty) val) + (let ((res ValueRegs (lower_bmask $I64 in_ty val))) (value_regs (value_regs_get res 0) (value_regs_get res 0)))) ;; Conversion of one 64-bit value to a 128-bit one. Duplicate the result of diff --git a/cranelift/codegen/src/isa/riscv64/lower.isle b/cranelift/codegen/src/isa/riscv64/lower.isle index edc3da22f1..a24f05fc13 100644 --- a/cranelift/codegen/src/isa/riscv64/lower.isle +++ b/cranelift/codegen/src/isa/riscv64/lower.isle @@ -605,13 +605,11 @@ ;;;;; Rules for `select`;;;;;;;;; (rule (lower (has_type ty (select c @ (value_type cty) x y))) - (gen_select ty (normalize_value cty c) x y) -) + (gen_select ty (normalize_cmp_value cty c) x y)) (rule 1 (lower (has_type ty (select (icmp cc a b) x y))) - (gen_select_reg cc a b x y) -) + (gen_select_reg cc a b x y)) ;;;;; Rules for `bitselect`;;;;;;;;; diff --git a/cranelift/filetests/filetests/isa/aarch64/i128-bmask.clif b/cranelift/filetests/filetests/isa/aarch64/i128-bmask.clif index 869d4b71e0..f6bff3ae4c 100644 --- a/cranelift/filetests/filetests/isa/aarch64/i128-bmask.clif +++ b/cranelift/filetests/filetests/isa/aarch64/i128-bmask.clif @@ -81,7 +81,7 @@ block0(v0: i32): } ; block0: -; subs xzr, x0, #0 +; subs wzr, w0, #0 ; csetm x1, ne ; mov x0, x1 ; ret @@ -93,7 +93,8 @@ block0(v0: i16): } ; block0: -; subs xzr, x0, #0 +; and w4, w0, #65535 +; subs wzr, w4, #0 ; csetm x1, ne ; mov x0, x1 ; ret @@ -105,7 +106,8 @@ block0(v0: i8): } ; block0: -; subs xzr, x0, #0 +; and w4, w0, #255 +; subs wzr, w4, #0 ; csetm x1, ne ; mov x0, x1 ; ret diff --git a/cranelift/filetests/filetests/isa/riscv64/condbr.clif b/cranelift/filetests/filetests/isa/riscv64/condbr.clif index d648805a60..cc9946faa7 100644 --- a/cranelift/filetests/filetests/isa/riscv64/condbr.clif +++ b/cranelift/filetests/filetests/isa/riscv64/condbr.clif @@ -399,3 +399,167 @@ block1: ; block3: ; ret + + +function %i8_brz(i8){ +block0(v0: i8): + brz v0, block1 + jump block1 + +block1: + nop + return +} + +; block0: +; andi t2,a0,255 +; beq t2,zero,taken(label1),not_taken(label2) +; block1: +; j label3 +; block2: +; j label3 +; block3: +; ret + +function %i8_brnz(i8){ +block0(v0: i8): + brnz v0, block1 + jump block1 + +block1: + nop + return +} + +; block0: +; andi t2,a0,255 +; bne t2,zero,taken(label1),not_taken(label2) +; block1: +; j label3 +; block2: +; j label3 +; block3: +; ret + +function %i16_brz(i16){ +block0(v0: i16): + brz v0, block1 + jump block1 + +block1: + nop + return +} + +; block0: +; lui t2,16 +; addi t2,t2,4095 +; and a2,a0,t2 +; beq a2,zero,taken(label1),not_taken(label2) +; block1: +; j label3 +; block2: +; j label3 +; block3: +; ret + +function %i16_brnz(i16){ +block0(v0: i16): + brnz v0, block1 + jump block1 + +block1: + nop + return +} + +; block0: +; lui t2,16 +; addi t2,t2,4095 +; and a2,a0,t2 +; bne a2,zero,taken(label1),not_taken(label2) +; block1: +; j label3 +; block2: +; j label3 +; block3: +; ret + +function %i32_brz(i32){ +block0(v0: i32): + brz v0, block1 + jump block1 + +block1: + nop + return +} + +; block0: +; addiw t2,a0,0 +; beq t2,zero,taken(label1),not_taken(label2) +; block1: +; j label3 +; block2: +; j label3 +; block3: +; ret + +function %i32_brnz(i32){ +block0(v0: i32): + brnz v0, block1 + jump block1 + +block1: + nop + return +} + +; block0: +; addiw t2,a0,0 +; bne t2,zero,taken(label1),not_taken(label2) +; block1: +; j label3 +; block2: +; j label3 +; block3: +; ret + +function %i64_brz(i64){ +block0(v0: i64): + brz v0, block1 + jump block1 + +block1: + nop + return +} + +; block0: +; beq a0,zero,taken(label1),not_taken(label2) +; block1: +; j label3 +; block2: +; j label3 +; block3: +; ret + +function %i64_brnz(i64){ +block0(v0: i64): + brnz v0, block1 + jump block1 + +block1: + nop + return +} + +; block0: +; bne a0,zero,taken(label1),not_taken(label2) +; block1: +; j label3 +; block2: +; j label3 +; block3: +; ret + diff --git a/cranelift/filetests/filetests/isa/riscv64/i128-bmask.clif b/cranelift/filetests/filetests/isa/riscv64/i128-bmask.clif index 558b8b3697..7bc84cc08f 100644 --- a/cranelift/filetests/filetests/isa/riscv64/i128-bmask.clif +++ b/cranelift/filetests/filetests/isa/riscv64/i128-bmask.clif @@ -82,8 +82,9 @@ block0(v0: i32): } ; block0: -; li a1,-1 -; select_reg a1,zero,a1##condition=(zero eq a0) +; addiw a1,a0,0 +; li a3,-1 +; select_reg a1,zero,a3##condition=(zero eq a1) ; mv a0,a1 ; ret @@ -94,8 +95,11 @@ block0(v0: i16): } ; block0: -; li a1,-1 -; select_reg a1,zero,a1##condition=(zero eq a0) +; lui a1,16 +; addi a1,a1,4095 +; and a4,a0,a1 +; li a6,-1 +; select_reg a1,zero,a6##condition=(zero eq a4) ; mv a0,a1 ; ret @@ -106,8 +110,9 @@ block0(v0: i8): } ; block0: -; li a1,-1 -; select_reg a1,zero,a1##condition=(zero eq a0) +; andi a1,a0,255 +; li a3,-1 +; select_reg a1,zero,a3##condition=(zero eq a1) ; mv a0,a1 ; ret diff --git a/cranelift/filetests/filetests/runtests/bmask.clif b/cranelift/filetests/filetests/runtests/bmask.clif index 5e004108cf..e762bbd078 100644 --- a/cranelift/filetests/filetests/runtests/bmask.clif +++ b/cranelift/filetests/filetests/runtests/bmask.clif @@ -131,3 +131,27 @@ block0(v0: i8): } ; run: %bmask_i8_i8(1) == -1 ; run: %bmask_i8_i8(0) == 0 + + +; This is a regression test for AArch64, where the high bits weren't +; correctly being masked off for smaller types +function %bmask_masks_small_types() -> i8 { +block0: + v0 = iconst.i8 120 + v1 = iconst.i8 7 + v2 = ishl.i8 v0, v1 + v3 = bmask.i8 v2 + return v3 +} +; run: %bmask_masks_small_types() == 0 + +; Similar to the above, this issue happened due to us always using a 64 bit +; comparison, even on a 32 bit type. This is triggered by ireduce since it +; doesn't actually produce any instructions, but is just a "type cast". +function %bmask_uses_32bit_cmp(i64) -> i8 { +block0(v0: i64): + v1 = ireduce.i32 v0 + v2 = bmask.i8 v1 + return v2 +} +; run: %bmask_uses_32bit_cmp(0x2520B6E9_00000000) == 0