Generate match {} statements by merging adjacent MatchVariant trie edges.

This commit is contained in:
Chris Fallin
2021-09-05 17:21:15 -07:00
parent ed4c857082
commit 3ccbaf0f69
3 changed files with 191 additions and 52 deletions

View File

@@ -1,7 +1,3 @@
- Convert series of if-lets from MatchVariants into a match { ... }.
- Should be pretty simple; just need to recognize adjacent edges in
priority-order that are all MatchVariants on the same input.
- Document semantics carefully, especially wrt extractors. - Document semantics carefully, especially wrt extractors.
- Build out an initial set of bindings for Cranelift LowerCtx with extractors
- Verify that priorities work as expected. for instruction info.

View File

@@ -2,7 +2,8 @@
(type A (enum (type A (enum
(A1 (x B) (y B)))) (A1 (x B) (y B))))
(type B (enum (type B (enum
(B1 (x u32)))) (B1 (x u32))
(B2 (x u32))))
(decl A2B (A) B) (decl A2B (A) B)
@@ -14,6 +15,10 @@
(A2B (A.A1 (B.B1 x) _)) (A2B (A.A1 (B.B1 x) _))
(B.B1 x)) (B.B1 x))
(rule 0
(A2B (A.A1 (B.B2 x) _))
(B.B1 x))
(rule -1 (rule -1
(A2B (A.A1 _ _)) (A2B (A.A1 _ _))
(B.B1 42)) (B.B1 42))

View File

@@ -2,7 +2,7 @@
use crate::error::Error; use crate::error::Error;
use crate::ir::{lower_rule, ExprInst, ExprSequence, InstId, PatternInst, PatternSequence, Value}; use crate::ir::{lower_rule, ExprInst, ExprSequence, InstId, PatternInst, PatternSequence, Value};
use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId}; use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId, Variant};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::fmt::Write; use std::fmt::Write;
@@ -309,7 +309,7 @@ impl TrieNode {
// ranges. Maybe the last edge we saw with the op // ranges. Maybe the last edge we saw with the op
// we're inserting can have its range expanded, // we're inserting can have its range expanded,
// however. // however.
if last_edge_with_op.is_some() && edges[last_edge_with_op.unwrap()].symbol == op { if last_edge_with_op.is_some() {
// Move it to the end of the run of equal-unit-range ops. // Move it to the end of the run of equal-unit-range ops.
edges.swap(last_edge_with_op.unwrap(), i - 1); edges.swap(last_edge_with_op.unwrap(), i - 1);
edge = Some(i - 1); edge = Some(i - 1);
@@ -901,6 +901,38 @@ impl<'a> Codegen<'a> {
Ok(()) Ok(())
} }
fn match_variant_binders(
&self,
variant: &Variant,
arg_tys: &[TypeId],
id: InstId,
ctx: &mut BodyContext,
) -> Vec<String> {
arg_tys
.iter()
.zip(variant.fields.iter())
.enumerate()
.map(|(i, (ty, field))| {
let value = Value::Pattern {
inst: id,
output: i,
};
let valuename = self.value_name(&value);
let fieldname = &self.typeenv.syms[field.name.index()];
match &self.typeenv.types[ty.index()] {
&Type::Primitive(..) => {
self.define_val(&value, ctx, /* is_ref = */ false);
format!("{}: {}", fieldname, valuename)
}
&Type::Enum { .. } => {
self.define_val(&value, ctx, /* is_ref = */ true);
format!("{}: ref {}", fieldname, valuename)
}
}
})
.collect::<Vec<_>>()
}
/// Returns a `bool` indicating whether this pattern inst is /// Returns a `bool` indicating whether this pattern inst is
/// infallible. /// infallible.
fn generate_pattern_inst( fn generate_pattern_inst(
@@ -961,29 +993,7 @@ impl<'a> Codegen<'a> {
let ty_name = self.type_name(input_ty, /* is_ref = */ Some("&")); let ty_name = self.type_name(input_ty, /* is_ref = */ Some("&"));
let variant = &variants[variant.index()]; let variant = &variants[variant.index()];
let variantname = &self.typeenv.syms[variant.name.index()]; let variantname = &self.typeenv.syms[variant.name.index()];
let args = arg_tys let args = self.match_variant_binders(variant, &arg_tys[..], id, ctx);
.iter()
.zip(variant.fields.iter())
.enumerate()
.map(|(i, (ty, field))| {
let value = Value::Pattern {
inst: id,
output: i,
};
let valuename = self.value_name(&value);
let fieldname = &self.typeenv.syms[field.name.index()];
match &self.typeenv.types[ty.index()] {
&Type::Primitive(..) => {
self.define_val(&value, ctx, /* is_ref = */ false);
format!("{}: {}", fieldname, valuename)
}
&Type::Enum { .. } => {
self.define_val(&value, ctx, /* is_ref = */ true);
format!("{}: ref {}", fieldname, valuename)
}
}
})
.collect::<Vec<_>>();
writeln!( writeln!(
code, code,
"{}if let {}::{} {{ {} }} = {} {{", "{}if let {}::{} {{ {} }} = {} {{",
@@ -1097,34 +1107,162 @@ impl<'a> Codegen<'a> {
&TrieNode::Decision { ref edges } => { &TrieNode::Decision { ref edges } => {
let subindent = format!("{} ", indent); let subindent = format!("{} ", indent);
// if this is a decision node, generate each match op // if this is a decision node, generate each match op
// in turn (in priority order). // in turn (in priority order). Sort the ops within
for &TrieEdge { // each priority, and gather together adjacent
ref symbol, // MatchVariant ops with the same input and disjoint
ref node, // variants in order to create a `match` rather than a
.. // chain of if-lets.
} in edges let mut edges = edges.clone();
{ edges.sort_by(|e1, e2| (-e1.range.0, &e1.symbol).cmp(&(-e2.range.0, &e2.symbol)));
match symbol {
&TrieSymbol::EndOfMatch => { let mut i = 0;
returned = self.generate_body(code, depth + 1, node, indent, ctx)?; while i < edges.len() {
} let mut last = i;
&TrieSymbol::Match { ref op } => { let mut adjacent_variants = HashSet::new();
let id = InstId(depth); let mut adjacent_variant_input = None;
let infallible = log::trace!("edge: {:?}", edges[i]);
self.generate_pattern_inst(code, id, op, indent, ctx)?; while last < edges.len() {
let sub_returned = match &edges[last].symbol {
self.generate_body(code, depth + 1, node, &subindent, ctx)?; &TrieSymbol::Match {
writeln!(code, "{}}}", indent)?; op: PatternInst::MatchVariant { input, variant, .. },
if infallible && sub_returned { } => {
returned = true; if adjacent_variant_input.is_none() {
adjacent_variant_input = Some(input);
}
if adjacent_variant_input == Some(input)
&& !adjacent_variants.contains(&variant)
{
adjacent_variants.insert(variant);
last += 1;
} else {
break;
}
}
_ => {
break; break;
} }
} }
} }
// edges[i..last] is a run of adjacent
// MatchVariants (possibly an empty one). Only use
// a `match` form if there are at least two
// adjacent options.
if last - i > 1 {
self.generate_body_matches(code, depth, &edges[i..last], indent, ctx)?;
i = last;
continue;
} else {
let &TrieEdge {
ref symbol,
ref node,
..
} = &edges[i];
i += 1;
match symbol {
&TrieSymbol::EndOfMatch => {
returned =
self.generate_body(code, depth + 1, node, indent, ctx)?;
}
&TrieSymbol::Match { ref op } => {
let id = InstId(depth);
let infallible =
self.generate_pattern_inst(code, id, op, indent, ctx)?;
let sub_returned =
self.generate_body(code, depth + 1, node, &subindent, ctx)?;
writeln!(code, "{}}}", indent)?;
if infallible && sub_returned {
returned = true;
break;
}
}
}
}
} }
} }
} }
Ok(returned) Ok(returned)
} }
fn generate_body_matches(
&self,
code: &mut dyn Write,
depth: usize,
edges: &[TrieEdge],
indent: &str,
ctx: &mut BodyContext,
) -> Result<(), Error> {
let (input, input_ty) = match &edges[0].symbol {
&TrieSymbol::Match {
op:
PatternInst::MatchVariant {
input, input_ty, ..
},
} => (input, input_ty),
_ => unreachable!(),
};
let (input_ty_sym, variants) = match &self.typeenv.types[input_ty.index()] {
&Type::Enum {
ref name,
ref variants,
..
} => (name, variants),
_ => unreachable!(),
};
let input_ty_name = &self.typeenv.syms[input_ty_sym.index()];
// Emit the `match`.
writeln!(
code,
"{}match {} {{",
indent,
self.value_by_ref(&input, ctx)
)?;
// Emit each case.
for &TrieEdge {
ref symbol,
ref node,
..
} in edges
{
let id = InstId(depth);
let (variant, arg_tys) = match symbol {
&TrieSymbol::Match {
op:
PatternInst::MatchVariant {
variant,
ref arg_tys,
..
},
} => (variant, arg_tys),
_ => unreachable!(),
};
let variantinfo = &variants[variant.index()];
let variantname = &self.typeenv.syms[variantinfo.name.index()];
let fields = self.match_variant_binders(variantinfo, arg_tys, id, ctx);
writeln!(
code,
"{} &{}::{} {{ {} }} => {{",
indent,
input_ty_name,
variantname,
fields.join(", ")
)?;
let subindent = format!("{} ", indent);
self.generate_body(code, depth + 1, node, &subindent, ctx)?;
writeln!(code, "{} }}", indent)?;
}
// Always add a catchall, because we don't do exhaustiveness
// checking on the MatcHVariants.
writeln!(code, "{} _ => {{}}", indent)?;
writeln!(code, "{}}}", indent)?;
Ok(())
}
} }