cranelift-isle: Factor out rule/pattern/expr visitors (#5174)

This makes some rather tricky analysis available to other users besides
the current IR. It shouldn't change current behavior, except if a rule
attempts to bind its root term to a name. There's no Rust value for a
root term, so the existing code silently ignored such bindings and would
panic saying "Variable should already be bound" if a rule attempted to
use such bindings. With this commit, the initial attempt to bind the
name reports the error instead.
This commit is contained in:
Jamey Sharp
2022-11-02 18:18:49 -07:00
committed by GitHub
parent f6a8c81a47
commit 2688b44915
2 changed files with 381 additions and 268 deletions

View File

@@ -3,7 +3,6 @@
use crate::lexer::Pos;
use crate::log;
use crate::sema::*;
use crate::StableMap;
declare_id!(
/// The id of an instruction in a `PatternSequence`.
@@ -239,33 +238,70 @@ impl ExprSequence {
}
}
#[derive(Clone, Copy, Debug)]
enum ValueOrArgs {
Value(Value),
ImplicitTermFromArgs(TermId),
}
impl ValueOrArgs {
fn to_value(&self) -> Option<Value> {
match self {
&ValueOrArgs::Value(v) => Some(v),
_ => None,
}
}
}
impl PatternSequence {
fn add_inst(&mut self, inst: PatternInst) -> InstId {
let id = InstId(self.insts.len());
self.insts.push(inst);
id
}
}
/// Used as an intermediate representation of expressions in the [RuleVisitor] implementation for
/// [PatternSequence].
pub struct ReturnExpr {
seq: ExprSequence,
output: Value,
output_ty: TypeId,
}
impl RuleVisitor for PatternSequence {
type PatternVisitor = Self;
type ExprVisitor = ExprSequence;
type Expr = ReturnExpr;
fn add_arg(&mut self, index: usize, ty: TypeId) -> Value {
let inst = self.add_inst(PatternInst::Arg { index, ty });
Value::Pattern { inst, output: 0 }
}
fn add_pattern<F: FnOnce(&mut Self)>(&mut self, visitor: F) {
visitor(self)
}
fn add_expr<F>(&mut self, visitor: F) -> ReturnExpr
where
F: FnOnce(&mut ExprSequence) -> VisitedExpr<ExprSequence>,
{
let mut expr = ExprSequence::default();
let VisitedExpr { ty, value } = visitor(&mut expr);
let index = 0;
expr.add_inst(ExprInst::Return { index, ty, value });
ReturnExpr {
seq: expr,
output: value,
output_ty: ty,
}
}
fn expr_as_pattern(&mut self, expr: ReturnExpr) -> Value {
let inst = self.add_inst(PatternInst::Expr {
seq: expr.seq,
output: expr.output,
output_ty: expr.output_ty,
});
// Create values for all outputs.
Value::Pattern { inst, output: 0 }
}
fn pattern_as_expr(&mut self, pattern: Value) -> Value {
pattern
}
}
impl PatternVisitor for PatternSequence {
type PatternId = Value;
fn add_match_equal(&mut self, a: Value, b: Value, ty: TypeId) {
self.add_inst(PatternInst::MatchEqual { a, b, ty });
}
@@ -300,8 +336,8 @@ impl PatternSequence {
fn add_extract(
&mut self,
inputs: Vec<Value>,
input_tys: Vec<TypeId>,
input: Value,
input_ty: TypeId,
output_tys: Vec<TypeId>,
term: TermId,
infallible: bool,
@@ -309,8 +345,8 @@ impl PatternSequence {
) -> Vec<Value> {
let outputs = output_tys.len();
let inst = self.add_inst(PatternInst::Extract {
inputs,
input_tys,
inputs: vec![input],
input_tys: vec![input_ty],
output_tys,
term,
infallible,
@@ -320,131 +356,6 @@ impl PatternSequence {
.map(|output| Value::Pattern { inst, output })
.collect()
}
fn add_expr_seq(&mut self, seq: ExprSequence, output: Value, output_ty: TypeId) -> Value {
let inst = self.add_inst(PatternInst::Expr {
seq,
output,
output_ty,
});
// Create values for all outputs.
Value::Pattern { inst, output: 0 }
}
/// Generate PatternInsts to match the given (sub)pattern. Works
/// recursively down the AST.
fn gen_pattern(
&mut self,
input: ValueOrArgs,
termenv: &TermEnv,
pat: &Pattern,
vars: &mut StableMap<VarId, Value>,
) {
match pat {
&Pattern::BindPattern(_ty, var, ref subpat) => {
// Bind the appropriate variable and recurse.
assert!(!vars.contains_key(&var));
if let Some(v) = input.to_value() {
vars.insert(var, v);
}
self.gen_pattern(input, termenv, subpat, vars);
}
&Pattern::Var(ty, var) => {
// Assert that the value matches the existing bound var.
let var_val = vars
.get(&var)
.cloned()
.expect("Variable should already be bound");
let input_val = input
.to_value()
.expect("Cannot match an =var pattern against root term");
self.add_match_equal(input_val, var_val, ty);
}
&Pattern::ConstInt(ty, value) => {
// Assert that the value matches the constant integer.
let input_val = input
.to_value()
.expect("Cannot match an integer pattern against root term");
self.add_match_int(input_val, ty, value);
}
&Pattern::ConstPrim(ty, value) => {
let input_val = input
.to_value()
.expect("Cannot match a constant-primitive pattern against root term");
self.add_match_prim(input_val, ty, value);
}
&Pattern::Term(ty, term, ref args) => {
match input {
ValueOrArgs::ImplicitTermFromArgs(termid) => {
assert_eq!(
termid, term,
"Cannot match a different term against root pattern"
);
let termdata = &termenv.terms[term.index()];
let arg_tys = &termdata.arg_tys[..];
for (i, subpat) in args.iter().enumerate() {
let value = self.add_arg(i, arg_tys[i]);
self.gen_pattern(ValueOrArgs::Value(value), termenv, subpat, vars);
}
}
ValueOrArgs::Value(input) => {
// Determine whether the term has an external extractor or not.
let termdata = &termenv.terms[term.index()];
let arg_values = match &termdata.kind {
TermKind::EnumVariant { variant } => {
self.add_match_variant(input, ty, &termdata.arg_tys, *variant)
}
TermKind::Decl {
extractor_kind: None,
..
} => {
panic!("Pattern invocation of undefined term body")
}
TermKind::Decl {
extractor_kind: Some(ExtractorKind::InternalExtractor { .. }),
..
} => {
panic!("Should have been expanded away")
}
TermKind::Decl {
multi,
extractor_kind:
Some(ExtractorKind::ExternalExtractor { infallible, .. }),
..
} => {
// Evaluate all `input` args.
let inputs = vec![input];
let input_tys = vec![termdata.ret_ty];
let output_tys = args.iter().map(|arg| arg.ty()).collect();
// Invoke the extractor.
self.add_extract(
inputs,
input_tys,
output_tys,
term,
*infallible && !*multi,
*multi,
)
}
};
for (pat, val) in args.iter().zip(arg_values) {
self.gen_pattern(ValueOrArgs::Value(val), termenv, pat, vars);
}
}
}
}
&Pattern::And(_ty, ref children) => {
for child in children {
self.gen_pattern(input, termenv, child, vars);
}
}
&Pattern::Wildcard(_ty) => {
// Nothing!
}
}
}
}
impl ExprSequence {
@@ -453,6 +364,10 @@ impl ExprSequence {
self.insts.push(inst);
id
}
}
impl ExprVisitor for ExprSequence {
type ExprId = Value;
fn add_const_int(&mut self, ty: TypeId, val: i128) -> Value {
let inst = self.add_inst(ExprInst::ConstInt { ty, val });
@@ -495,134 +410,15 @@ impl ExprSequence {
});
Value::Expr { inst, output: 0 }
}
fn add_return(&mut self, ty: TypeId, value: Value) {
self.add_inst(ExprInst::Return {
index: 0,
ty,
value,
});
}
/// Creates a sequence of ExprInsts to generate the given
/// expression value. Returns the value ID as well as the root
/// term ID, if any.
fn gen_expr(
&mut self,
termenv: &TermEnv,
expr: &Expr,
vars: &StableMap<VarId, Value>,
) -> Value {
log!("gen_expr: expr {:?}", expr);
match expr {
&Expr::ConstInt(ty, val) => self.add_const_int(ty, val),
&Expr::ConstPrim(ty, val) => self.add_const_prim(ty, val),
&Expr::Let {
ty: _ty,
ref bindings,
ref body,
} => {
let mut vars = vars.clone();
for &(var, _var_ty, ref var_expr) in bindings {
let var_value = self.gen_expr(termenv, var_expr, &vars);
vars.insert(var, var_value);
}
self.gen_expr(termenv, body, &vars)
}
&Expr::Var(_ty, var_id) => vars.get(&var_id).cloned().unwrap(),
&Expr::Term(ty, term, ref arg_exprs) => {
let termdata = &termenv.terms[term.index()];
let arg_values_tys = arg_exprs
.iter()
.map(|arg_expr| self.gen_expr(termenv, arg_expr, vars))
.zip(termdata.arg_tys.iter().copied())
.collect();
match &termdata.kind {
TermKind::EnumVariant { variant } => {
self.add_create_variant(arg_values_tys, ty, *variant)
}
TermKind::Decl {
constructor_kind: Some(ConstructorKind::InternalConstructor),
multi,
..
} => {
self.add_construct(
arg_values_tys,
ty,
term,
/* infallible = */ false,
*multi,
)
}
TermKind::Decl {
constructor_kind: Some(ConstructorKind::ExternalConstructor { .. }),
pure,
multi,
..
} => {
self.add_construct(
arg_values_tys,
ty,
term,
/* infallible = */ !pure,
*multi,
)
}
TermKind::Decl {
constructor_kind: None,
..
} => panic!("Should have been caught by typechecking"),
}
}
}
}
}
/// Build a sequence from a rule.
pub fn lower_rule(termenv: &TermEnv, rule: RuleId) -> (PatternSequence, ExprSequence) {
let mut pattern_seq: PatternSequence = Default::default();
let mut expr_seq: ExprSequence = Default::default();
let ruledata = &termenv.rules[rule.index()];
log!("lower_rule: ruledata {:?}", ruledata);
let mut pattern_seq = PatternSequence::default();
let mut expr_seq = ruledata.visit(&mut pattern_seq, termenv).seq;
expr_seq.pos = ruledata.pos;
let mut vars = StableMap::new();
let root_term = ruledata
.lhs
.root_term()
.expect("Pattern must have a term at the root");
log!("lower_rule: ruledata {:?}", ruledata,);
// Lower the pattern, starting from the root input value.
pattern_seq.gen_pattern(
ValueOrArgs::ImplicitTermFromArgs(root_term),
termenv,
&ruledata.lhs,
&mut vars,
);
// Lower the `if-let` clauses into the pattern seq, using
// `PatternInst::Expr` for the sub-exprs (right-hand sides).
for iflet in &ruledata.iflets {
let mut subexpr_seq: ExprSequence = Default::default();
let subexpr_ret_value = subexpr_seq.gen_expr(termenv, &iflet.rhs, &mut vars);
subexpr_seq.add_return(iflet.rhs.ty(), subexpr_ret_value);
let pattern_value =
pattern_seq.add_expr_seq(subexpr_seq, subexpr_ret_value, iflet.rhs.ty());
pattern_seq.gen_pattern(
ValueOrArgs::Value(pattern_value),
termenv,
&iflet.lhs,
&mut vars,
);
}
// Lower the expression, making use of the bound variables
// from the pattern.
let rhs_root_val = expr_seq.gen_expr(termenv, &ruledata.rhs, &vars);
// Return the root RHS value.
let output_ty = ruledata.rhs.ty();
expr_seq.add_return(output_ty, rhs_root_val);
(pattern_seq, expr_seq)
}

View File

@@ -22,6 +22,7 @@ use crate::{StableMap, StableSet};
use std::collections::hash_map::Entry;
use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::collections::HashMap;
use std::sync::Arc;
declare_id!(
@@ -500,6 +501,43 @@ pub enum Expr {
},
}
/// Visitor interface for [Pattern]s. Visitors can assign an arbitrary identifier to each
/// subpattern, which is threaded through to subsequent calls into the visitor.
pub trait PatternVisitor {
/// The type of subpattern identifiers.
type PatternId: Copy;
/// Match if `a` and `b` have equal values.
fn add_match_equal(&mut self, a: Self::PatternId, b: Self::PatternId, ty: TypeId);
/// Match if `input` is the given integer constant.
fn add_match_int(&mut self, input: Self::PatternId, ty: TypeId, int_val: i128);
/// Match if `input` is the given primitive constant.
fn add_match_prim(&mut self, input: Self::PatternId, ty: TypeId, val: Sym);
/// Match if `input` is the given enum variant. Returns an identifier for each field within the
/// enum variant. The length of the return list must equal the length of `arg_tys`.
fn add_match_variant(
&mut self,
input: Self::PatternId,
input_ty: TypeId,
arg_tys: &[TypeId],
variant: VariantId,
) -> Vec<Self::PatternId>;
/// Match if the given external extractor succeeds on `input`. Returns an identifier for each
/// return value from the external extractor. The length of the return list must equal the
/// length of `output_tys`.
fn add_extract(
&mut self,
input: Self::PatternId,
input_ty: TypeId,
output_tys: Vec<TypeId>,
term: TermId,
infallible: bool,
multi: bool,
) -> Vec<Self::PatternId>;
}
impl Pattern {
/// Get this pattern's type.
pub fn ty(&self) -> TypeId {
@@ -522,6 +560,114 @@ impl Pattern {
_ => None,
}
}
/// Recursively visit every sub-pattern.
pub fn visit<V: PatternVisitor>(
&self,
visitor: &mut V,
input: V::PatternId,
termenv: &TermEnv,
vars: &mut HashMap<VarId, V::PatternId>,
) {
match self {
&Pattern::BindPattern(_ty, var, ref subpat) => {
// Bind the appropriate variable and recurse.
assert!(!vars.contains_key(&var));
vars.insert(var, input);
subpat.visit(visitor, input, termenv, vars);
}
&Pattern::Var(ty, var) => {
// Assert that the value matches the existing bound var.
let var_val = vars
.get(&var)
.copied()
.expect("Variable should already be bound");
visitor.add_match_equal(input, var_val, ty);
}
&Pattern::ConstInt(ty, value) => visitor.add_match_int(input, ty, value),
&Pattern::ConstPrim(ty, value) => visitor.add_match_prim(input, ty, value),
&Pattern::Term(ty, term, ref args) => {
// Determine whether the term has an external extractor or not.
let termdata = &termenv.terms[term.index()];
let arg_values = match &termdata.kind {
TermKind::EnumVariant { variant } => {
visitor.add_match_variant(input, ty, &termdata.arg_tys, *variant)
}
TermKind::Decl {
extractor_kind: None,
..
} => {
panic!("Pattern invocation of undefined term body")
}
TermKind::Decl {
extractor_kind: Some(ExtractorKind::InternalExtractor { .. }),
..
} => {
panic!("Should have been expanded away")
}
TermKind::Decl {
multi,
extractor_kind: Some(ExtractorKind::ExternalExtractor { infallible, .. }),
..
} => {
// Evaluate all `input` args.
let output_tys = args.iter().map(|arg| arg.ty()).collect();
// Invoke the extractor.
visitor.add_extract(
input,
termdata.ret_ty,
output_tys,
term,
*infallible && !*multi,
*multi,
)
}
};
for (pat, val) in args.iter().zip(arg_values) {
pat.visit(visitor, val, termenv, vars);
}
}
&Pattern::And(_ty, ref children) => {
for child in children {
child.visit(visitor, input, termenv, vars);
}
}
&Pattern::Wildcard(_ty) => {
// Nothing!
}
}
}
}
/// Visitor interface for [Expr]s. Visitors can return an arbitrary identifier for each
/// subexpression, which is threaded through to subsequent calls into the visitor.
pub trait ExprVisitor {
/// The type of subexpression identifiers.
type ExprId: Copy;
/// Construct a constant integer.
fn add_const_int(&mut self, ty: TypeId, val: i128) -> Self::ExprId;
/// Construct a primitive constant.
fn add_const_prim(&mut self, ty: TypeId, val: Sym) -> Self::ExprId;
/// Construct an enum variant with the given `inputs` assigned to the variant's fields in order.
fn add_create_variant(
&mut self,
inputs: Vec<(Self::ExprId, TypeId)>,
ty: TypeId,
variant: VariantId,
) -> Self::ExprId;
/// Call an external constructor with the given `inputs` as arguments.
fn add_construct(
&mut self,
inputs: Vec<(Self::ExprId, TypeId)>,
ty: TypeId,
term: TermId,
infallible: bool,
multi: bool,
) -> Self::ExprId;
}
impl Expr {
@@ -535,6 +681,177 @@ impl Expr {
&Self::Let { ty: t, .. } => t,
}
}
/// Recursively visit every subexpression.
pub fn visit<V: ExprVisitor>(
&self,
visitor: &mut V,
termenv: &TermEnv,
vars: &HashMap<VarId, V::ExprId>,
) -> V::ExprId {
log!("Expr::visit: expr {:?}", self);
match self {
&Expr::ConstInt(ty, val) => visitor.add_const_int(ty, val),
&Expr::ConstPrim(ty, val) => visitor.add_const_prim(ty, val),
&Expr::Let {
ty: _ty,
ref bindings,
ref body,
} => {
let mut vars = vars.clone();
for &(var, _var_ty, ref var_expr) in bindings {
let var_value = var_expr.visit(visitor, termenv, &vars);
vars.insert(var, var_value);
}
body.visit(visitor, termenv, &vars)
}
&Expr::Var(_ty, var_id) => *vars.get(&var_id).unwrap(),
&Expr::Term(ty, term, ref arg_exprs) => {
let termdata = &termenv.terms[term.index()];
let arg_values_tys = arg_exprs
.iter()
.map(|arg_expr| arg_expr.visit(visitor, termenv, vars))
.zip(termdata.arg_tys.iter().copied())
.collect();
match &termdata.kind {
TermKind::EnumVariant { variant } => {
visitor.add_create_variant(arg_values_tys, ty, *variant)
}
TermKind::Decl {
constructor_kind: Some(ConstructorKind::InternalConstructor),
multi,
..
} => {
visitor.add_construct(
arg_values_tys,
ty,
term,
/* infallible = */ false,
*multi,
)
}
TermKind::Decl {
constructor_kind: Some(ConstructorKind::ExternalConstructor { .. }),
pure,
multi,
..
} => {
visitor.add_construct(
arg_values_tys,
ty,
term,
/* infallible = */ !pure,
*multi,
)
}
TermKind::Decl {
constructor_kind: None,
..
} => panic!("Should have been caught by typechecking"),
}
}
}
}
fn visit_in_rule<V: RuleVisitor>(
&self,
visitor: &mut V,
termenv: &TermEnv,
vars: &HashMap<VarId, <V::PatternVisitor as PatternVisitor>::PatternId>,
) -> V::Expr {
let var_exprs = vars
.iter()
.map(|(&var, &val)| (var, visitor.pattern_as_expr(val)))
.collect();
visitor.add_expr(|visitor| VisitedExpr {
ty: self.ty(),
value: self.visit(visitor, termenv, &var_exprs),
})
}
}
/// Information about an expression after it has been fully visited in [RuleVisitor::add_expr].
#[derive(Clone, Copy)]
pub struct VisitedExpr<V: ExprVisitor> {
/// The type of the top-level expression.
pub ty: TypeId,
/// The identifier returned by the visitor for the top-level expression.
pub value: V::ExprId,
}
/// Visitor interface for [Rule]s. Visitors must be able to visit patterns by implementing
/// [PatternVisitor], and to visit expressions by providing a type that implements [ExprVisitor].
pub trait RuleVisitor {
/// The type of pattern visitors constructed by [RuleVisitor::add_pattern].
type PatternVisitor: PatternVisitor;
/// The type of expression visitors constructed by [RuleVisitor::add_expr].
type ExprVisitor: ExprVisitor;
/// The type returned from [RuleVisitor::add_expr], which may be exchanged for a subpattern
/// identifier using [RuleVisitor::expr_as_pattern].
type Expr;
/// Visit one of the arguments to the top-level pattern.
fn add_arg(
&mut self,
index: usize,
ty: TypeId,
) -> <Self::PatternVisitor as PatternVisitor>::PatternId;
/// Visit a pattern, used once for the rule's left-hand side and once for each if-let. You can
/// determine which part of the rule the pattern comes from based on whether the `PatternId`
/// passed to the first call to this visitor came from `add_arg` or `expr_as_pattern`.
fn add_pattern<F>(&mut self, visitor: F)
where
F: FnOnce(&mut Self::PatternVisitor);
/// Visit an expression, used once for each if-let and once for the rule's right-hand side.
fn add_expr<F>(&mut self, visitor: F) -> Self::Expr
where
F: FnOnce(&mut Self::ExprVisitor) -> VisitedExpr<Self::ExprVisitor>;
/// Given an expression from [RuleVisitor::add_expr], return an identifier that can be used with
/// a pattern visitor in [RuleVisitor::add_pattern].
fn expr_as_pattern(
&mut self,
expr: Self::Expr,
) -> <Self::PatternVisitor as PatternVisitor>::PatternId;
/// Given an identifier from the pattern visitor, return an identifier that can be used with
/// the expression visitor.
fn pattern_as_expr(
&mut self,
pattern: <Self::PatternVisitor as PatternVisitor>::PatternId,
) -> <Self::ExprVisitor as ExprVisitor>::ExprId;
}
impl Rule {
/// Recursively visit every pattern and expression in this rule. Returns the [RuleVisitor::Expr]
/// that was returned from [RuleVisitor::add_expr] when that function was called on the rule's
/// right-hand side.
pub fn visit<V: RuleVisitor>(&self, visitor: &mut V, termenv: &TermEnv) -> V::Expr {
let mut vars = HashMap::new();
// Visit the pattern, starting from the root input value.
if let &Pattern::Term(_, term, ref args) = &self.lhs {
let termdata = &termenv.terms[term.index()];
for (i, (subpat, &arg_ty)) in args.iter().zip(termdata.arg_tys.iter()).enumerate() {
let value = visitor.add_arg(i, arg_ty);
visitor.add_pattern(|visitor| subpat.visit(visitor, value, termenv, &mut vars));
}
} else {
unreachable!("Pattern must have a term at the root");
}
// Visit the `if-let` clauses, using `V::ExprVisitor` for the sub-exprs (right-hand sides).
for iflet in self.iflets.iter() {
let subexpr = iflet.rhs.visit_in_rule(visitor, termenv, &vars);
let value = visitor.expr_as_pattern(subexpr);
visitor.add_pattern(|visitor| iflet.lhs.visit(visitor, value, termenv, &mut vars));
}
// Visit the rule's right-hand side, making use of the bound variables from the pattern.
self.rhs.visit_in_rule(visitor, termenv, &vars)
}
}
/// Given an `Option<T>`, unwrap the inner `T` value, or `continue` if it is