cranelift-isle: Specialize for Term at rule root (#5295)

In #5174 we decided it doesn't make sense for a rule to have a
bind-pattern at the root of its left-hand side. There's no Rust value
corresponding to the root value of such a term, because it actually
represents a function declaration with one or more arguments.

This commit takes that to its logical conclusion.

`sema::Rule` previously had an `lhs` field whose value must always be a
`Pattern::Term` variant, and anyone using that structure had to deal
with the possibility of finding the wrong variant there.

Now the relevant fields from that variant are stored directly in `Rule`
instead. Also, the (tiny!) portion of `translate_pattern` which applied
when the pattern was the root term is now inlined in `collect_rules`.

Because `translate_pattern` no longer has to special-case the root term,
we can delete its `rule_term` and `is_root` arguments. That brings it
down to a more manageable four arguments, which means many calls fit on
one line now.
This commit is contained in:
Jamey Sharp
2022-11-18 11:21:08 -08:00
committed by GitHub
parent 4fcbd5bf23
commit 54207d343e
4 changed files with 57 additions and 108 deletions

View File

@@ -152,7 +152,6 @@ pub enum Pattern {
impl Pattern {
pub fn root_term(&self) -> Option<&Ident> {
match self {
&Pattern::BindPattern { ref subpat, .. } => subpat.root_term(),
&Pattern::Term { ref sym, .. } => Some(sym),
_ => None,
}

View File

@@ -436,7 +436,9 @@ pub struct Rule {
/// This rule's id.
pub id: RuleId,
/// The left-hand side pattern that this rule matches.
pub lhs: Pattern,
pub root_term: TermId,
/// Patterns to test against the root term's arguments.
pub args: Vec<Pattern>,
/// Any subpattern "if-let" clauses.
pub iflets: Vec<IfLet>,
/// The right-hand side expression that this rule evaluates upon successful
@@ -578,15 +580,6 @@ impl Pattern {
}
}
/// Get the root term of this pattern, if any.
pub fn root_term(&self) -> Option<TermId> {
match self {
&Pattern::Term(_, term, _) => Some(term),
&Pattern::BindPattern(_, _, ref subpat) => subpat.root_term(),
_ => None,
}
}
/// Recursively visit every sub-pattern.
pub fn visit<V: PatternVisitor>(
&self,
@@ -858,15 +851,11 @@ impl Rule {
let mut vars = HashMap::new();
// Visit the pattern, starting from the root input value.
if let &Pattern::Term(_, term, ref args) = &self.lhs {
let termdata = &termenv.terms[term.index()];
for (i, (subpat, &arg_ty)) in args.iter().zip(termdata.arg_tys.iter()).enumerate() {
let termdata = &termenv.terms[self.root_term.index()];
for (i, (subpat, &arg_ty)) in self.args.iter().zip(termdata.arg_tys.iter()).enumerate() {
let value = visitor.add_arg(i, arg_ty);
visitor.add_pattern(|visitor| subpat.visit(visitor, value, termenv, &mut vars));
}
} else {
unreachable!("Pattern must have a term at the root");
}
// Visit the `if-let` clauses, using `V::ExprVisitor` for the sub-exprs (right-hand sides).
for iflet in self.iflets.iter() {
@@ -1648,28 +1637,30 @@ impl TermEnv {
let pos = rule.pos;
let mut bindings = Bindings::default();
let rule_term = match rule.pattern.root_term() {
Some(name) => match self.get_term_by_name(tyenv, name) {
Some(term) => term,
None => {
tyenv.report_error(
pos,
"Cannot define a rule for an unknown term".to_string(),
);
continue;
}
},
None => {
let (sym, args) = if let ast::Pattern::Term { sym, args, .. } = &rule.pattern {
(sym, args)
} else {
tyenv.report_error(
pos,
"Rule does not have a term at the root of its left-hand side"
.to_string(),
);
continue;
}
};
let pure = match &self.terms[rule_term.index()].kind {
let root_term = if let Some(term) = self.get_term_by_name(tyenv, sym) {
term
} else {
tyenv.report_error(
pos,
"Cannot define a rule for an unknown term".to_string(),
);
continue;
};
let termdata = &self.terms[root_term.index()];
let pure = match &termdata.kind {
&TermKind::Decl { pure, .. } => pure,
_ => {
tyenv.report_error(
@@ -1681,25 +1672,18 @@ impl TermEnv {
}
};
let (lhs, ty) = unwrap_or_continue!(self.translate_pattern(
tyenv,
rule_term,
&rule.pattern,
None,
&mut bindings,
/* is_root = */ true,
));
termdata.check_args_count(args, tyenv, pos, sym);
let args = self.translate_args(args, termdata, tyenv, &mut bindings);
let iflets = rule
.iflets
.iter()
.filter_map(|iflet| {
self.translate_iflet(tyenv, rule_term, iflet, &mut bindings)
})
.filter_map(|iflet| self.translate_iflet(tyenv, iflet, &mut bindings))
.collect();
let rhs = unwrap_or_continue!(self.translate_expr(
tyenv,
&rule.expr,
Some(ty),
Some(termdata.ret_ty),
&mut bindings,
pure,
));
@@ -1707,7 +1691,8 @@ impl TermEnv {
let rid = RuleId(self.rules.len());
self.rules.push(Rule {
id: rid,
lhs,
root_term,
args,
iflets,
rhs,
vars: bindings.seen,
@@ -1798,11 +1783,9 @@ impl TermEnv {
fn translate_pattern(
&self,
tyenv: &mut TypeEnv,
rule_term: TermId,
pat: &ast::Pattern,
expected_ty: Option<TypeId>,
bindings: &mut Bindings,
is_root: bool,
) -> Option<(Pattern, TypeId)> {
log!("translate_pattern: {:?}", pat);
log!("translate_pattern: bindings = {:?}", bindings);
@@ -1858,11 +1841,9 @@ impl TermEnv {
for subpat in subpats {
let (subpat, ty) = unwrap_or_continue!(self.translate_pattern(
tyenv,
rule_term,
&*subpat,
subpat,
expected_ty,
bindings,
/* is_root = */ false,
));
expected_ty = expected_ty.or(Some(ty));
@@ -1885,14 +1866,7 @@ impl TermEnv {
pos,
} => {
// Do the subpattern first so we can resolve the type for sure.
let (subpat, ty) = self.translate_pattern(
tyenv,
rule_term,
&*subpat,
expected_ty,
bindings,
/* is_root = */ false,
)?;
let (subpat, ty) = self.translate_pattern(tyenv, subpat, expected_ty, bindings)?;
let name = tyenv.intern_mut(var);
if bindings.lookup(name).is_some() {
@@ -1984,11 +1958,9 @@ impl TermEnv {
{
return self.translate_pattern(
tyenv,
rule_term,
&expanded_pattern,
Some(expected_ty),
bindings,
/* is_root = */ false,
);
}
@@ -2005,10 +1977,6 @@ impl TermEnv {
termdata.check_args_count(args, tyenv, pos, sym);
match &termdata.kind {
TermKind::Decl {
constructor_kind: Some(ConstructorKind::InternalConstructor),
..
} if is_root && tid == rule_term => {}
TermKind::EnumVariant { .. } => {}
TermKind::Decl {
extractor_kind: Some(ExtractorKind::ExternalExtractor { .. }),
@@ -2024,14 +1992,7 @@ impl TermEnv {
// substitutions.
log!("internal extractor macro args = {:?}", args);
let pat = template.subst_macro_args(&args)?;
return self.translate_pattern(
tyenv,
rule_term,
&pat,
expected_ty,
bindings,
/* is_root = */ false,
);
return self.translate_pattern(tyenv, &pat, expected_ty, bindings);
}
TermKind::Decl {
extractor_kind: None,
@@ -2048,29 +2009,27 @@ impl TermEnv {
}
}
// Resolve subpatterns.
let subpats = args
.iter()
.zip(termdata.arg_tys.iter())
.filter_map(|(arg, &arg_ty)| {
self.translate_pattern(
tyenv,
rule_term,
arg,
Some(arg_ty),
bindings,
/* is_root = */ false,
)
})
.map(|(subpat, _)| subpat)
.collect();
let subpats = self.translate_args(args, termdata, tyenv, bindings);
Some((Pattern::Term(ty, tid, subpats), ty))
}
&ast::Pattern::MacroArg { .. } => unreachable!(),
}
}
fn translate_args(
&self,
args: &Vec<ast::Pattern>,
termdata: &Term,
tyenv: &mut TypeEnv,
bindings: &mut Bindings,
) -> Vec<Pattern> {
args.iter()
.zip(termdata.arg_tys.iter())
.filter_map(|(arg, &arg_ty)| self.translate_pattern(tyenv, arg, Some(arg_ty), bindings))
.map(|(subpat, _)| subpat)
.collect()
}
fn maybe_implicit_convert_expr(
&self,
tyenv: &mut TypeEnv,
@@ -2321,7 +2280,6 @@ impl TermEnv {
fn translate_iflet(
&self,
tyenv: &mut TypeEnv,
rule_term: TermId,
iflet: &ast::IfLet,
bindings: &mut Bindings,
) -> Option<IfLet> {
@@ -2329,14 +2287,7 @@ impl TermEnv {
let rhs =
self.translate_expr(tyenv, &iflet.expr, None, bindings, /* pure = */ true)?;
let ty = rhs.ty();
let (lhs, _lhs_ty) = self.translate_pattern(
tyenv,
rule_term,
&iflet.pattern,
Some(ty),
bindings,
/* is_root = */ false,
)?;
let (lhs, _lhs_ty) = self.translate_pattern(tyenv, &iflet.pattern, Some(ty), bindings)?;
Some(IfLet { lhs, rhs })
}

View File

@@ -290,7 +290,6 @@ impl TermFunctionsBuilder {
log!("termenv: {:?}", termenv);
for rule in termenv.rules.iter() {
let (pattern, expr) = lower_rule(termenv, rule.id);
let root_term = rule.lhs.root_term().unwrap();
log!(
"build:\n- rule {:?}\n- pattern {:?}\n- expr {:?}",
@@ -306,7 +305,7 @@ impl TermFunctionsBuilder {
.chain(std::iter::once(TrieSymbol::EndOfMatch));
self.builders_by_term
.entry(root_term)
.entry(rule.root_term)
.or_insert(TrieNode::Empty)
.insert(rule.prio, symbols, expr);
}

View File

@@ -174,7 +174,7 @@ pub fn build(
let mut errors = Vec::new();
let mut term = HashMap::new();
for rule in termenv.rules.iter() {
term.entry(rule.lhs.root_term().unwrap())
term.entry(rule.root_term)
.or_insert_with(RuleSetBuilder::default)
.add_rule(rule, termenv, tyenv, &mut errors);
}