cranelift-isle: Unify expressions and bindings (#5294)

As it turns out, that distinction was not necessary for this
representation. Removing it eliminates some complexity around wrapping
expressions as bindings and vice versa. It also clears up some confusion
about which category to put certain constructs in (arguments and
extractors) by refusing to have different categories.

While I was writing this patch I also realized that `add_match_variant`
and `normalize_equivalence_classes` both need to do fundamentally the
same things with enum variants, so I refactored them to share code and
make their relationship clearer.

Finally, I reviewed all the comments in this file and fixed some places
where they could be more clear.
This commit is contained in:
Jamey Sharp
2022-11-17 16:00:59 -08:00
committed by GitHub
parent 3b6544dc66
commit 9a44ef7443

View File

@@ -2,7 +2,7 @@
//! to closely reflect the operations we can implement in Rust, to make code generation easy.
use crate::error::{Error, Source, Span};
use crate::lexer::Pos;
use crate::sema::{self, RuleVisitor};
use crate::sema;
use crate::DisjointSets;
use std::collections::{hash_map::Entry, HashMap};
@@ -12,9 +12,6 @@ pub struct TupleIndex(u8);
/// A hash-consed identifier for a binding, stored in a [RuleSet].
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct BindingId(u16);
/// A hash-consed identifier for an expression, stored in a [RuleSet].
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct ExprId(u16);
impl BindingId {
/// Get the index of this id.
@@ -23,24 +20,10 @@ impl BindingId {
}
}
impl ExprId {
/// Get the index of this id.
pub fn index(self) -> usize {
self.0.into()
}
}
/// Expressions construct new values. Rust pattern matching can only destructure existing values,
/// not call functions or construct new values. So `if-let` and external extractor invocations need
/// to interrupt pattern matching in order to evaluate a suitable expression. These expressions are
/// also used when evaluating the right-hand side of a rule.
/// Bindings are anything which can be bound to a variable name in Rust. This includes expressions,
/// such as constants or function calls; but it also includes names bound in pattern matches.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum Expr {
/// A binding from some sequence of pattern matches, used as an expression.
Binding {
/// Which binding site is being used as an expression?
source: BindingId,
},
pub enum Binding {
/// Evaluates to the given integer literal.
ConstInt {
/// The constant value.
@@ -61,39 +44,28 @@ pub enum Expr {
/// Which extractor should be called?
term: sema::TermId,
/// What expression should be passed to the extractor?
parameter: ExprId,
parameter: BindingId,
},
/// The result of calling an external constructor.
Constructor {
/// Which constructor should be called?
term: sema::TermId,
/// What expressions should be passed to the constructor?
parameters: Box<[ExprId]>,
parameters: Box<[BindingId]>,
},
/// The result of constructing an enum variant.
Variant {
MakeVariant {
/// Which enum type should be constructed?
ty: sema::TypeId,
/// Which variant of that enum should be constructed?
variant: sema::VariantId,
/// What expressions should be provided for this variant's fields?
fields: Box<[ExprId]>,
fields: Box<[BindingId]>,
},
}
/// Binding sites are the result of Rust pattern matching. This is the dual of an expression: while
/// expressions build up values, bindings take values apart.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum Binding {
/// A match begins at the result of some expression that produces a Rust value.
Expr {
/// Which expression is being matched?
constructor: ExprId,
},
/// After some sequence of matches, we'll match one of the previous bindings against an enum
/// variant and produce a new binding from one of its fields. There must be a matching
/// [Constraint] for each `source`/`variant` pair that appears in a binding.
Variant {
/// Pattern-match one of the previous bindings against an enum variant and produce a new binding
/// from one of its fields. There must be a corresponding [Constraint::Variant] for each
/// `source`/`variant` pair that appears in some `MatchVariant` binding.
MatchVariant {
/// Which binding is being matched?
source: BindingId,
/// Which enum variant are we pulling binding sites from? This is somewhat redundant with
@@ -105,17 +77,17 @@ pub enum Binding {
/// get the field names.
field: TupleIndex,
},
/// After some sequence of matches, we'll match one of the previous bindings against
/// `Option::Some` and produce a new binding from its contents. (This currently only happens
/// with external extractors.)
Some {
/// Pattern-match one of the previous bindings against `Option::Some` and produce a new binding
/// from its contents. There must be a corresponding [Constraint::Some] for each `source` that
/// appears in a `MatchSome` binding. (This currently only happens with external extractors.)
MatchSome {
/// Which binding is being matched?
source: BindingId,
},
/// After some sequence of matches, we'll match one of the previous bindings against a tuple and
/// produce a new binding from one of its fields. (This currently only happens with external
/// extractors.)
Tuple {
/// Pattern-match one of the previous bindings against a tuple and produce a new binding from
/// one of its fields. This is an irrefutable pattern match so there is no corresponding
/// [Constraint]. (This currently only happens with external extractors.)
MatchTuple {
/// Which binding is being matched?
source: BindingId,
/// Which tuple field are we projecting out?
@@ -152,14 +124,15 @@ pub enum Constraint {
Some,
}
/// A term-rewriting rule. All [BindingId]s and [ExprId]s are only meaningful in the context of the
/// [RuleSet] that contains this rule.
/// A term-rewriting rule. All [BindingId]s are only meaningful in the context of the [RuleSet] that
/// contains this rule.
#[derive(Debug, Default)]
pub struct Rule {
/// Where was this rule defined?
pub pos: Pos,
/// All of these bindings must match for this rule to apply. Note that within a single rule, if
/// a binding site must match two different constants, then the rule can never match.
/// All of these bindings must match the given constraints for this rule to apply. Note that
/// within a single rule, if a binding site must match two different constraints, then the rule
/// can never match.
constraints: HashMap<BindingId, Constraint>,
/// Sets of bindings which must be equal for this rule to match.
pub equals: DisjointSets<BindingId>,
@@ -167,7 +140,7 @@ pub struct Rule {
/// evaluated. If multiple applicable rules have the same priority, that's an overlap error.
pub prio: i64,
/// If this rule applies, the top-level term should evaluate to this expression.
pub result: ExprId,
pub result: BindingId,
}
/// Records whether a given pair of rules can both match on some input.
@@ -184,15 +157,13 @@ pub enum Overlap {
},
}
/// A collection of [Rule]s, along with hash-consed [Binding]s and [Expr]s for all of them.
/// A collection of [Rule]s, along with hash-consed [Binding]s for all of them.
#[derive(Debug, Default)]
pub struct RuleSet {
/// The [Rule]s for a single [sema::Term].
pub rules: Vec<Rule>,
/// The bindings identified by [BindingId]s within rules.
pub bindings: Vec<Binding>,
/// The expressions identified by [ExprId]s within rules.
pub exprs: Vec<Expr>,
}
/// Construct a [RuleSet] for each term in `termenv` that has rules.
@@ -306,7 +277,6 @@ struct UnreachableError {
struct RuleSetBuilder {
current_rule: Rule,
binding_map: HashMap<Binding, BindingId>,
expr_map: HashMap<Expr, ExprId>,
unreachable: Vec<UnreachableError>,
rules: RuleSet,
}
@@ -370,7 +340,7 @@ impl RuleSetBuilder {
// First, find all the constraints that need to be copied to other binding sites in their
// respective equivalence classes. Note: do not remove these constraints here! Yes, we'll
// put them back later, but we rely on still having them around so that
// `set_constraint_or_error` can detect conflicting constraints.
// `set_constraint` can detect conflicting constraints.
let mut deferred_constraints = Vec::new();
for (&binding, &constraint) in self.current_rule.constraints.iter() {
if let Some(root) = self.current_rule.equals.find_mut(binding) {
@@ -387,8 +357,8 @@ impl RuleSetBuilder {
// Remove the entire equivalence class and instead add copies of this constraint to
// every binding site in the class. If there are constraints on other binding sites in
// this class, then when we try to copy this constraint to those binding sites,
// `set_constraint_or_error` will check that the constraints are equal and record an
// appropriate error otherwise.
// `set_constraint` will check that the constraints are equal and record an appropriate
// error otherwise.
//
// Later, we'll re-visit those other binding sites because they're still in
// `deferred_constraints`, but `set` will be empty because we already deleted the
@@ -409,14 +379,23 @@ impl RuleSetBuilder {
},
Some((&base, rest)),
) => {
let base_fields =
self.field_bindings(base, fields, variant, &mut deferred_constraints);
let mut defer = |this: &Self, binding| {
// We're adding equality constraints to binding sites that may not have had
// one already. If that binding site already had a concrete constraint, then
// we need to "recursively" propagate that constraint through the new
// equivalence class too.
if let Some(constraint) = this.current_rule.get_constraint(binding) {
deferred_constraints.push((binding, constraint));
}
};
let base_fields = self.variant_bindings(base, fields, variant);
base_fields.iter().for_each(|&x| defer(self, x));
for &binding in rest {
for (&x, &y) in self
.field_bindings(binding, fields, variant, &mut deferred_constraints)
for (&x, y) in base_fields
.iter()
.zip(base_fields.iter())
.zip(self.variant_bindings(binding, fields, variant))
{
defer(self, y);
self.current_rule.equals.merge(x, y);
}
}
@@ -433,33 +412,24 @@ impl RuleSetBuilder {
}
for binding in set {
self.set_constraint_or_error(binding, constraint);
self.set_constraint(binding, constraint);
}
}
}
fn field_bindings(
fn variant_bindings(
&mut self,
binding: BindingId,
fields: TupleIndex,
variant: sema::VariantId,
deferred_constraints: &mut Vec<(BindingId, Constraint)>,
) -> Box<[BindingId]> {
) -> Vec<BindingId> {
(0..fields.0)
.map(TupleIndex)
.map(move |field| {
let binding = self.dedup_binding(Binding::Variant {
.map(|field| {
self.dedup_binding(Binding::MatchVariant {
source: binding,
variant,
field,
});
// We've just added an equality constraint to a binding site that may not have had
// one already. If that binding site already had a concrete constraint, then we need
// to "recursively" propagate that constraint through the new equivalence class too.
if let Some(constraint) = self.current_rule.get_constraint(binding) {
deferred_constraints.push((binding, constraint));
}
binding
field: TupleIndex(field),
})
})
.collect()
}
@@ -475,24 +445,7 @@ impl RuleSetBuilder {
}
}
fn dedup_expr(&mut self, expr: Expr) -> ExprId {
if let Some(expr) = self.expr_map.get(&expr) {
*expr
} else {
let id = ExprId(self.rules.exprs.len().try_into().unwrap());
self.rules.exprs.push(expr.clone());
self.expr_map.insert(expr, id);
id
}
}
fn set_constraint(&mut self, input: Binding, constraint: Constraint) -> BindingId {
let input = self.dedup_binding(input);
self.set_constraint_or_error(input, constraint);
input
}
fn set_constraint_or_error(&mut self, input: BindingId, constraint: Constraint) {
fn set_constraint(&mut self, input: BindingId, constraint: Constraint) {
if let Err(e) = self.current_rule.set_constraint(input, constraint) {
self.unreachable.push(e);
}
@@ -500,37 +453,32 @@ impl RuleSetBuilder {
}
impl sema::PatternVisitor for RuleSetBuilder {
/// The "identifier" this visitor uses for binding sites is a [Binding], not a [BindingId].
/// Either choice would work but this approach avoids adding bindings to the [RuleSet] if they
/// are never used in any rule.
type PatternId = Binding;
type PatternId = BindingId;
fn add_match_equal(&mut self, a: Binding, b: Binding, _ty: sema::TypeId) {
let a = self.dedup_binding(a);
let b = self.dedup_binding(b);
fn add_match_equal(&mut self, a: BindingId, b: BindingId, _ty: sema::TypeId) {
// If both bindings represent the same binding site, they're implicitly equal.
if a != b {
self.current_rule.equals.merge(a, b);
}
}
fn add_match_int(&mut self, input: Binding, _ty: sema::TypeId, val: i128) {
fn add_match_int(&mut self, input: BindingId, _ty: sema::TypeId, val: i128) {
self.set_constraint(input, Constraint::ConstInt { val });
}
fn add_match_prim(&mut self, input: Binding, _ty: sema::TypeId, val: sema::Sym) {
fn add_match_prim(&mut self, input: BindingId, _ty: sema::TypeId, val: sema::Sym) {
self.set_constraint(input, Constraint::ConstPrim { val });
}
fn add_match_variant(
&mut self,
input: Binding,
input: BindingId,
input_ty: sema::TypeId,
arg_tys: &[sema::TypeId],
variant: sema::VariantId,
) -> Vec<Binding> {
) -> Vec<BindingId> {
let fields = TupleIndex(arg_tys.len().try_into().unwrap());
let source = self.set_constraint(
self.set_constraint(
input,
Constraint::Variant {
fields,
@@ -538,79 +486,61 @@ impl sema::PatternVisitor for RuleSetBuilder {
variant,
},
);
(0..fields.0)
.map(TupleIndex)
.map(|field| Binding::Variant {
source,
variant,
field,
})
.collect()
self.variant_bindings(input, fields, variant)
}
fn add_extract(
&mut self,
input: Binding,
input: BindingId,
_input_ty: sema::TypeId,
output_tys: Vec<sema::TypeId>,
term: sema::TermId,
infallible: bool,
_multi: bool,
) -> Vec<Binding> {
// ISLE treats external extractors as patterns, but in this representation they're
// expressions, because Rust doesn't support calling functions during pattern matching. To
// glue the two representations together we have to introduce suitable adapter nodes.
let input = self.pattern_as_expr(input);
let input = self.dedup_expr(Expr::Extractor {
) -> Vec<BindingId> {
let source = self.dedup_binding(Binding::Extractor {
term,
parameter: input,
});
let input = self.expr_as_pattern(input);
// If the extractor is fallible, build a pattern and constraint for `Some`
let source = if infallible {
input
source
} else {
let source = self.set_constraint(input, Constraint::Some);
Binding::Some { source }
self.set_constraint(source, Constraint::Some);
self.dedup_binding(Binding::MatchSome { source })
};
// If the extractor has multiple outputs, create a separate binding for each
match output_tys.len().try_into().unwrap() {
0 => vec![],
1 => vec![source],
outputs => {
let source = self.dedup_binding(source);
(0..outputs)
.map(TupleIndex)
.map(|field| Binding::Tuple { source, field })
.collect()
}
outputs => (0..outputs)
.map(TupleIndex)
.map(|field| self.dedup_binding(Binding::MatchTuple { source, field }))
.collect(),
}
}
}
impl sema::ExprVisitor for RuleSetBuilder {
/// Unlike the `PatternVisitor` implementation, we use [ExprId] to identify intermediate
/// expressions, not [Expr]. Visited expressions are always used so we might as well deduplicate
/// them eagerly.
type ExprId = ExprId;
type ExprId = BindingId;
fn add_const_int(&mut self, _ty: sema::TypeId, val: i128) -> ExprId {
self.dedup_expr(Expr::ConstInt { val })
fn add_const_int(&mut self, _ty: sema::TypeId, val: i128) -> BindingId {
self.dedup_binding(Binding::ConstInt { val })
}
fn add_const_prim(&mut self, _ty: sema::TypeId, val: sema::Sym) -> ExprId {
self.dedup_expr(Expr::ConstPrim { val })
fn add_const_prim(&mut self, _ty: sema::TypeId, val: sema::Sym) -> BindingId {
self.dedup_binding(Binding::ConstPrim { val })
}
fn add_create_variant(
&mut self,
inputs: Vec<(ExprId, sema::TypeId)>,
inputs: Vec<(BindingId, sema::TypeId)>,
ty: sema::TypeId,
variant: sema::VariantId,
) -> ExprId {
self.dedup_expr(Expr::Variant {
) -> BindingId {
self.dedup_binding(Binding::MakeVariant {
ty,
variant,
fields: inputs.into_iter().map(|(expr, _)| expr).collect(),
@@ -619,13 +549,13 @@ impl sema::ExprVisitor for RuleSetBuilder {
fn add_construct(
&mut self,
inputs: Vec<(ExprId, sema::TypeId)>,
inputs: Vec<(BindingId, sema::TypeId)>,
_ty: sema::TypeId,
term: sema::TermId,
_infallible: bool,
_multi: bool,
) -> ExprId {
self.dedup_expr(Expr::Constructor {
) -> BindingId {
self.dedup_binding(Binding::Constructor {
term,
parameters: inputs.into_iter().map(|(expr, _)| expr).collect(),
})
@@ -635,42 +565,29 @@ impl sema::ExprVisitor for RuleSetBuilder {
impl sema::RuleVisitor for RuleSetBuilder {
type PatternVisitor = Self;
type ExprVisitor = Self;
type Expr = ExprId;
type Expr = BindingId;
fn add_arg(&mut self, index: usize, _ty: sema::TypeId) -> Binding {
// Arguments don't need to be pattern-matched to reference them, so they're expressions
fn add_arg(&mut self, index: usize, _ty: sema::TypeId) -> BindingId {
let index = TupleIndex(index.try_into().unwrap());
let expr = self.dedup_expr(Expr::Argument { index });
Binding::Expr { constructor: expr }
self.dedup_binding(Binding::Argument { index })
}
fn add_pattern<F: FnOnce(&mut Self)>(&mut self, visitor: F) {
visitor(self)
}
fn add_expr<F>(&mut self, visitor: F) -> ExprId
fn add_expr<F>(&mut self, visitor: F) -> BindingId
where
F: FnOnce(&mut Self) -> sema::VisitedExpr<Self>,
{
visitor(self).value
}
fn expr_as_pattern(&mut self, expr: ExprId) -> Binding {
if let &Expr::Binding { source: binding } = &self.rules.exprs[expr.index()] {
// Short-circuit wrapping a binding around an expr from another binding
self.rules.bindings[binding.index()]
} else {
Binding::Expr { constructor: expr }
}
fn expr_as_pattern(&mut self, expr: BindingId) -> BindingId {
expr
}
fn pattern_as_expr(&mut self, pattern: Binding) -> ExprId {
if let Binding::Expr { constructor } = pattern {
// Short-circuit wrapping an expr around a binding from another expr
constructor
} else {
let binding = self.dedup_binding(pattern);
self.dedup_expr(Expr::Binding { source: binding })
}
fn pattern_as_expr(&mut self, pattern: BindingId) -> BindingId {
pattern
}
}