cranelift-isle: Unify expressions and bindings (#5294)

As it turns out, that distinction was not necessary for this
representation. Removing it eliminates some complexity around wrapping
expressions as bindings and vice versa. It also clears up some confusion
about which category to put certain constructs in (arguments and
extractors) by refusing to have different categories.

While I was writing this patch I also realized that `add_match_variant`
and `normalize_equivalence_classes` both need to do fundamentally the
same things with enum variants, so I refactored them to share code and
make their relationship clearer.

Finally, I reviewed all the comments in this file and fixed some places
where they could be more clear.
This commit is contained in:
Jamey Sharp
2022-11-17 16:00:59 -08:00
committed by GitHub
parent 3b6544dc66
commit 9a44ef7443

View File

@@ -2,7 +2,7 @@
//! to closely reflect the operations we can implement in Rust, to make code generation easy. //! to closely reflect the operations we can implement in Rust, to make code generation easy.
use crate::error::{Error, Source, Span}; use crate::error::{Error, Source, Span};
use crate::lexer::Pos; use crate::lexer::Pos;
use crate::sema::{self, RuleVisitor}; use crate::sema;
use crate::DisjointSets; use crate::DisjointSets;
use std::collections::{hash_map::Entry, HashMap}; use std::collections::{hash_map::Entry, HashMap};
@@ -12,9 +12,6 @@ pub struct TupleIndex(u8);
/// A hash-consed identifier for a binding, stored in a [RuleSet]. /// A hash-consed identifier for a binding, stored in a [RuleSet].
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct BindingId(u16); pub struct BindingId(u16);
/// A hash-consed identifier for an expression, stored in a [RuleSet].
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct ExprId(u16);
impl BindingId { impl BindingId {
/// Get the index of this id. /// Get the index of this id.
@@ -23,24 +20,10 @@ impl BindingId {
} }
} }
impl ExprId { /// Bindings are anything which can be bound to a variable name in Rust. This includes expressions,
/// Get the index of this id. /// such as constants or function calls; but it also includes names bound in pattern matches.
pub fn index(self) -> usize {
self.0.into()
}
}
/// Expressions construct new values. Rust pattern matching can only destructure existing values,
/// not call functions or construct new values. So `if-let` and external extractor invocations need
/// to interrupt pattern matching in order to evaluate a suitable expression. These expressions are
/// also used when evaluating the right-hand side of a rule.
#[derive(Clone, Debug, Eq, Hash, PartialEq)] #[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum Expr { pub enum Binding {
/// A binding from some sequence of pattern matches, used as an expression.
Binding {
/// Which binding site is being used as an expression?
source: BindingId,
},
/// Evaluates to the given integer literal. /// Evaluates to the given integer literal.
ConstInt { ConstInt {
/// The constant value. /// The constant value.
@@ -61,39 +44,28 @@ pub enum Expr {
/// Which extractor should be called? /// Which extractor should be called?
term: sema::TermId, term: sema::TermId,
/// What expression should be passed to the extractor? /// What expression should be passed to the extractor?
parameter: ExprId, parameter: BindingId,
}, },
/// The result of calling an external constructor. /// The result of calling an external constructor.
Constructor { Constructor {
/// Which constructor should be called? /// Which constructor should be called?
term: sema::TermId, term: sema::TermId,
/// What expressions should be passed to the constructor? /// What expressions should be passed to the constructor?
parameters: Box<[ExprId]>, parameters: Box<[BindingId]>,
}, },
/// The result of constructing an enum variant. /// The result of constructing an enum variant.
Variant { MakeVariant {
/// Which enum type should be constructed? /// Which enum type should be constructed?
ty: sema::TypeId, ty: sema::TypeId,
/// Which variant of that enum should be constructed? /// Which variant of that enum should be constructed?
variant: sema::VariantId, variant: sema::VariantId,
/// What expressions should be provided for this variant's fields? /// What expressions should be provided for this variant's fields?
fields: Box<[ExprId]>, fields: Box<[BindingId]>,
}, },
} /// Pattern-match one of the previous bindings against an enum variant and produce a new binding
/// from one of its fields. There must be a corresponding [Constraint::Variant] for each
/// Binding sites are the result of Rust pattern matching. This is the dual of an expression: while /// `source`/`variant` pair that appears in some `MatchVariant` binding.
/// expressions build up values, bindings take values apart. MatchVariant {
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum Binding {
/// A match begins at the result of some expression that produces a Rust value.
Expr {
/// Which expression is being matched?
constructor: ExprId,
},
/// After some sequence of matches, we'll match one of the previous bindings against an enum
/// variant and produce a new binding from one of its fields. There must be a matching
/// [Constraint] for each `source`/`variant` pair that appears in a binding.
Variant {
/// Which binding is being matched? /// Which binding is being matched?
source: BindingId, source: BindingId,
/// Which enum variant are we pulling binding sites from? This is somewhat redundant with /// Which enum variant are we pulling binding sites from? This is somewhat redundant with
@@ -105,17 +77,17 @@ pub enum Binding {
/// get the field names. /// get the field names.
field: TupleIndex, field: TupleIndex,
}, },
/// After some sequence of matches, we'll match one of the previous bindings against /// Pattern-match one of the previous bindings against `Option::Some` and produce a new binding
/// `Option::Some` and produce a new binding from its contents. (This currently only happens /// from its contents. There must be a corresponding [Constraint::Some] for each `source` that
/// with external extractors.) /// appears in a `MatchSome` binding. (This currently only happens with external extractors.)
Some { MatchSome {
/// Which binding is being matched? /// Which binding is being matched?
source: BindingId, source: BindingId,
}, },
/// After some sequence of matches, we'll match one of the previous bindings against a tuple and /// Pattern-match one of the previous bindings against a tuple and produce a new binding from
/// produce a new binding from one of its fields. (This currently only happens with external /// one of its fields. This is an irrefutable pattern match so there is no corresponding
/// extractors.) /// [Constraint]. (This currently only happens with external extractors.)
Tuple { MatchTuple {
/// Which binding is being matched? /// Which binding is being matched?
source: BindingId, source: BindingId,
/// Which tuple field are we projecting out? /// Which tuple field are we projecting out?
@@ -152,14 +124,15 @@ pub enum Constraint {
Some, Some,
} }
/// A term-rewriting rule. All [BindingId]s and [ExprId]s are only meaningful in the context of the /// A term-rewriting rule. All [BindingId]s are only meaningful in the context of the [RuleSet] that
/// [RuleSet] that contains this rule. /// contains this rule.
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Rule { pub struct Rule {
/// Where was this rule defined? /// Where was this rule defined?
pub pos: Pos, pub pos: Pos,
/// All of these bindings must match for this rule to apply. Note that within a single rule, if /// All of these bindings must match the given constraints for this rule to apply. Note that
/// a binding site must match two different constants, then the rule can never match. /// within a single rule, if a binding site must match two different constraints, then the rule
/// can never match.
constraints: HashMap<BindingId, Constraint>, constraints: HashMap<BindingId, Constraint>,
/// Sets of bindings which must be equal for this rule to match. /// Sets of bindings which must be equal for this rule to match.
pub equals: DisjointSets<BindingId>, pub equals: DisjointSets<BindingId>,
@@ -167,7 +140,7 @@ pub struct Rule {
/// evaluated. If multiple applicable rules have the same priority, that's an overlap error. /// evaluated. If multiple applicable rules have the same priority, that's an overlap error.
pub prio: i64, pub prio: i64,
/// If this rule applies, the top-level term should evaluate to this expression. /// If this rule applies, the top-level term should evaluate to this expression.
pub result: ExprId, pub result: BindingId,
} }
/// Records whether a given pair of rules can both match on some input. /// Records whether a given pair of rules can both match on some input.
@@ -184,15 +157,13 @@ pub enum Overlap {
}, },
} }
/// A collection of [Rule]s, along with hash-consed [Binding]s and [Expr]s for all of them. /// A collection of [Rule]s, along with hash-consed [Binding]s for all of them.
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct RuleSet { pub struct RuleSet {
/// The [Rule]s for a single [sema::Term]. /// The [Rule]s for a single [sema::Term].
pub rules: Vec<Rule>, pub rules: Vec<Rule>,
/// The bindings identified by [BindingId]s within rules. /// The bindings identified by [BindingId]s within rules.
pub bindings: Vec<Binding>, pub bindings: Vec<Binding>,
/// The expressions identified by [ExprId]s within rules.
pub exprs: Vec<Expr>,
} }
/// Construct a [RuleSet] for each term in `termenv` that has rules. /// Construct a [RuleSet] for each term in `termenv` that has rules.
@@ -306,7 +277,6 @@ struct UnreachableError {
struct RuleSetBuilder { struct RuleSetBuilder {
current_rule: Rule, current_rule: Rule,
binding_map: HashMap<Binding, BindingId>, binding_map: HashMap<Binding, BindingId>,
expr_map: HashMap<Expr, ExprId>,
unreachable: Vec<UnreachableError>, unreachable: Vec<UnreachableError>,
rules: RuleSet, rules: RuleSet,
} }
@@ -370,7 +340,7 @@ impl RuleSetBuilder {
// First, find all the constraints that need to be copied to other binding sites in their // First, find all the constraints that need to be copied to other binding sites in their
// respective equivalence classes. Note: do not remove these constraints here! Yes, we'll // respective equivalence classes. Note: do not remove these constraints here! Yes, we'll
// put them back later, but we rely on still having them around so that // put them back later, but we rely on still having them around so that
// `set_constraint_or_error` can detect conflicting constraints. // `set_constraint` can detect conflicting constraints.
let mut deferred_constraints = Vec::new(); let mut deferred_constraints = Vec::new();
for (&binding, &constraint) in self.current_rule.constraints.iter() { for (&binding, &constraint) in self.current_rule.constraints.iter() {
if let Some(root) = self.current_rule.equals.find_mut(binding) { if let Some(root) = self.current_rule.equals.find_mut(binding) {
@@ -387,8 +357,8 @@ impl RuleSetBuilder {
// Remove the entire equivalence class and instead add copies of this constraint to // Remove the entire equivalence class and instead add copies of this constraint to
// every binding site in the class. If there are constraints on other binding sites in // every binding site in the class. If there are constraints on other binding sites in
// this class, then when we try to copy this constraint to those binding sites, // this class, then when we try to copy this constraint to those binding sites,
// `set_constraint_or_error` will check that the constraints are equal and record an // `set_constraint` will check that the constraints are equal and record an appropriate
// appropriate error otherwise. // error otherwise.
// //
// Later, we'll re-visit those other binding sites because they're still in // Later, we'll re-visit those other binding sites because they're still in
// `deferred_constraints`, but `set` will be empty because we already deleted the // `deferred_constraints`, but `set` will be empty because we already deleted the
@@ -409,14 +379,23 @@ impl RuleSetBuilder {
}, },
Some((&base, rest)), Some((&base, rest)),
) => { ) => {
let base_fields = let mut defer = |this: &Self, binding| {
self.field_bindings(base, fields, variant, &mut deferred_constraints); // We're adding equality constraints to binding sites that may not have had
// one already. If that binding site already had a concrete constraint, then
// we need to "recursively" propagate that constraint through the new
// equivalence class too.
if let Some(constraint) = this.current_rule.get_constraint(binding) {
deferred_constraints.push((binding, constraint));
}
};
let base_fields = self.variant_bindings(base, fields, variant);
base_fields.iter().for_each(|&x| defer(self, x));
for &binding in rest { for &binding in rest {
for (&x, &y) in self for (&x, y) in base_fields
.field_bindings(binding, fields, variant, &mut deferred_constraints)
.iter() .iter()
.zip(base_fields.iter()) .zip(self.variant_bindings(binding, fields, variant))
{ {
defer(self, y);
self.current_rule.equals.merge(x, y); self.current_rule.equals.merge(x, y);
} }
} }
@@ -433,33 +412,24 @@ impl RuleSetBuilder {
} }
for binding in set { for binding in set {
self.set_constraint_or_error(binding, constraint); self.set_constraint(binding, constraint);
} }
} }
} }
fn field_bindings( fn variant_bindings(
&mut self, &mut self,
binding: BindingId, binding: BindingId,
fields: TupleIndex, fields: TupleIndex,
variant: sema::VariantId, variant: sema::VariantId,
deferred_constraints: &mut Vec<(BindingId, Constraint)>, ) -> Vec<BindingId> {
) -> Box<[BindingId]> {
(0..fields.0) (0..fields.0)
.map(TupleIndex) .map(|field| {
.map(move |field| { self.dedup_binding(Binding::MatchVariant {
let binding = self.dedup_binding(Binding::Variant {
source: binding, source: binding,
variant, variant,
field, field: TupleIndex(field),
}); })
// We've just added an equality constraint to a binding site that may not have had
// one already. If that binding site already had a concrete constraint, then we need
// to "recursively" propagate that constraint through the new equivalence class too.
if let Some(constraint) = self.current_rule.get_constraint(binding) {
deferred_constraints.push((binding, constraint));
}
binding
}) })
.collect() .collect()
} }
@@ -475,24 +445,7 @@ impl RuleSetBuilder {
} }
} }
fn dedup_expr(&mut self, expr: Expr) -> ExprId { fn set_constraint(&mut self, input: BindingId, constraint: Constraint) {
if let Some(expr) = self.expr_map.get(&expr) {
*expr
} else {
let id = ExprId(self.rules.exprs.len().try_into().unwrap());
self.rules.exprs.push(expr.clone());
self.expr_map.insert(expr, id);
id
}
}
fn set_constraint(&mut self, input: Binding, constraint: Constraint) -> BindingId {
let input = self.dedup_binding(input);
self.set_constraint_or_error(input, constraint);
input
}
fn set_constraint_or_error(&mut self, input: BindingId, constraint: Constraint) {
if let Err(e) = self.current_rule.set_constraint(input, constraint) { if let Err(e) = self.current_rule.set_constraint(input, constraint) {
self.unreachable.push(e); self.unreachable.push(e);
} }
@@ -500,37 +453,32 @@ impl RuleSetBuilder {
} }
impl sema::PatternVisitor for RuleSetBuilder { impl sema::PatternVisitor for RuleSetBuilder {
/// The "identifier" this visitor uses for binding sites is a [Binding], not a [BindingId]. type PatternId = BindingId;
/// Either choice would work but this approach avoids adding bindings to the [RuleSet] if they
/// are never used in any rule.
type PatternId = Binding;
fn add_match_equal(&mut self, a: Binding, b: Binding, _ty: sema::TypeId) { fn add_match_equal(&mut self, a: BindingId, b: BindingId, _ty: sema::TypeId) {
let a = self.dedup_binding(a);
let b = self.dedup_binding(b);
// If both bindings represent the same binding site, they're implicitly equal. // If both bindings represent the same binding site, they're implicitly equal.
if a != b { if a != b {
self.current_rule.equals.merge(a, b); self.current_rule.equals.merge(a, b);
} }
} }
fn add_match_int(&mut self, input: Binding, _ty: sema::TypeId, val: i128) { fn add_match_int(&mut self, input: BindingId, _ty: sema::TypeId, val: i128) {
self.set_constraint(input, Constraint::ConstInt { val }); self.set_constraint(input, Constraint::ConstInt { val });
} }
fn add_match_prim(&mut self, input: Binding, _ty: sema::TypeId, val: sema::Sym) { fn add_match_prim(&mut self, input: BindingId, _ty: sema::TypeId, val: sema::Sym) {
self.set_constraint(input, Constraint::ConstPrim { val }); self.set_constraint(input, Constraint::ConstPrim { val });
} }
fn add_match_variant( fn add_match_variant(
&mut self, &mut self,
input: Binding, input: BindingId,
input_ty: sema::TypeId, input_ty: sema::TypeId,
arg_tys: &[sema::TypeId], arg_tys: &[sema::TypeId],
variant: sema::VariantId, variant: sema::VariantId,
) -> Vec<Binding> { ) -> Vec<BindingId> {
let fields = TupleIndex(arg_tys.len().try_into().unwrap()); let fields = TupleIndex(arg_tys.len().try_into().unwrap());
let source = self.set_constraint( self.set_constraint(
input, input,
Constraint::Variant { Constraint::Variant {
fields, fields,
@@ -538,79 +486,61 @@ impl sema::PatternVisitor for RuleSetBuilder {
variant, variant,
}, },
); );
(0..fields.0) self.variant_bindings(input, fields, variant)
.map(TupleIndex)
.map(|field| Binding::Variant {
source,
variant,
field,
})
.collect()
} }
fn add_extract( fn add_extract(
&mut self, &mut self,
input: Binding, input: BindingId,
_input_ty: sema::TypeId, _input_ty: sema::TypeId,
output_tys: Vec<sema::TypeId>, output_tys: Vec<sema::TypeId>,
term: sema::TermId, term: sema::TermId,
infallible: bool, infallible: bool,
_multi: bool, _multi: bool,
) -> Vec<Binding> { ) -> Vec<BindingId> {
// ISLE treats external extractors as patterns, but in this representation they're let source = self.dedup_binding(Binding::Extractor {
// expressions, because Rust doesn't support calling functions during pattern matching. To
// glue the two representations together we have to introduce suitable adapter nodes.
let input = self.pattern_as_expr(input);
let input = self.dedup_expr(Expr::Extractor {
term, term,
parameter: input, parameter: input,
}); });
let input = self.expr_as_pattern(input);
// If the extractor is fallible, build a pattern and constraint for `Some` // If the extractor is fallible, build a pattern and constraint for `Some`
let source = if infallible { let source = if infallible {
input source
} else { } else {
let source = self.set_constraint(input, Constraint::Some); self.set_constraint(source, Constraint::Some);
Binding::Some { source } self.dedup_binding(Binding::MatchSome { source })
}; };
// If the extractor has multiple outputs, create a separate binding for each // If the extractor has multiple outputs, create a separate binding for each
match output_tys.len().try_into().unwrap() { match output_tys.len().try_into().unwrap() {
0 => vec![], 0 => vec![],
1 => vec![source], 1 => vec![source],
outputs => { outputs => (0..outputs)
let source = self.dedup_binding(source); .map(TupleIndex)
(0..outputs) .map(|field| self.dedup_binding(Binding::MatchTuple { source, field }))
.map(TupleIndex) .collect(),
.map(|field| Binding::Tuple { source, field })
.collect()
}
} }
} }
} }
impl sema::ExprVisitor for RuleSetBuilder { impl sema::ExprVisitor for RuleSetBuilder {
/// Unlike the `PatternVisitor` implementation, we use [ExprId] to identify intermediate type ExprId = BindingId;
/// expressions, not [Expr]. Visited expressions are always used so we might as well deduplicate
/// them eagerly.
type ExprId = ExprId;
fn add_const_int(&mut self, _ty: sema::TypeId, val: i128) -> ExprId { fn add_const_int(&mut self, _ty: sema::TypeId, val: i128) -> BindingId {
self.dedup_expr(Expr::ConstInt { val }) self.dedup_binding(Binding::ConstInt { val })
} }
fn add_const_prim(&mut self, _ty: sema::TypeId, val: sema::Sym) -> ExprId { fn add_const_prim(&mut self, _ty: sema::TypeId, val: sema::Sym) -> BindingId {
self.dedup_expr(Expr::ConstPrim { val }) self.dedup_binding(Binding::ConstPrim { val })
} }
fn add_create_variant( fn add_create_variant(
&mut self, &mut self,
inputs: Vec<(ExprId, sema::TypeId)>, inputs: Vec<(BindingId, sema::TypeId)>,
ty: sema::TypeId, ty: sema::TypeId,
variant: sema::VariantId, variant: sema::VariantId,
) -> ExprId { ) -> BindingId {
self.dedup_expr(Expr::Variant { self.dedup_binding(Binding::MakeVariant {
ty, ty,
variant, variant,
fields: inputs.into_iter().map(|(expr, _)| expr).collect(), fields: inputs.into_iter().map(|(expr, _)| expr).collect(),
@@ -619,13 +549,13 @@ impl sema::ExprVisitor for RuleSetBuilder {
fn add_construct( fn add_construct(
&mut self, &mut self,
inputs: Vec<(ExprId, sema::TypeId)>, inputs: Vec<(BindingId, sema::TypeId)>,
_ty: sema::TypeId, _ty: sema::TypeId,
term: sema::TermId, term: sema::TermId,
_infallible: bool, _infallible: bool,
_multi: bool, _multi: bool,
) -> ExprId { ) -> BindingId {
self.dedup_expr(Expr::Constructor { self.dedup_binding(Binding::Constructor {
term, term,
parameters: inputs.into_iter().map(|(expr, _)| expr).collect(), parameters: inputs.into_iter().map(|(expr, _)| expr).collect(),
}) })
@@ -635,42 +565,29 @@ impl sema::ExprVisitor for RuleSetBuilder {
impl sema::RuleVisitor for RuleSetBuilder { impl sema::RuleVisitor for RuleSetBuilder {
type PatternVisitor = Self; type PatternVisitor = Self;
type ExprVisitor = Self; type ExprVisitor = Self;
type Expr = ExprId; type Expr = BindingId;
fn add_arg(&mut self, index: usize, _ty: sema::TypeId) -> Binding { fn add_arg(&mut self, index: usize, _ty: sema::TypeId) -> BindingId {
// Arguments don't need to be pattern-matched to reference them, so they're expressions
let index = TupleIndex(index.try_into().unwrap()); let index = TupleIndex(index.try_into().unwrap());
let expr = self.dedup_expr(Expr::Argument { index }); self.dedup_binding(Binding::Argument { index })
Binding::Expr { constructor: expr }
} }
fn add_pattern<F: FnOnce(&mut Self)>(&mut self, visitor: F) { fn add_pattern<F: FnOnce(&mut Self)>(&mut self, visitor: F) {
visitor(self) visitor(self)
} }
fn add_expr<F>(&mut self, visitor: F) -> ExprId fn add_expr<F>(&mut self, visitor: F) -> BindingId
where where
F: FnOnce(&mut Self) -> sema::VisitedExpr<Self>, F: FnOnce(&mut Self) -> sema::VisitedExpr<Self>,
{ {
visitor(self).value visitor(self).value
} }
fn expr_as_pattern(&mut self, expr: ExprId) -> Binding { fn expr_as_pattern(&mut self, expr: BindingId) -> BindingId {
if let &Expr::Binding { source: binding } = &self.rules.exprs[expr.index()] { expr
// Short-circuit wrapping a binding around an expr from another binding
self.rules.bindings[binding.index()]
} else {
Binding::Expr { constructor: expr }
}
} }
fn pattern_as_expr(&mut self, pattern: Binding) -> ExprId { fn pattern_as_expr(&mut self, pattern: BindingId) -> BindingId {
if let Binding::Expr { constructor } = pattern { pattern
// Short-circuit wrapping an expr around a binding from another expr
constructor
} else {
let binding = self.dedup_binding(pattern);
self.dedup_expr(Expr::Binding { source: binding })
}
} }
} }