Generate match {} statements by merging adjacent MatchVariant trie edges.
This commit is contained in:
@@ -1,7 +1,3 @@
|
||||
- Convert series of if-lets from MatchVariants into a match { ... }.
|
||||
- Should be pretty simple; just need to recognize adjacent edges in
|
||||
priority-order that are all MatchVariants on the same input.
|
||||
|
||||
- Document semantics carefully, especially wrt extractors.
|
||||
|
||||
- Verify that priorities work as expected.
|
||||
- Build out an initial set of bindings for Cranelift LowerCtx with extractors
|
||||
for instruction info.
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
(type A (enum
|
||||
(A1 (x B) (y B))))
|
||||
(type B (enum
|
||||
(B1 (x u32))))
|
||||
(B1 (x u32))
|
||||
(B2 (x u32))))
|
||||
|
||||
(decl A2B (A) B)
|
||||
|
||||
@@ -14,6 +15,10 @@
|
||||
(A2B (A.A1 (B.B1 x) _))
|
||||
(B.B1 x))
|
||||
|
||||
(rule 0
|
||||
(A2B (A.A1 (B.B2 x) _))
|
||||
(B.B1 x))
|
||||
|
||||
(rule -1
|
||||
(A2B (A.A1 _ _))
|
||||
(B.B1 42))
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::ir::{lower_rule, ExprInst, ExprSequence, InstId, PatternInst, PatternSequence, Value};
|
||||
use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId};
|
||||
use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId, Variant};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fmt::Write;
|
||||
|
||||
@@ -309,7 +309,7 @@ impl TrieNode {
|
||||
// ranges. Maybe the last edge we saw with the op
|
||||
// we're inserting can have its range expanded,
|
||||
// however.
|
||||
if last_edge_with_op.is_some() && edges[last_edge_with_op.unwrap()].symbol == op {
|
||||
if last_edge_with_op.is_some() {
|
||||
// Move it to the end of the run of equal-unit-range ops.
|
||||
edges.swap(last_edge_with_op.unwrap(), i - 1);
|
||||
edge = Some(i - 1);
|
||||
@@ -901,6 +901,38 @@ impl<'a> Codegen<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn match_variant_binders(
|
||||
&self,
|
||||
variant: &Variant,
|
||||
arg_tys: &[TypeId],
|
||||
id: InstId,
|
||||
ctx: &mut BodyContext,
|
||||
) -> Vec<String> {
|
||||
arg_tys
|
||||
.iter()
|
||||
.zip(variant.fields.iter())
|
||||
.enumerate()
|
||||
.map(|(i, (ty, field))| {
|
||||
let value = Value::Pattern {
|
||||
inst: id,
|
||||
output: i,
|
||||
};
|
||||
let valuename = self.value_name(&value);
|
||||
let fieldname = &self.typeenv.syms[field.name.index()];
|
||||
match &self.typeenv.types[ty.index()] {
|
||||
&Type::Primitive(..) => {
|
||||
self.define_val(&value, ctx, /* is_ref = */ false);
|
||||
format!("{}: {}", fieldname, valuename)
|
||||
}
|
||||
&Type::Enum { .. } => {
|
||||
self.define_val(&value, ctx, /* is_ref = */ true);
|
||||
format!("{}: ref {}", fieldname, valuename)
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
/// Returns a `bool` indicating whether this pattern inst is
|
||||
/// infallible.
|
||||
fn generate_pattern_inst(
|
||||
@@ -961,29 +993,7 @@ impl<'a> Codegen<'a> {
|
||||
let ty_name = self.type_name(input_ty, /* is_ref = */ Some("&"));
|
||||
let variant = &variants[variant.index()];
|
||||
let variantname = &self.typeenv.syms[variant.name.index()];
|
||||
let args = arg_tys
|
||||
.iter()
|
||||
.zip(variant.fields.iter())
|
||||
.enumerate()
|
||||
.map(|(i, (ty, field))| {
|
||||
let value = Value::Pattern {
|
||||
inst: id,
|
||||
output: i,
|
||||
};
|
||||
let valuename = self.value_name(&value);
|
||||
let fieldname = &self.typeenv.syms[field.name.index()];
|
||||
match &self.typeenv.types[ty.index()] {
|
||||
&Type::Primitive(..) => {
|
||||
self.define_val(&value, ctx, /* is_ref = */ false);
|
||||
format!("{}: {}", fieldname, valuename)
|
||||
}
|
||||
&Type::Enum { .. } => {
|
||||
self.define_val(&value, ctx, /* is_ref = */ true);
|
||||
format!("{}: ref {}", fieldname, valuename)
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let args = self.match_variant_binders(variant, &arg_tys[..], id, ctx);
|
||||
writeln!(
|
||||
code,
|
||||
"{}if let {}::{} {{ {} }} = {} {{",
|
||||
@@ -1097,34 +1107,162 @@ impl<'a> Codegen<'a> {
|
||||
&TrieNode::Decision { ref edges } => {
|
||||
let subindent = format!("{} ", indent);
|
||||
// if this is a decision node, generate each match op
|
||||
// in turn (in priority order).
|
||||
for &TrieEdge {
|
||||
ref symbol,
|
||||
ref node,
|
||||
..
|
||||
} in edges
|
||||
{
|
||||
match symbol {
|
||||
&TrieSymbol::EndOfMatch => {
|
||||
returned = self.generate_body(code, depth + 1, node, indent, ctx)?;
|
||||
}
|
||||
&TrieSymbol::Match { ref op } => {
|
||||
let id = InstId(depth);
|
||||
let infallible =
|
||||
self.generate_pattern_inst(code, id, op, indent, ctx)?;
|
||||
let sub_returned =
|
||||
self.generate_body(code, depth + 1, node, &subindent, ctx)?;
|
||||
writeln!(code, "{}}}", indent)?;
|
||||
if infallible && sub_returned {
|
||||
returned = true;
|
||||
// in turn (in priority order). Sort the ops within
|
||||
// each priority, and gather together adjacent
|
||||
// MatchVariant ops with the same input and disjoint
|
||||
// variants in order to create a `match` rather than a
|
||||
// chain of if-lets.
|
||||
let mut edges = edges.clone();
|
||||
edges.sort_by(|e1, e2| (-e1.range.0, &e1.symbol).cmp(&(-e2.range.0, &e2.symbol)));
|
||||
|
||||
let mut i = 0;
|
||||
while i < edges.len() {
|
||||
let mut last = i;
|
||||
let mut adjacent_variants = HashSet::new();
|
||||
let mut adjacent_variant_input = None;
|
||||
log::trace!("edge: {:?}", edges[i]);
|
||||
while last < edges.len() {
|
||||
match &edges[last].symbol {
|
||||
&TrieSymbol::Match {
|
||||
op: PatternInst::MatchVariant { input, variant, .. },
|
||||
} => {
|
||||
if adjacent_variant_input.is_none() {
|
||||
adjacent_variant_input = Some(input);
|
||||
}
|
||||
if adjacent_variant_input == Some(input)
|
||||
&& !adjacent_variants.contains(&variant)
|
||||
{
|
||||
adjacent_variants.insert(variant);
|
||||
last += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// edges[i..last] is a run of adjacent
|
||||
// MatchVariants (possibly an empty one). Only use
|
||||
// a `match` form if there are at least two
|
||||
// adjacent options.
|
||||
if last - i > 1 {
|
||||
self.generate_body_matches(code, depth, &edges[i..last], indent, ctx)?;
|
||||
i = last;
|
||||
continue;
|
||||
} else {
|
||||
let &TrieEdge {
|
||||
ref symbol,
|
||||
ref node,
|
||||
..
|
||||
} = &edges[i];
|
||||
i += 1;
|
||||
|
||||
match symbol {
|
||||
&TrieSymbol::EndOfMatch => {
|
||||
returned =
|
||||
self.generate_body(code, depth + 1, node, indent, ctx)?;
|
||||
}
|
||||
&TrieSymbol::Match { ref op } => {
|
||||
let id = InstId(depth);
|
||||
let infallible =
|
||||
self.generate_pattern_inst(code, id, op, indent, ctx)?;
|
||||
let sub_returned =
|
||||
self.generate_body(code, depth + 1, node, &subindent, ctx)?;
|
||||
writeln!(code, "{}}}", indent)?;
|
||||
if infallible && sub_returned {
|
||||
returned = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(returned)
|
||||
}
|
||||
|
||||
fn generate_body_matches(
|
||||
&self,
|
||||
code: &mut dyn Write,
|
||||
depth: usize,
|
||||
edges: &[TrieEdge],
|
||||
indent: &str,
|
||||
ctx: &mut BodyContext,
|
||||
) -> Result<(), Error> {
|
||||
let (input, input_ty) = match &edges[0].symbol {
|
||||
&TrieSymbol::Match {
|
||||
op:
|
||||
PatternInst::MatchVariant {
|
||||
input, input_ty, ..
|
||||
},
|
||||
} => (input, input_ty),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let (input_ty_sym, variants) = match &self.typeenv.types[input_ty.index()] {
|
||||
&Type::Enum {
|
||||
ref name,
|
||||
ref variants,
|
||||
..
|
||||
} => (name, variants),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let input_ty_name = &self.typeenv.syms[input_ty_sym.index()];
|
||||
|
||||
// Emit the `match`.
|
||||
writeln!(
|
||||
code,
|
||||
"{}match {} {{",
|
||||
indent,
|
||||
self.value_by_ref(&input, ctx)
|
||||
)?;
|
||||
|
||||
// Emit each case.
|
||||
for &TrieEdge {
|
||||
ref symbol,
|
||||
ref node,
|
||||
..
|
||||
} in edges
|
||||
{
|
||||
let id = InstId(depth);
|
||||
let (variant, arg_tys) = match symbol {
|
||||
&TrieSymbol::Match {
|
||||
op:
|
||||
PatternInst::MatchVariant {
|
||||
variant,
|
||||
ref arg_tys,
|
||||
..
|
||||
},
|
||||
} => (variant, arg_tys),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let variantinfo = &variants[variant.index()];
|
||||
let variantname = &self.typeenv.syms[variantinfo.name.index()];
|
||||
let fields = self.match_variant_binders(variantinfo, arg_tys, id, ctx);
|
||||
writeln!(
|
||||
code,
|
||||
"{} &{}::{} {{ {} }} => {{",
|
||||
indent,
|
||||
input_ty_name,
|
||||
variantname,
|
||||
fields.join(", ")
|
||||
)?;
|
||||
let subindent = format!("{} ", indent);
|
||||
self.generate_body(code, depth + 1, node, &subindent, ctx)?;
|
||||
writeln!(code, "{} }}", indent)?;
|
||||
}
|
||||
|
||||
// Always add a catchall, because we don't do exhaustiveness
|
||||
// checking on the MatcHVariants.
|
||||
writeln!(code, "{} _ => {{}}", indent)?;
|
||||
|
||||
writeln!(code, "{}}}", indent)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user