diff --git a/Cargo.lock b/Cargo.lock index a5c941a1b6..6780b4cfc9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -622,6 +622,7 @@ dependencies = [ "cranelift-codegen", "hashbrown", "log", + "similar", "smallvec", "target-lexicon", ] diff --git a/cranelift/frontend/Cargo.toml b/cranelift/frontend/Cargo.toml index c3f0aa5bb6..bb68f5e8d1 100644 --- a/cranelift/frontend/Cargo.toml +++ b/cranelift/frontend/Cargo.toml @@ -17,6 +17,9 @@ log = { workspace = true } hashbrown = { workspace = true, optional = true } smallvec = { workspace = true } +[dev-dependencies] +similar = { workspace = true } + [features] default = ["std"] std = ["cranelift-codegen/std"] diff --git a/cranelift/frontend/src/switch.rs b/cranelift/frontend/src/switch.rs index 4a1abef46b..3920ba8537 100644 --- a/cranelift/frontend/src/switch.rs +++ b/cranelift/frontend/src/switch.rs @@ -108,27 +108,19 @@ impl Switch { } /// Binary search for the right `ContiguousCaseRange`. - fn build_search_tree( + fn build_search_tree<'a>( bx: &mut FunctionBuilder, val: Value, otherwise: Block, - contiguous_case_ranges: Vec, - ) -> Vec<(EntryIndex, Block, Vec)> { - let mut cases_and_jt_blocks = Vec::new(); - + contiguous_case_ranges: &'a [ContiguousCaseRange], + ) { // Avoid allocation in the common case if contiguous_case_ranges.len() <= 3 { - Self::build_search_branches( - bx, - val, - otherwise, - contiguous_case_ranges, - &mut cases_and_jt_blocks, - ); - return cases_and_jt_blocks; + Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges); + return; } - let mut stack: Vec<(Option, Vec)> = Vec::new(); + let mut stack = Vec::new(); stack.push((None, contiguous_case_ranges)); while let Some((block, contiguous_case_ranges)) = stack.pop() { @@ -137,17 +129,10 @@ impl Switch { } if contiguous_case_ranges.len() <= 3 { - Self::build_search_branches( - bx, - val, - otherwise, - contiguous_case_ranges, - &mut cases_and_jt_blocks, - ); + Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges); } else { let split_point = contiguous_case_ranges.len() / 2; - let mut left = contiguous_case_ranges; - let right = left.split_off(split_point); + let (left, right) = contiguous_case_ranges.split_at(split_point); let left_block = bx.create_block(); let right_block = bx.create_block(); @@ -155,8 +140,8 @@ impl Switch { let first_index = right[0].first_index; let should_take_right_side = icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index); - bx.ins().brnz(should_take_right_side, right_block, &[]); - bx.ins().jump(left_block, &[]); + bx.ins() + .brif(should_take_right_side, right_block, &[], left_block, &[]); bx.seal_block(left_block); bx.seal_block(right_block); @@ -165,126 +150,108 @@ impl Switch { stack.push((Some(right_block), right)); } } - - cases_and_jt_blocks } /// Linear search for the right `ContiguousCaseRange`. - fn build_search_branches( + fn build_search_branches<'a>( bx: &mut FunctionBuilder, val: Value, otherwise: Block, - contiguous_case_ranges: Vec, - cases_and_jt_blocks: &mut Vec<(EntryIndex, Block, Vec)>, + contiguous_case_ranges: &'a [ContiguousCaseRange], ) { - let mut was_branch = false; - let ins_fallthrough_jump = |was_branch: bool, bx: &mut FunctionBuilder| { - if was_branch { - let block = bx.create_block(); - bx.ins().jump(block, &[]); - bx.seal_block(block); - bx.switch_to_block(block); - } - }; - for ContiguousCaseRange { - first_index, - blocks, - } in contiguous_case_ranges.into_iter().rev() - { - match (blocks.len(), first_index) { - (1, 0) => { - ins_fallthrough_jump(was_branch, bx); - bx.ins().brz(val, blocks[0], &[]); - } - (1, _) => { - ins_fallthrough_jump(was_branch, bx); - let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, first_index); - bx.ins().brnz(is_good_val, blocks[0], &[]); - } - (_, 0) => { - // if `first_index` is 0, then `icmp_imm uge val, first_index` is trivially true - let jt_block = bx.create_block(); - bx.ins().jump(jt_block, &[]); - bx.seal_block(jt_block); - cases_and_jt_blocks.push((first_index, jt_block, blocks)); - // `jump otherwise` below must not be hit, because the current block has been - // filled above. This is the last iteration anyway, as 0 is the smallest - // unsigned int, so just return here. - return; - } - (_, _) => { - ins_fallthrough_jump(was_branch, bx); - let jt_block = bx.create_block(); - let is_good_val = - icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index); - bx.ins().brnz(is_good_val, jt_block, &[]); - bx.seal_block(jt_block); - cases_and_jt_blocks.push((first_index, jt_block, blocks)); - } - } - was_branch = true; - } + let last_ix = contiguous_case_ranges.len() - 1; + for (ix, range) in contiguous_case_ranges.iter().rev().enumerate() { + let alternate = if ix == last_ix { + otherwise + } else { + bx.create_block() + }; - bx.ins().jump(otherwise, &[]); + if range.first_index == 0 { + assert_eq!(alternate, otherwise); + + if let Some(block) = range.single_block() { + bx.ins().brif(val, otherwise, &[], block, &[]); + } else { + Self::build_jump_table(bx, val, otherwise, 0, &range.blocks); + } + } else { + if let Some(block) = range.single_block() { + let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index); + bx.ins().brif(is_good_val, block, &[], alternate, &[]); + } else { + let is_good_val = icmp_imm_u128( + bx, + IntCC::UnsignedGreaterThanOrEqual, + val, + range.first_index, + ); + let jt_block = bx.create_block(); + bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]); + bx.seal_block(jt_block); + bx.switch_to_block(jt_block); + Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks); + } + } + + if alternate != otherwise { + bx.seal_block(alternate); + bx.switch_to_block(alternate); + } + } } - /// For every item in `cases_and_jt_blocks` this will create a jump table in the specified block. - fn build_jump_tables( + fn build_jump_table( bx: &mut FunctionBuilder, val: Value, otherwise: Block, - cases_and_jt_blocks: Vec<(EntryIndex, Block, Vec)>, + first_index: EntryIndex, + blocks: &[Block], ) { - for (first_index, jt_block, blocks) in cases_and_jt_blocks.into_iter().rev() { - // There are currently no 128bit systems supported by rustc, but once we do ensure that - // we don't silently ignore a part of the jump table for 128bit integers on 128bit systems. - assert!( - u32::try_from(blocks.len()).is_ok(), - "Jump tables bigger than 2^32-1 are not yet supported" - ); + // There are currently no 128bit systems supported by rustc, but once we do ensure that + // we don't silently ignore a part of the jump table for 128bit integers on 128bit systems. + assert!( + u32::try_from(blocks.len()).is_ok(), + "Jump tables bigger than 2^32-1 are not yet supported" + ); - let mut jt_data = JumpTableData::new(); - for block in blocks { - jt_data.push_entry(block); - } - let jump_table = bx.create_jump_table(jt_data); + let jt_data = JumpTableData::with_blocks(Vec::from(blocks)); + let jump_table = bx.create_jump_table(jt_data); - bx.switch_to_block(jt_block); - let discr = if first_index == 0 { - val + let discr = if first_index == 0 { + val + } else { + if let Ok(first_index) = u64::try_from(first_index) { + bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg()) } else { - if let Ok(first_index) = u64::try_from(first_index) { - bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg()) - } else { - let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64); - let lsb = bx.ins().iconst(types::I64, lsb as i64); - let msb = bx.ins().iconst(types::I64, msb as i64); - let index = bx.ins().iconcat(lsb, msb); - bx.ins().isub(val, index) - } - }; + let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64); + let lsb = bx.ins().iconst(types::I64, lsb as i64); + let msb = bx.ins().iconst(types::I64, msb as i64); + let index = bx.ins().iconcat(lsb, msb); + bx.ins().isub(val, index) + } + }; - let discr = match bx.func.dfg.value_type(discr).bits() { - bits if bits > 32 => { - // Check for overflow of cast to u32. This is the max supported jump table entries. - let new_block = bx.create_block(); - let bigger_than_u32 = - bx.ins() - .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64); - bx.ins().brnz(bigger_than_u32, otherwise, &[]); - bx.ins().jump(new_block, &[]); - bx.seal_block(new_block); - bx.switch_to_block(new_block); + let discr = match bx.func.dfg.value_type(discr).bits() { + bits if bits > 32 => { + // Check for overflow of cast to u32. This is the max supported jump table entries. + let new_block = bx.create_block(); + let bigger_than_u32 = + bx.ins() + .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64); + bx.ins() + .brif(bigger_than_u32, otherwise, &[], new_block, &[]); + bx.seal_block(new_block); + bx.switch_to_block(new_block); - // Cast to i32, as br_table is not implemented for i64/i128 - bx.ins().ireduce(types::I32, discr) - } - bits if bits < 32 => bx.ins().uextend(types::I32, discr), - _ => discr, - }; + // Cast to i32, as br_table is not implemented for i64/i128 + bx.ins().ireduce(types::I32, discr) + } + bits if bits < 32 => bx.ins().uextend(types::I32, discr), + _ => discr, + }; - bx.ins().br_table(discr, otherwise, jump_table); - } + bx.ins().br_table(discr, otherwise, jump_table); } /// Build the switch @@ -307,9 +274,7 @@ impl Switch { } let contiguous_case_ranges = self.collect_contiguous_case_ranges(); - let cases_and_jt_blocks = - Self::build_search_tree(bx, val, otherwise, contiguous_case_ranges); - Self::build_jump_tables(bx, val, otherwise, cases_and_jt_blocks); + Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges); } } @@ -351,6 +316,15 @@ impl ContiguousCaseRange { blocks: Vec::new(), } } + + /// Returns `Some` block when there is only a single block in this range. + fn single_block(&self) -> Option { + if self.blocks.len() == 1 { + Some(self.blocks[0]) + } else { + None + } + } } #[cfg(test)] @@ -384,43 +358,52 @@ mod tests { }}; } + macro_rules! assert_eq_output { + ($actual:ident, $expected:literal) => { + if $actual != $expected { + assert!( + false, + "\n{}", + similar::TextDiff::from_lines($expected, &$actual) + .unified_diff() + .header("expected", "actual") + ); + } + }; + } + #[test] fn switch_zero() { let func = setup!(0, [0,]); - assert_eq!( + assert_eq_output!( func, "block0: v0 = iconst.i8 0 - brz v0, block1 ; v0 = 0 - jump block0" + brif v0, block0, block1 ; v0 = 0" ); } #[test] fn switch_single() { let func = setup!(0, [1,]); - assert_eq!( + assert_eq_output!( func, "block0: v0 = iconst.i8 0 v1 = icmp_imm eq v0, 1 ; v0 = 0 - brnz v1, block1 - jump block0" + brif v1, block1, block0" ); } #[test] fn switch_bool() { let func = setup!(0, [0, 1,]); - assert_eq!( + assert_eq_output!( func, " jt0 = jump_table [block1, block2] block0: v0 = iconst.i8 0 - jump block3 - -block3: v1 = uextend.i32 v0 ; v0 = 0 br_table v1, block0, jt0" ); @@ -429,56 +412,50 @@ block3: #[test] fn switch_two_gap() { let func = setup!(0, [0, 2,]); - assert_eq!( + assert_eq_output!( func, "block0: v0 = iconst.i8 0 v1 = icmp_imm eq v0, 2 ; v0 = 0 - brnz v1, block2 - jump block3 + brif v1, block2, block3 block3: - brz.i8 v0, block1 ; v0 = 0 - jump block0" + brif.i8 v0, block0, block1 ; v0 = 0" ); } #[test] fn switch_many() { let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]); - assert_eq!( + assert_eq_output!( func, - " jt0 = jump_table [block1, block2] - jt1 = jump_table [block5, block6, block7] + " jt0 = jump_table [block5, block6, block7] + jt1 = jump_table [block1, block2] block0: v0 = iconst.i8 0 v1 = icmp_imm uge v0, 7 ; v0 = 0 - brnz v1, block9 - jump block8 + brif v1, block9, block8 block9: v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0 - brnz v2, block10 - jump block11 + brif v2, block11, block10 block11: - v3 = icmp_imm.i8 eq v0, 7 ; v0 = 0 - brnz v3, block4 - jump block0 - -block8: - v4 = icmp_imm.i8 eq v0, 5 ; v0 = 0 - brnz v4, block3 - jump block12 - -block12: - v5 = uextend.i32 v0 ; v0 = 0 - br_table v5, block0, jt0 + v3 = iadd_imm.i8 v0, -10 ; v0 = 0 + v4 = uextend.i32 v3 + br_table v4, block0, jt0 block10: - v6 = iadd_imm.i8 v0, -10 ; v0 = 0 - v7 = uextend.i32 v6 + v5 = icmp_imm.i8 eq v0, 7 ; v0 = 0 + brif v5, block4, block0 + +block8: + v6 = icmp_imm.i8 eq v0, 5 ; v0 = 0 + brif v6, block3, block12 + +block12: + v7 = uextend.i32 v0 ; v0 = 0 br_table v7, block0, jt1" ); } @@ -486,51 +463,46 @@ block10: #[test] fn switch_min_index_value() { let func = setup!(0, [i8::MIN as u8 as u128, 1,]); - assert_eq!( + assert_eq_output!( func, "block0: v0 = iconst.i8 0 v1 = icmp_imm eq v0, 128 ; v0 = 0 - brnz v1, block1 - jump block3 + brif v1, block1, block3 block3: v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0 - brnz v2, block2 - jump block0" + brif v2, block2, block0" ); } #[test] fn switch_max_index_value() { let func = setup!(0, [i8::MAX as u8 as u128, 1,]); - assert_eq!( + assert_eq_output!( func, "block0: v0 = iconst.i8 0 v1 = icmp_imm eq v0, 127 ; v0 = 0 - brnz v1, block1 - jump block3 + brif v1, block1, block3 block3: v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0 - brnz v2, block2 - jump block0" + brif v2, block2, block0" ) } #[test] fn switch_optimal_codegen() { let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]); - assert_eq!( + assert_eq_output!( func, " jt0 = jump_table [block2, block3] block0: v0 = iconst.i8 0 v1 = icmp_imm eq v0, 255 ; v0 = 0 - brnz v1, block1 - jump block4 + brif v1, block1, block4 block4: v2 = uextend.i32 v0 ; v0 = 0 @@ -617,20 +589,16 @@ block4: .trim_start_matches("function u0:0() fast {\n") .trim_end_matches("\n}\n") .to_string(); - assert_eq!( + assert_eq_output!( func, " jt0 = jump_table [block2, block1] block0: v0 = iconst.i64 0 - jump block4 + v1 = icmp_imm ugt v0, 0xffff_ffff ; v0 = 0 + brif v1, block3, block4 block4: - v1 = icmp_imm.i64 ugt v0, 0xffff_ffff ; v0 = 0 - brnz v1, block3 - jump block5 - -block5: v2 = ireduce.i32 v0 ; v0 = 0 br_table v2, block3, jt0" ); @@ -659,21 +627,17 @@ block5: .trim_start_matches("function u0:0() fast {\n") .trim_end_matches("\n}\n") .to_string(); - assert_eq!( + assert_eq_output!( func, " jt0 = jump_table [block2, block1] block0: v0 = iconst.i64 0 v1 = uextend.i128 v0 ; v0 = 0 - jump block4 + v2 = icmp_imm ugt v1, 0xffff_ffff + brif v2, block3, block4 block4: - v2 = icmp_imm.i128 ugt v1, 0xffff_ffff - brnz v2, block3 - jump block5 - -block5: v3 = ireduce.i32 v1 br_table v3, block3, jt0" );