Rework the switch module in cranelift-frontend in terms of brif (#5644)

Rework the compilation strategy for switch to:
* use brif instead of brz and brnz
* generate tables inline, rather than delyaing them to after the decision tree has been generated
* avoid allocating new vectors by using slices into the sorted contiguous ranges
* avoid generating some unconditional jumps
* output differences in test output using the similar crate for easier debugging
This commit is contained in:
Trevor Elliott
2023-01-27 16:00:40 -08:00
committed by GitHub
parent 0f8393508a
commit b47006d432
3 changed files with 159 additions and 191 deletions

1
Cargo.lock generated
View File

@@ -622,6 +622,7 @@ dependencies = [
"cranelift-codegen", "cranelift-codegen",
"hashbrown", "hashbrown",
"log", "log",
"similar",
"smallvec", "smallvec",
"target-lexicon", "target-lexicon",
] ]

View File

@@ -17,6 +17,9 @@ log = { workspace = true }
hashbrown = { workspace = true, optional = true } hashbrown = { workspace = true, optional = true }
smallvec = { workspace = true } smallvec = { workspace = true }
[dev-dependencies]
similar = { workspace = true }
[features] [features]
default = ["std"] default = ["std"]
std = ["cranelift-codegen/std"] std = ["cranelift-codegen/std"]

View File

@@ -108,27 +108,19 @@ impl Switch {
} }
/// Binary search for the right `ContiguousCaseRange`. /// Binary search for the right `ContiguousCaseRange`.
fn build_search_tree( fn build_search_tree<'a>(
bx: &mut FunctionBuilder, bx: &mut FunctionBuilder,
val: Value, val: Value,
otherwise: Block, otherwise: Block,
contiguous_case_ranges: Vec<ContiguousCaseRange>, contiguous_case_ranges: &'a [ContiguousCaseRange],
) -> Vec<(EntryIndex, Block, Vec<Block>)> { ) {
let mut cases_and_jt_blocks = Vec::new();
// Avoid allocation in the common case // Avoid allocation in the common case
if contiguous_case_ranges.len() <= 3 { if contiguous_case_ranges.len() <= 3 {
Self::build_search_branches( Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
bx, return;
val,
otherwise,
contiguous_case_ranges,
&mut cases_and_jt_blocks,
);
return cases_and_jt_blocks;
} }
let mut stack: Vec<(Option<Block>, Vec<ContiguousCaseRange>)> = Vec::new(); let mut stack = Vec::new();
stack.push((None, contiguous_case_ranges)); stack.push((None, contiguous_case_ranges));
while let Some((block, contiguous_case_ranges)) = stack.pop() { while let Some((block, contiguous_case_ranges)) = stack.pop() {
@@ -137,17 +129,10 @@ impl Switch {
} }
if contiguous_case_ranges.len() <= 3 { if contiguous_case_ranges.len() <= 3 {
Self::build_search_branches( Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
bx,
val,
otherwise,
contiguous_case_ranges,
&mut cases_and_jt_blocks,
);
} else { } else {
let split_point = contiguous_case_ranges.len() / 2; let split_point = contiguous_case_ranges.len() / 2;
let mut left = contiguous_case_ranges; let (left, right) = contiguous_case_ranges.split_at(split_point);
let right = left.split_off(split_point);
let left_block = bx.create_block(); let left_block = bx.create_block();
let right_block = bx.create_block(); let right_block = bx.create_block();
@@ -155,8 +140,8 @@ impl Switch {
let first_index = right[0].first_index; let first_index = right[0].first_index;
let should_take_right_side = let should_take_right_side =
icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index); icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
bx.ins().brnz(should_take_right_side, right_block, &[]); bx.ins()
bx.ins().jump(left_block, &[]); .brif(should_take_right_side, right_block, &[], left_block, &[]);
bx.seal_block(left_block); bx.seal_block(left_block);
bx.seal_block(right_block); bx.seal_block(right_block);
@@ -165,77 +150,64 @@ impl Switch {
stack.push((Some(right_block), right)); stack.push((Some(right_block), right));
} }
} }
cases_and_jt_blocks
} }
/// Linear search for the right `ContiguousCaseRange`. /// Linear search for the right `ContiguousCaseRange`.
fn build_search_branches( fn build_search_branches<'a>(
bx: &mut FunctionBuilder, bx: &mut FunctionBuilder,
val: Value, val: Value,
otherwise: Block, otherwise: Block,
contiguous_case_ranges: Vec<ContiguousCaseRange>, contiguous_case_ranges: &'a [ContiguousCaseRange],
cases_and_jt_blocks: &mut Vec<(EntryIndex, Block, Vec<Block>)>,
) { ) {
let mut was_branch = false; let last_ix = contiguous_case_ranges.len() - 1;
let ins_fallthrough_jump = |was_branch: bool, bx: &mut FunctionBuilder| { for (ix, range) in contiguous_case_ranges.iter().rev().enumerate() {
if was_branch { let alternate = if ix == last_ix {
let block = bx.create_block(); otherwise
bx.ins().jump(block, &[]); } else {
bx.seal_block(block); bx.create_block()
bx.switch_to_block(block);
}
}; };
for ContiguousCaseRange {
first_index, if range.first_index == 0 {
blocks, assert_eq!(alternate, otherwise);
} in contiguous_case_ranges.into_iter().rev()
{ if let Some(block) = range.single_block() {
match (blocks.len(), first_index) { bx.ins().brif(val, otherwise, &[], block, &[]);
(1, 0) => { } else {
ins_fallthrough_jump(was_branch, bx); Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
bx.ins().brz(val, blocks[0], &[]);
} }
(1, _) => { } else {
ins_fallthrough_jump(was_branch, bx); if let Some(block) = range.single_block() {
let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, first_index); let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
bx.ins().brnz(is_good_val, blocks[0], &[]); bx.ins().brif(is_good_val, block, &[], alternate, &[]);
} } else {
(_, 0) => { let is_good_val = icmp_imm_u128(
// if `first_index` is 0, then `icmp_imm uge val, first_index` is trivially true bx,
IntCC::UnsignedGreaterThanOrEqual,
val,
range.first_index,
);
let jt_block = bx.create_block(); let jt_block = bx.create_block();
bx.ins().jump(jt_block, &[]); bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
bx.seal_block(jt_block); bx.seal_block(jt_block);
cases_and_jt_blocks.push((first_index, jt_block, blocks)); bx.switch_to_block(jt_block);
// `jump otherwise` below must not be hit, because the current block has been Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
// filled above. This is the last iteration anyway, as 0 is the smallest
// unsigned int, so just return here.
return;
} }
(_, _) => {
ins_fallthrough_jump(was_branch, bx);
let jt_block = bx.create_block();
let is_good_val =
icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
bx.ins().brnz(is_good_val, jt_block, &[]);
bx.seal_block(jt_block);
cases_and_jt_blocks.push((first_index, jt_block, blocks));
}
}
was_branch = true;
} }
bx.ins().jump(otherwise, &[]); if alternate != otherwise {
bx.seal_block(alternate);
bx.switch_to_block(alternate);
}
}
} }
/// For every item in `cases_and_jt_blocks` this will create a jump table in the specified block. fn build_jump_table(
fn build_jump_tables(
bx: &mut FunctionBuilder, bx: &mut FunctionBuilder,
val: Value, val: Value,
otherwise: Block, otherwise: Block,
cases_and_jt_blocks: Vec<(EntryIndex, Block, Vec<Block>)>, first_index: EntryIndex,
blocks: &[Block],
) { ) {
for (first_index, jt_block, blocks) in cases_and_jt_blocks.into_iter().rev() {
// There are currently no 128bit systems supported by rustc, but once we do ensure that // There are currently no 128bit systems supported by rustc, but once we do ensure that
// we don't silently ignore a part of the jump table for 128bit integers on 128bit systems. // we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
assert!( assert!(
@@ -243,13 +215,9 @@ impl Switch {
"Jump tables bigger than 2^32-1 are not yet supported" "Jump tables bigger than 2^32-1 are not yet supported"
); );
let mut jt_data = JumpTableData::new(); let jt_data = JumpTableData::with_blocks(Vec::from(blocks));
for block in blocks {
jt_data.push_entry(block);
}
let jump_table = bx.create_jump_table(jt_data); let jump_table = bx.create_jump_table(jt_data);
bx.switch_to_block(jt_block);
let discr = if first_index == 0 { let discr = if first_index == 0 {
val val
} else { } else {
@@ -271,8 +239,8 @@ impl Switch {
let bigger_than_u32 = let bigger_than_u32 =
bx.ins() bx.ins()
.icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64); .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
bx.ins().brnz(bigger_than_u32, otherwise, &[]); bx.ins()
bx.ins().jump(new_block, &[]); .brif(bigger_than_u32, otherwise, &[], new_block, &[]);
bx.seal_block(new_block); bx.seal_block(new_block);
bx.switch_to_block(new_block); bx.switch_to_block(new_block);
@@ -285,7 +253,6 @@ impl Switch {
bx.ins().br_table(discr, otherwise, jump_table); bx.ins().br_table(discr, otherwise, jump_table);
} }
}
/// Build the switch /// Build the switch
/// ///
@@ -307,9 +274,7 @@ impl Switch {
} }
let contiguous_case_ranges = self.collect_contiguous_case_ranges(); let contiguous_case_ranges = self.collect_contiguous_case_ranges();
let cases_and_jt_blocks = Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
Self::build_search_tree(bx, val, otherwise, contiguous_case_ranges);
Self::build_jump_tables(bx, val, otherwise, cases_and_jt_blocks);
} }
} }
@@ -351,6 +316,15 @@ impl ContiguousCaseRange {
blocks: Vec::new(), blocks: Vec::new(),
} }
} }
/// Returns `Some` block when there is only a single block in this range.
fn single_block(&self) -> Option<Block> {
if self.blocks.len() == 1 {
Some(self.blocks[0])
} else {
None
}
}
} }
#[cfg(test)] #[cfg(test)]
@@ -384,43 +358,52 @@ mod tests {
}}; }};
} }
macro_rules! assert_eq_output {
($actual:ident, $expected:literal) => {
if $actual != $expected {
assert!(
false,
"\n{}",
similar::TextDiff::from_lines($expected, &$actual)
.unified_diff()
.header("expected", "actual")
);
}
};
}
#[test] #[test]
fn switch_zero() { fn switch_zero() {
let func = setup!(0, [0,]); let func = setup!(0, [0,]);
assert_eq!( assert_eq_output!(
func, func,
"block0: "block0:
v0 = iconst.i8 0 v0 = iconst.i8 0
brz v0, block1 ; v0 = 0 brif v0, block0, block1 ; v0 = 0"
jump block0"
); );
} }
#[test] #[test]
fn switch_single() { fn switch_single() {
let func = setup!(0, [1,]); let func = setup!(0, [1,]);
assert_eq!( assert_eq_output!(
func, func,
"block0: "block0:
v0 = iconst.i8 0 v0 = iconst.i8 0
v1 = icmp_imm eq v0, 1 ; v0 = 0 v1 = icmp_imm eq v0, 1 ; v0 = 0
brnz v1, block1 brif v1, block1, block0"
jump block0"
); );
} }
#[test] #[test]
fn switch_bool() { fn switch_bool() {
let func = setup!(0, [0, 1,]); let func = setup!(0, [0, 1,]);
assert_eq!( assert_eq_output!(
func, func,
" jt0 = jump_table [block1, block2] " jt0 = jump_table [block1, block2]
block0: block0:
v0 = iconst.i8 0 v0 = iconst.i8 0
jump block3
block3:
v1 = uextend.i32 v0 ; v0 = 0 v1 = uextend.i32 v0 ; v0 = 0
br_table v1, block0, jt0" br_table v1, block0, jt0"
); );
@@ -429,56 +412,50 @@ block3:
#[test] #[test]
fn switch_two_gap() { fn switch_two_gap() {
let func = setup!(0, [0, 2,]); let func = setup!(0, [0, 2,]);
assert_eq!( assert_eq_output!(
func, func,
"block0: "block0:
v0 = iconst.i8 0 v0 = iconst.i8 0
v1 = icmp_imm eq v0, 2 ; v0 = 0 v1 = icmp_imm eq v0, 2 ; v0 = 0
brnz v1, block2 brif v1, block2, block3
jump block3
block3: block3:
brz.i8 v0, block1 ; v0 = 0 brif.i8 v0, block0, block1 ; v0 = 0"
jump block0"
); );
} }
#[test] #[test]
fn switch_many() { fn switch_many() {
let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]); let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
assert_eq!( assert_eq_output!(
func, func,
" jt0 = jump_table [block1, block2] " jt0 = jump_table [block5, block6, block7]
jt1 = jump_table [block5, block6, block7] jt1 = jump_table [block1, block2]
block0: block0:
v0 = iconst.i8 0 v0 = iconst.i8 0
v1 = icmp_imm uge v0, 7 ; v0 = 0 v1 = icmp_imm uge v0, 7 ; v0 = 0
brnz v1, block9 brif v1, block9, block8
jump block8
block9: block9:
v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0 v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0
brnz v2, block10 brif v2, block11, block10
jump block11
block11: block11:
v3 = icmp_imm.i8 eq v0, 7 ; v0 = 0 v3 = iadd_imm.i8 v0, -10 ; v0 = 0
brnz v3, block4 v4 = uextend.i32 v3
jump block0 br_table v4, block0, jt0
block8:
v4 = icmp_imm.i8 eq v0, 5 ; v0 = 0
brnz v4, block3
jump block12
block12:
v5 = uextend.i32 v0 ; v0 = 0
br_table v5, block0, jt0
block10: block10:
v6 = iadd_imm.i8 v0, -10 ; v0 = 0 v5 = icmp_imm.i8 eq v0, 7 ; v0 = 0
v7 = uextend.i32 v6 brif v5, block4, block0
block8:
v6 = icmp_imm.i8 eq v0, 5 ; v0 = 0
brif v6, block3, block12
block12:
v7 = uextend.i32 v0 ; v0 = 0
br_table v7, block0, jt1" br_table v7, block0, jt1"
); );
} }
@@ -486,51 +463,46 @@ block10:
#[test] #[test]
fn switch_min_index_value() { fn switch_min_index_value() {
let func = setup!(0, [i8::MIN as u8 as u128, 1,]); let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
assert_eq!( assert_eq_output!(
func, func,
"block0: "block0:
v0 = iconst.i8 0 v0 = iconst.i8 0
v1 = icmp_imm eq v0, 128 ; v0 = 0 v1 = icmp_imm eq v0, 128 ; v0 = 0
brnz v1, block1 brif v1, block1, block3
jump block3
block3: block3:
v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0 v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
brnz v2, block2 brif v2, block2, block0"
jump block0"
); );
} }
#[test] #[test]
fn switch_max_index_value() { fn switch_max_index_value() {
let func = setup!(0, [i8::MAX as u8 as u128, 1,]); let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
assert_eq!( assert_eq_output!(
func, func,
"block0: "block0:
v0 = iconst.i8 0 v0 = iconst.i8 0
v1 = icmp_imm eq v0, 127 ; v0 = 0 v1 = icmp_imm eq v0, 127 ; v0 = 0
brnz v1, block1 brif v1, block1, block3
jump block3
block3: block3:
v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0 v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
brnz v2, block2 brif v2, block2, block0"
jump block0"
) )
} }
#[test] #[test]
fn switch_optimal_codegen() { fn switch_optimal_codegen() {
let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]); let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
assert_eq!( assert_eq_output!(
func, func,
" jt0 = jump_table [block2, block3] " jt0 = jump_table [block2, block3]
block0: block0:
v0 = iconst.i8 0 v0 = iconst.i8 0
v1 = icmp_imm eq v0, 255 ; v0 = 0 v1 = icmp_imm eq v0, 255 ; v0 = 0
brnz v1, block1 brif v1, block1, block4
jump block4
block4: block4:
v2 = uextend.i32 v0 ; v0 = 0 v2 = uextend.i32 v0 ; v0 = 0
@@ -617,20 +589,16 @@ block4:
.trim_start_matches("function u0:0() fast {\n") .trim_start_matches("function u0:0() fast {\n")
.trim_end_matches("\n}\n") .trim_end_matches("\n}\n")
.to_string(); .to_string();
assert_eq!( assert_eq_output!(
func, func,
" jt0 = jump_table [block2, block1] " jt0 = jump_table [block2, block1]
block0: block0:
v0 = iconst.i64 0 v0 = iconst.i64 0
jump block4 v1 = icmp_imm ugt v0, 0xffff_ffff ; v0 = 0
brif v1, block3, block4
block4: block4:
v1 = icmp_imm.i64 ugt v0, 0xffff_ffff ; v0 = 0
brnz v1, block3
jump block5
block5:
v2 = ireduce.i32 v0 ; v0 = 0 v2 = ireduce.i32 v0 ; v0 = 0
br_table v2, block3, jt0" br_table v2, block3, jt0"
); );
@@ -659,21 +627,17 @@ block5:
.trim_start_matches("function u0:0() fast {\n") .trim_start_matches("function u0:0() fast {\n")
.trim_end_matches("\n}\n") .trim_end_matches("\n}\n")
.to_string(); .to_string();
assert_eq!( assert_eq_output!(
func, func,
" jt0 = jump_table [block2, block1] " jt0 = jump_table [block2, block1]
block0: block0:
v0 = iconst.i64 0 v0 = iconst.i64 0
v1 = uextend.i128 v0 ; v0 = 0 v1 = uextend.i128 v0 ; v0 = 0
jump block4 v2 = icmp_imm ugt v1, 0xffff_ffff
brif v2, block3, block4
block4: block4:
v2 = icmp_imm.i128 ugt v1, 0xffff_ffff
brnz v2, block3
jump block5
block5:
v3 = ireduce.i32 v1 v3 = ireduce.i32 v1
br_table v3, block3, jt0" br_table v3, block3, jt0"
); );