Rework the switch module in cranelift-frontend in terms of brif (#5644)
Rework the compilation strategy for switch to: * use brif instead of brz and brnz * generate tables inline, rather than delyaing them to after the decision tree has been generated * avoid allocating new vectors by using slices into the sorted contiguous ranges * avoid generating some unconditional jumps * output differences in test output using the similar crate for easier debugging
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -622,6 +622,7 @@ dependencies = [
|
||||
"cranelift-codegen",
|
||||
"hashbrown",
|
||||
"log",
|
||||
"similar",
|
||||
"smallvec",
|
||||
"target-lexicon",
|
||||
]
|
||||
|
||||
@@ -17,6 +17,9 @@ log = { workspace = true }
|
||||
hashbrown = { workspace = true, optional = true }
|
||||
smallvec = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
similar = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = ["cranelift-codegen/std"]
|
||||
|
||||
@@ -108,27 +108,19 @@ impl Switch {
|
||||
}
|
||||
|
||||
/// Binary search for the right `ContiguousCaseRange`.
|
||||
fn build_search_tree(
|
||||
fn build_search_tree<'a>(
|
||||
bx: &mut FunctionBuilder,
|
||||
val: Value,
|
||||
otherwise: Block,
|
||||
contiguous_case_ranges: Vec<ContiguousCaseRange>,
|
||||
) -> Vec<(EntryIndex, Block, Vec<Block>)> {
|
||||
let mut cases_and_jt_blocks = Vec::new();
|
||||
|
||||
contiguous_case_ranges: &'a [ContiguousCaseRange],
|
||||
) {
|
||||
// Avoid allocation in the common case
|
||||
if contiguous_case_ranges.len() <= 3 {
|
||||
Self::build_search_branches(
|
||||
bx,
|
||||
val,
|
||||
otherwise,
|
||||
contiguous_case_ranges,
|
||||
&mut cases_and_jt_blocks,
|
||||
);
|
||||
return cases_and_jt_blocks;
|
||||
Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
|
||||
return;
|
||||
}
|
||||
|
||||
let mut stack: Vec<(Option<Block>, Vec<ContiguousCaseRange>)> = Vec::new();
|
||||
let mut stack = Vec::new();
|
||||
stack.push((None, contiguous_case_ranges));
|
||||
|
||||
while let Some((block, contiguous_case_ranges)) = stack.pop() {
|
||||
@@ -137,17 +129,10 @@ impl Switch {
|
||||
}
|
||||
|
||||
if contiguous_case_ranges.len() <= 3 {
|
||||
Self::build_search_branches(
|
||||
bx,
|
||||
val,
|
||||
otherwise,
|
||||
contiguous_case_ranges,
|
||||
&mut cases_and_jt_blocks,
|
||||
);
|
||||
Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
|
||||
} else {
|
||||
let split_point = contiguous_case_ranges.len() / 2;
|
||||
let mut left = contiguous_case_ranges;
|
||||
let right = left.split_off(split_point);
|
||||
let (left, right) = contiguous_case_ranges.split_at(split_point);
|
||||
|
||||
let left_block = bx.create_block();
|
||||
let right_block = bx.create_block();
|
||||
@@ -155,8 +140,8 @@ impl Switch {
|
||||
let first_index = right[0].first_index;
|
||||
let should_take_right_side =
|
||||
icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
|
||||
bx.ins().brnz(should_take_right_side, right_block, &[]);
|
||||
bx.ins().jump(left_block, &[]);
|
||||
bx.ins()
|
||||
.brif(should_take_right_side, right_block, &[], left_block, &[]);
|
||||
|
||||
bx.seal_block(left_block);
|
||||
bx.seal_block(right_block);
|
||||
@@ -165,77 +150,64 @@ impl Switch {
|
||||
stack.push((Some(right_block), right));
|
||||
}
|
||||
}
|
||||
|
||||
cases_and_jt_blocks
|
||||
}
|
||||
|
||||
/// Linear search for the right `ContiguousCaseRange`.
|
||||
fn build_search_branches(
|
||||
fn build_search_branches<'a>(
|
||||
bx: &mut FunctionBuilder,
|
||||
val: Value,
|
||||
otherwise: Block,
|
||||
contiguous_case_ranges: Vec<ContiguousCaseRange>,
|
||||
cases_and_jt_blocks: &mut Vec<(EntryIndex, Block, Vec<Block>)>,
|
||||
contiguous_case_ranges: &'a [ContiguousCaseRange],
|
||||
) {
|
||||
let mut was_branch = false;
|
||||
let ins_fallthrough_jump = |was_branch: bool, bx: &mut FunctionBuilder| {
|
||||
if was_branch {
|
||||
let block = bx.create_block();
|
||||
bx.ins().jump(block, &[]);
|
||||
bx.seal_block(block);
|
||||
bx.switch_to_block(block);
|
||||
}
|
||||
let last_ix = contiguous_case_ranges.len() - 1;
|
||||
for (ix, range) in contiguous_case_ranges.iter().rev().enumerate() {
|
||||
let alternate = if ix == last_ix {
|
||||
otherwise
|
||||
} else {
|
||||
bx.create_block()
|
||||
};
|
||||
for ContiguousCaseRange {
|
||||
first_index,
|
||||
blocks,
|
||||
} in contiguous_case_ranges.into_iter().rev()
|
||||
{
|
||||
match (blocks.len(), first_index) {
|
||||
(1, 0) => {
|
||||
ins_fallthrough_jump(was_branch, bx);
|
||||
bx.ins().brz(val, blocks[0], &[]);
|
||||
|
||||
if range.first_index == 0 {
|
||||
assert_eq!(alternate, otherwise);
|
||||
|
||||
if let Some(block) = range.single_block() {
|
||||
bx.ins().brif(val, otherwise, &[], block, &[]);
|
||||
} else {
|
||||
Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
|
||||
}
|
||||
(1, _) => {
|
||||
ins_fallthrough_jump(was_branch, bx);
|
||||
let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, first_index);
|
||||
bx.ins().brnz(is_good_val, blocks[0], &[]);
|
||||
}
|
||||
(_, 0) => {
|
||||
// if `first_index` is 0, then `icmp_imm uge val, first_index` is trivially true
|
||||
} else {
|
||||
if let Some(block) = range.single_block() {
|
||||
let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
|
||||
bx.ins().brif(is_good_val, block, &[], alternate, &[]);
|
||||
} else {
|
||||
let is_good_val = icmp_imm_u128(
|
||||
bx,
|
||||
IntCC::UnsignedGreaterThanOrEqual,
|
||||
val,
|
||||
range.first_index,
|
||||
);
|
||||
let jt_block = bx.create_block();
|
||||
bx.ins().jump(jt_block, &[]);
|
||||
bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
|
||||
bx.seal_block(jt_block);
|
||||
cases_and_jt_blocks.push((first_index, jt_block, blocks));
|
||||
// `jump otherwise` below must not be hit, because the current block has been
|
||||
// filled above. This is the last iteration anyway, as 0 is the smallest
|
||||
// unsigned int, so just return here.
|
||||
return;
|
||||
bx.switch_to_block(jt_block);
|
||||
Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
|
||||
}
|
||||
(_, _) => {
|
||||
ins_fallthrough_jump(was_branch, bx);
|
||||
let jt_block = bx.create_block();
|
||||
let is_good_val =
|
||||
icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
|
||||
bx.ins().brnz(is_good_val, jt_block, &[]);
|
||||
bx.seal_block(jt_block);
|
||||
cases_and_jt_blocks.push((first_index, jt_block, blocks));
|
||||
}
|
||||
}
|
||||
was_branch = true;
|
||||
}
|
||||
|
||||
bx.ins().jump(otherwise, &[]);
|
||||
if alternate != otherwise {
|
||||
bx.seal_block(alternate);
|
||||
bx.switch_to_block(alternate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// For every item in `cases_and_jt_blocks` this will create a jump table in the specified block.
|
||||
fn build_jump_tables(
|
||||
fn build_jump_table(
|
||||
bx: &mut FunctionBuilder,
|
||||
val: Value,
|
||||
otherwise: Block,
|
||||
cases_and_jt_blocks: Vec<(EntryIndex, Block, Vec<Block>)>,
|
||||
first_index: EntryIndex,
|
||||
blocks: &[Block],
|
||||
) {
|
||||
for (first_index, jt_block, blocks) in cases_and_jt_blocks.into_iter().rev() {
|
||||
// There are currently no 128bit systems supported by rustc, but once we do ensure that
|
||||
// we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
|
||||
assert!(
|
||||
@@ -243,13 +215,9 @@ impl Switch {
|
||||
"Jump tables bigger than 2^32-1 are not yet supported"
|
||||
);
|
||||
|
||||
let mut jt_data = JumpTableData::new();
|
||||
for block in blocks {
|
||||
jt_data.push_entry(block);
|
||||
}
|
||||
let jt_data = JumpTableData::with_blocks(Vec::from(blocks));
|
||||
let jump_table = bx.create_jump_table(jt_data);
|
||||
|
||||
bx.switch_to_block(jt_block);
|
||||
let discr = if first_index == 0 {
|
||||
val
|
||||
} else {
|
||||
@@ -271,8 +239,8 @@ impl Switch {
|
||||
let bigger_than_u32 =
|
||||
bx.ins()
|
||||
.icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
|
||||
bx.ins().brnz(bigger_than_u32, otherwise, &[]);
|
||||
bx.ins().jump(new_block, &[]);
|
||||
bx.ins()
|
||||
.brif(bigger_than_u32, otherwise, &[], new_block, &[]);
|
||||
bx.seal_block(new_block);
|
||||
bx.switch_to_block(new_block);
|
||||
|
||||
@@ -285,7 +253,6 @@ impl Switch {
|
||||
|
||||
bx.ins().br_table(discr, otherwise, jump_table);
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the switch
|
||||
///
|
||||
@@ -307,9 +274,7 @@ impl Switch {
|
||||
}
|
||||
|
||||
let contiguous_case_ranges = self.collect_contiguous_case_ranges();
|
||||
let cases_and_jt_blocks =
|
||||
Self::build_search_tree(bx, val, otherwise, contiguous_case_ranges);
|
||||
Self::build_jump_tables(bx, val, otherwise, cases_and_jt_blocks);
|
||||
Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -351,6 +316,15 @@ impl ContiguousCaseRange {
|
||||
blocks: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `Some` block when there is only a single block in this range.
|
||||
fn single_block(&self) -> Option<Block> {
|
||||
if self.blocks.len() == 1 {
|
||||
Some(self.blocks[0])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -384,43 +358,52 @@ mod tests {
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! assert_eq_output {
|
||||
($actual:ident, $expected:literal) => {
|
||||
if $actual != $expected {
|
||||
assert!(
|
||||
false,
|
||||
"\n{}",
|
||||
similar::TextDiff::from_lines($expected, &$actual)
|
||||
.unified_diff()
|
||||
.header("expected", "actual")
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn switch_zero() {
|
||||
let func = setup!(0, [0,]);
|
||||
assert_eq!(
|
||||
assert_eq_output!(
|
||||
func,
|
||||
"block0:
|
||||
v0 = iconst.i8 0
|
||||
brz v0, block1 ; v0 = 0
|
||||
jump block0"
|
||||
brif v0, block0, block1 ; v0 = 0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn switch_single() {
|
||||
let func = setup!(0, [1,]);
|
||||
assert_eq!(
|
||||
assert_eq_output!(
|
||||
func,
|
||||
"block0:
|
||||
v0 = iconst.i8 0
|
||||
v1 = icmp_imm eq v0, 1 ; v0 = 0
|
||||
brnz v1, block1
|
||||
jump block0"
|
||||
brif v1, block1, block0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn switch_bool() {
|
||||
let func = setup!(0, [0, 1,]);
|
||||
assert_eq!(
|
||||
assert_eq_output!(
|
||||
func,
|
||||
" jt0 = jump_table [block1, block2]
|
||||
|
||||
block0:
|
||||
v0 = iconst.i8 0
|
||||
jump block3
|
||||
|
||||
block3:
|
||||
v1 = uextend.i32 v0 ; v0 = 0
|
||||
br_table v1, block0, jt0"
|
||||
);
|
||||
@@ -429,56 +412,50 @@ block3:
|
||||
#[test]
|
||||
fn switch_two_gap() {
|
||||
let func = setup!(0, [0, 2,]);
|
||||
assert_eq!(
|
||||
assert_eq_output!(
|
||||
func,
|
||||
"block0:
|
||||
v0 = iconst.i8 0
|
||||
v1 = icmp_imm eq v0, 2 ; v0 = 0
|
||||
brnz v1, block2
|
||||
jump block3
|
||||
brif v1, block2, block3
|
||||
|
||||
block3:
|
||||
brz.i8 v0, block1 ; v0 = 0
|
||||
jump block0"
|
||||
brif.i8 v0, block0, block1 ; v0 = 0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn switch_many() {
|
||||
let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
|
||||
assert_eq!(
|
||||
assert_eq_output!(
|
||||
func,
|
||||
" jt0 = jump_table [block1, block2]
|
||||
jt1 = jump_table [block5, block6, block7]
|
||||
" jt0 = jump_table [block5, block6, block7]
|
||||
jt1 = jump_table [block1, block2]
|
||||
|
||||
block0:
|
||||
v0 = iconst.i8 0
|
||||
v1 = icmp_imm uge v0, 7 ; v0 = 0
|
||||
brnz v1, block9
|
||||
jump block8
|
||||
brif v1, block9, block8
|
||||
|
||||
block9:
|
||||
v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0
|
||||
brnz v2, block10
|
||||
jump block11
|
||||
brif v2, block11, block10
|
||||
|
||||
block11:
|
||||
v3 = icmp_imm.i8 eq v0, 7 ; v0 = 0
|
||||
brnz v3, block4
|
||||
jump block0
|
||||
|
||||
block8:
|
||||
v4 = icmp_imm.i8 eq v0, 5 ; v0 = 0
|
||||
brnz v4, block3
|
||||
jump block12
|
||||
|
||||
block12:
|
||||
v5 = uextend.i32 v0 ; v0 = 0
|
||||
br_table v5, block0, jt0
|
||||
v3 = iadd_imm.i8 v0, -10 ; v0 = 0
|
||||
v4 = uextend.i32 v3
|
||||
br_table v4, block0, jt0
|
||||
|
||||
block10:
|
||||
v6 = iadd_imm.i8 v0, -10 ; v0 = 0
|
||||
v7 = uextend.i32 v6
|
||||
v5 = icmp_imm.i8 eq v0, 7 ; v0 = 0
|
||||
brif v5, block4, block0
|
||||
|
||||
block8:
|
||||
v6 = icmp_imm.i8 eq v0, 5 ; v0 = 0
|
||||
brif v6, block3, block12
|
||||
|
||||
block12:
|
||||
v7 = uextend.i32 v0 ; v0 = 0
|
||||
br_table v7, block0, jt1"
|
||||
);
|
||||
}
|
||||
@@ -486,51 +463,46 @@ block10:
|
||||
#[test]
|
||||
fn switch_min_index_value() {
|
||||
let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
|
||||
assert_eq!(
|
||||
assert_eq_output!(
|
||||
func,
|
||||
"block0:
|
||||
v0 = iconst.i8 0
|
||||
v1 = icmp_imm eq v0, 128 ; v0 = 0
|
||||
brnz v1, block1
|
||||
jump block3
|
||||
brif v1, block1, block3
|
||||
|
||||
block3:
|
||||
v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
|
||||
brnz v2, block2
|
||||
jump block0"
|
||||
brif v2, block2, block0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn switch_max_index_value() {
|
||||
let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
|
||||
assert_eq!(
|
||||
assert_eq_output!(
|
||||
func,
|
||||
"block0:
|
||||
v0 = iconst.i8 0
|
||||
v1 = icmp_imm eq v0, 127 ; v0 = 0
|
||||
brnz v1, block1
|
||||
jump block3
|
||||
brif v1, block1, block3
|
||||
|
||||
block3:
|
||||
v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
|
||||
brnz v2, block2
|
||||
jump block0"
|
||||
brif v2, block2, block0"
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn switch_optimal_codegen() {
|
||||
let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
|
||||
assert_eq!(
|
||||
assert_eq_output!(
|
||||
func,
|
||||
" jt0 = jump_table [block2, block3]
|
||||
|
||||
block0:
|
||||
v0 = iconst.i8 0
|
||||
v1 = icmp_imm eq v0, 255 ; v0 = 0
|
||||
brnz v1, block1
|
||||
jump block4
|
||||
brif v1, block1, block4
|
||||
|
||||
block4:
|
||||
v2 = uextend.i32 v0 ; v0 = 0
|
||||
@@ -617,20 +589,16 @@ block4:
|
||||
.trim_start_matches("function u0:0() fast {\n")
|
||||
.trim_end_matches("\n}\n")
|
||||
.to_string();
|
||||
assert_eq!(
|
||||
assert_eq_output!(
|
||||
func,
|
||||
" jt0 = jump_table [block2, block1]
|
||||
|
||||
block0:
|
||||
v0 = iconst.i64 0
|
||||
jump block4
|
||||
v1 = icmp_imm ugt v0, 0xffff_ffff ; v0 = 0
|
||||
brif v1, block3, block4
|
||||
|
||||
block4:
|
||||
v1 = icmp_imm.i64 ugt v0, 0xffff_ffff ; v0 = 0
|
||||
brnz v1, block3
|
||||
jump block5
|
||||
|
||||
block5:
|
||||
v2 = ireduce.i32 v0 ; v0 = 0
|
||||
br_table v2, block3, jt0"
|
||||
);
|
||||
@@ -659,21 +627,17 @@ block5:
|
||||
.trim_start_matches("function u0:0() fast {\n")
|
||||
.trim_end_matches("\n}\n")
|
||||
.to_string();
|
||||
assert_eq!(
|
||||
assert_eq_output!(
|
||||
func,
|
||||
" jt0 = jump_table [block2, block1]
|
||||
|
||||
block0:
|
||||
v0 = iconst.i64 0
|
||||
v1 = uextend.i128 v0 ; v0 = 0
|
||||
jump block4
|
||||
v2 = icmp_imm ugt v1, 0xffff_ffff
|
||||
brif v2, block3, block4
|
||||
|
||||
block4:
|
||||
v2 = icmp_imm.i128 ugt v1, 0xffff_ffff
|
||||
brnz v2, block3
|
||||
jump block5
|
||||
|
||||
block5:
|
||||
v3 = ireduce.i32 v1
|
||||
br_table v3, block3, jt0"
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user