diff --git a/cranelift/isle/isle/src/error.rs b/cranelift/isle/isle/src/error.rs index 5748c267ab..9115be6367 100644 --- a/cranelift/isle/isle/src/error.rs +++ b/cranelift/isle/isle/src/error.rs @@ -42,6 +42,18 @@ pub enum Error { span: Span, }, + /// The rule can never match any input. + UnreachableError { + /// The error message. + msg: String, + + /// The input ISLE source. + src: Source, + + /// The location of the unreachable rule. + span: Span, + }, + /// The rules mentioned overlap in the input they accept. OverlapError { /// The error message. @@ -119,6 +131,15 @@ impl std::fmt::Display for Error { #[cfg(feature = "miette-errors")] Error::TypeError { msg, .. } => write!(f, "type error: {}", msg), + Error::UnreachableError { src, span, msg } => { + write!( + f, + "{}: unreachable rule: {}", + span.from.pretty_print_with_filename(&*src.name), + msg + ) + } + Error::OverlapError { msg, rules, .. } => { writeln!(f, "overlap error: {}\n{}", msg, OverlappingRules(&rules)) } diff --git a/cranelift/isle/isle/src/lib.rs b/cranelift/isle/isle/src/lib.rs index 140ba0aff1..570aeb4e4c 100644 --- a/cranelift/isle/isle/src/lib.rs +++ b/cranelift/isle/isle/src/lib.rs @@ -91,6 +91,114 @@ impl Index<&K> for StableMap { } } +/// Stores disjoint sets and provides efficient operations to merge two sets, and to find a +/// representative member of a set given any member of that set. In this implementation, sets always +/// have at least two members, and can only be formed by the `merge` operation. +#[derive(Debug, Default)] +pub struct DisjointSets { + parent: HashMap, +} + +impl DisjointSets { + /// Find a representative member of the set containing `x`. If `x` has not been merged with any + /// other items using `merge`, returns `None`. This method updates the data structure to make + /// future queries faster, and takes amortized constant time. + /// + /// ``` + /// let mut sets = cranelift_isle::DisjointSets::default(); + /// sets.merge(1, 2); + /// sets.merge(1, 3); + /// sets.merge(2, 4); + /// assert_eq!(sets.find_mut(3).unwrap(), sets.find_mut(4).unwrap()); + /// assert_eq!(sets.find_mut(10), None); + /// ``` + pub fn find_mut(&mut self, mut x: T) -> Option { + while let Some(node) = self.parent.get(&x) { + if node.0 == x { + return Some(x); + } + let grandparent = self.parent[&node.0].0; + // Re-do the lookup but take a mutable borrow this time + self.parent.get_mut(&x).unwrap().0 = grandparent; + x = grandparent; + } + None + } + + /// Merge the set containing `x` with the set containing `y`. This method takes amortized + /// constant time. + pub fn merge(&mut self, x: T, y: T) { + assert_ne!(x, y); + let mut x = if let Some(x) = self.find_mut(x) { + self.parent[&x] + } else { + self.parent.insert(x, (x, 0)); + (x, 0) + }; + let mut y = if let Some(y) = self.find_mut(y) { + self.parent[&y] + } else { + self.parent.insert(y, (y, 0)); + (y, 0) + }; + + if x == y { + return; + } + + if x.1 < y.1 { + std::mem::swap(&mut x, &mut y); + } + + self.parent.get_mut(&y.0).unwrap().0 = x.0; + if x.1 == y.1 { + let x_rank = &mut self.parent.get_mut(&x.0).unwrap().1; + *x_rank = x_rank.saturating_add(1); + } + } + + /// Remove the set containing the given item, and return all members of that set. The set is + /// returned in sorted order. This method takes time linear in the total size of all sets. + /// + /// ``` + /// let mut sets = cranelift_isle::DisjointSets::default(); + /// sets.merge(1, 2); + /// sets.merge(1, 3); + /// sets.merge(2, 4); + /// assert_eq!(sets.remove_set_of(4), &[1, 2, 3, 4]); + /// assert_eq!(sets.remove_set_of(1), &[]); + /// assert!(sets.is_empty()); + /// ``` + pub fn remove_set_of(&mut self, x: T) -> Vec + where + T: Ord, + { + let mut set = Vec::new(); + if let Some(x) = self.find_mut(x) { + set.extend(self.parent.keys().copied()); + // It's important to use `find_mut` here to avoid quadratic worst-case time. + set.retain(|&y| self.find_mut(y).unwrap() == x); + for y in set.iter() { + self.parent.remove(y); + } + set.sort_unstable(); + } + set + } + + /// Returns true if there are no sets. This method takes constant time. + /// + /// ``` + /// let mut sets = cranelift_isle::DisjointSets::default(); + /// assert!(sets.is_empty()); + /// sets.merge(1, 2); + /// assert!(!sets.is_empty()); + /// ``` + pub fn is_empty(&self) -> bool { + self.parent.is_empty() + } +} + pub mod ast; pub mod codegen; pub mod compile; @@ -102,6 +210,7 @@ pub mod overlap; pub mod parser; pub mod sema; pub mod trie; +pub mod trie_again; #[cfg(feature = "miette-errors")] mod error_miette; diff --git a/cranelift/isle/isle/src/overlap.rs b/cranelift/isle/isle/src/overlap.rs index 789e7d450c..2e0f34a553 100644 --- a/cranelift/isle/isle/src/overlap.rs +++ b/cranelift/isle/isle/src/overlap.rs @@ -5,15 +5,14 @@ use std::collections::{HashMap, HashSet}; use crate::error::{Error, Result, Source, Span}; use crate::lexer::Pos; -use crate::sema::{self, Rule, RuleId, Sym, TermEnv, TermId, TermKind, TypeEnv, VarId}; +use crate::sema::{TermEnv, TermId, TermKind, TypeEnv}; +use crate::trie_again; /// Check for overlap. pub fn check(tyenv: &TypeEnv, termenv: &TermEnv) -> Result<()> { - let mut errors = check_overlaps(termenv).report(tyenv, termenv); - errors.sort_by_key(|err| match err { - Error::OverlapError { rules, .. } => rules.first().unwrap().1.from, - _ => Pos::default(), - }); + let (terms, mut errors) = trie_again::build(termenv, tyenv); + errors.append(&mut check_overlaps(terms, termenv).report(tyenv)); + match errors.len() { 0 => Ok(()), 1 => Err(errors.pop().unwrap()), @@ -25,7 +24,7 @@ pub fn check(tyenv: &TypeEnv, termenv: &TermEnv) -> Result<()> { #[derive(Default)] struct Errors { /// Edges between rules indicating overlap. - nodes: HashMap>, + nodes: HashMap>, } impl Errors { @@ -33,30 +32,29 @@ impl Errors { /// nodes from the graph with the highest degree, reporting errors for them and their direct /// connections. The goal with reporting errors this way is to prefer reporting rules that /// overlap with many others first, and then report other more targeted overlaps later. - fn report(mut self, tyenv: &TypeEnv, termenv: &TermEnv) -> Vec { + fn report(mut self, tyenv: &TypeEnv) -> Vec { let mut errors = Vec::new(); - let get_info = |id: RuleId| { - let rule = &termenv.rules[id.0]; - let file = rule.pos.file; + let get_info = |pos: Pos| { + let file = pos.file; let src = Source::new( tyenv.filenames[file].clone(), tyenv.file_texts[file].clone(), ); - let span = Span::new_single(rule.pos); + let span = Span::new_single(pos); (src, span) }; - while let Some((&id, _)) = self + while let Some((&pos, _)) = self .nodes .iter() - .max_by_key(|(id, edges)| (edges.len(), *id)) + .max_by_key(|(pos, edges)| (edges.len(), *pos)) { - let node = self.nodes.remove(&id).unwrap(); + let node = self.nodes.remove(&pos).unwrap(); for other in node.iter() { if let Entry::Occupied(mut entry) = self.nodes.entry(*other) { let back_edges = entry.get_mut(); - back_edges.remove(&id); + back_edges.remove(&pos); if back_edges.is_empty() { entry.remove(); } @@ -64,7 +62,7 @@ impl Errors { } // build the real error - let mut rules = vec![get_info(id)]; + let mut rules = vec![get_info(pos)]; rules.extend(node.into_iter().map(get_info)); @@ -74,271 +72,48 @@ impl Errors { }); } + errors.sort_by_key(|err| match err { + Error::OverlapError { rules, .. } => rules.first().unwrap().1.from, + _ => Pos::default(), + }); errors } - /// Add a bidirectional edge between two rules in the graph. - fn add_edge(&mut self, a: RuleId, b: RuleId) { - // edges are undirected - self.nodes.entry(a).or_default().insert(b); - self.nodes.entry(b).or_default().insert(a); + fn check_pair(&mut self, a: &trie_again::Rule, b: &trie_again::Rule) { + if let trie_again::Overlap::Yes { .. } = a.may_overlap(b) { + if a.prio == b.prio { + // edges are undirected + self.nodes.entry(a.pos).or_default().insert(b.pos); + self.nodes.entry(b.pos).or_default().insert(a.pos); + } + } } } -/// Determine if any rules overlap in the input that they accept. This checkes every unique pair of +/// Determine if any rules overlap in the input that they accept. This checks every unique pair of /// rules, as checking rules in aggregate tends to suffer from exponential explosion in the /// presence of wildcard patterns. -fn check_overlaps(env: &TermEnv) -> Errors { - struct RulePatterns<'a> { - rule: &'a Rule, - pats: Box<[Pattern]>, - } - let mut by_term = HashMap::new(); - for rule in env.rules.iter() { - if let sema::Pattern::Term(_, tid, ref vars) = rule.lhs { - let is_multi_ctor = match &env.terms[tid.index()].kind { - &TermKind::Decl { multi, .. } => multi, - _ => false, - }; - if is_multi_ctor { - // Rules for multi-constructors are not checked for - // overlap: the ctor returns *every* match, not just - // the first or highest-priority one, so overlap does - // not actually affect the results. - continue; - } - - // Group rules by term and priority. Only rules within the same group are checked to - // see if they overlap each other. If you want to change the scope of overlap checking, - // change this key. - let key = (tid, rule.prio); - - let mut binds = Vec::new(); - let rule = RulePatterns { - rule, - pats: vars - .iter() - .map(|pat| Pattern::from_sema(env, &mut binds, pat)) - .collect(), - }; - by_term.entry(key).or_insert_with(Vec::new).push(rule); - } - } - +fn check_overlaps(terms: Vec<(TermId, trie_again::RuleSet)>, env: &TermEnv) -> Errors { let mut errs = Errors::default(); - for (_, rows) in by_term { - let mut cursor = rows.into_iter(); + for (tid, ruleset) in terms { + let is_multi_ctor = match &env.terms[tid.index()].kind { + &TermKind::Decl { multi, .. } => multi, + _ => false, + }; + if is_multi_ctor { + // Rules for multi-constructors are not checked for + // overlap: the ctor returns *every* match, not just + // the first or highest-priority one, so overlap does + // not actually affect the results. + continue; + } + + let mut cursor = ruleset.rules.iter(); while let Some(left) = cursor.next() { for right in cursor.as_slice() { - if check_overlap_pair(&left.pats, &right.pats) { - errs.add_edge(left.rule.id, right.rule.id); - } + errs.check_pair(left, right); } } } errs } - -/// Check if two rules overlap in the inputs they accept. -fn check_overlap_pair(a: &[Pattern], b: &[Pattern]) -> bool { - debug_assert_eq!(a.len(), b.len()); - let mut worklist: Vec<_> = a.iter().zip(b.iter()).collect(); - - while let Some((a, b)) = worklist.pop() { - // Checking the cross-product of two and-patterns is O(n*m). Merging sorted lists or - // hash-maps might be faster in practice, but: - // - The alternatives are not asymptotically faster, because in theory all the subpatterns - // might have the same extractor or enum variant, and in that case any approach has to - // check all of the cross-product combinations anyway. - // - It's easier to reason about this doubly-nested loop than about merging sorted lists or - // picking the right hash keys. - // - These lists are always so small that performance doesn't matter. - for a in a.as_and_subpatterns() { - for b in b.as_and_subpatterns() { - let overlap = match (a, b) { - (Pattern::Int { value: a }, Pattern::Int { value: b }) => a == b, - (Pattern::Const { name: a }, Pattern::Const { name: b }) => a == b, - - // if it's the same variant or same extractor, check all pairs of subterms - ( - Pattern::Variant { - id: a, - pats: a_pats, - }, - Pattern::Variant { - id: b, - pats: b_pats, - }, - ) - | ( - Pattern::Extractor { - id: a, - pats: a_pats, - }, - Pattern::Extractor { - id: b, - pats: b_pats, - }, - ) if a == b => { - debug_assert_eq!(a_pats.len(), b_pats.len()); - worklist.extend(a_pats.iter().zip(b_pats.iter())); - true - } - - // different variants of the same enum definitely do not overlap - (Pattern::Variant { .. }, Pattern::Variant { .. }) => false, - - // an extractor which does not exactly match the other pattern might overlap - (Pattern::Extractor { .. }, _) | (_, Pattern::Extractor { .. }) => true, - - // a wildcard definitely overlaps - (Pattern::Wildcard, _) | (_, Pattern::Wildcard) => true, - - // these patterns can only be paired with patterns of the same type, or - // wildcards or extractors, and all those cases are covered above - (Pattern::Int { .. } | Pattern::Const { .. } | Pattern::Variant { .. }, _) => { - unreachable!() - } - - // and-patterns don't reach here due to as_and_subpatterns - (Pattern::And { .. }, _) => unreachable!(), - }; - - if !overlap { - return false; - } - } - } - } - true -} - -/// A version of [`sema::Pattern`] with some simplifications to make overlap checking easier. -#[derive(Debug, Clone)] -enum Pattern { - /// Integer literal patterns. - Int { - value: i128, - }, - - /// Constant literal patterns, such as `$F32`. - Const { - name: Sym, - }, - - /// Enum variant constructors. - Variant { - id: TermId, - pats: Box<[Pattern]>, - }, - - /// Conjunctions of patterns. - And { - pats: Box<[Pattern]>, - }, - - /// Extractor uses (both fallible and infallible). - Extractor { - id: TermId, - pats: Box<[Pattern]>, - }, - - Wildcard, -} - -impl Pattern { - /// Create a [`Pattern`] from a [`sema::Pattern`]. The major differences between these two - /// representations are as follows: - /// 1. Variable bindings are removed and turned into wildcards - /// 2. Equality constraints are removed and turned into inlined versions of the patterns they - /// would have introduced equalities with - /// 3. [`sema::Pattern::Term`] instances are turned into either [`Pattern::Variant`] or - /// [`Pattern::Extractor`] cases depending on their term kind. - fn from_sema(env: &TermEnv, binds: &mut Vec<(VarId, Pattern)>, pat: &sema::Pattern) -> Self { - match pat { - sema::Pattern::BindPattern(_, id, pat) => { - let pat = Self::from_sema(env, binds, pat); - binds.push((*id, pat.clone())); - pat - } - - sema::Pattern::Var(_, id) => { - for (vid, pat) in binds.iter().rev() { - if vid == id { - // We inline equality constraints for two reasons: we specialize on the - // spine of related patterns only, so more specific information about - // individual values isn't necessarily helpful; we consider overlap - // checking to be an over-approximation of overlapping rules, so handling - // equalities ends up being best-effort. As an approximation, we use - // whatever pattern happened to be at the binding of the variable for all - // of the cases where it's used for equality. For example, in the following - // rule: - // - // > (rule (example x @ (Enum.Variant y) x) ...) - // - // we will only specialize up to `(Enum.Variant _)`, so any more specific - // runtime values of `y` won't end up helping to identify overlap. As a - // result, we rewrite the patterns in the rule to look more like the - // following, as it greatly simplifies overlap checking. - // - // > (rule (example (Enum.Variant _) (Enum.Variant _)) ...) - // - // Cases that this scheme won't handle look like the following: - // - // > (rule (example2 2 3) ...) - // > (rule (example2 x x) ...) - // - // As in this case we'll not make use of the information that `2` and `3` - // aren't equal to know that the rules don't overlap. One approach that we - // could take here is delaying substitution to the point where a variable - // binding has been specialized, turning the rules into the following once - // specialization had occurred for `2`: - // - // > (rule (example2 2 3) ...) - // > (rule (example2 2 2) ...) - return pat.clone(); - } - } - - binds.push((*id, Pattern::Wildcard)); - Pattern::Wildcard - } - - sema::Pattern::ConstInt(_, value) => Pattern::Int { value: *value }, - sema::Pattern::ConstPrim(_, name) => Pattern::Const { name: *name }, - - &sema::Pattern::Term(_, id, ref pats) => { - let pats = pats - .iter() - .map(|pat| Pattern::from_sema(env, binds, pat)) - .collect(); - - match &env.terms[id.0].kind { - TermKind::EnumVariant { .. } => Pattern::Variant { id, pats }, - TermKind::Decl { .. } => Pattern::Extractor { id, pats }, - } - } - - sema::Pattern::Wildcard(_) => Pattern::Wildcard, - - sema::Pattern::And(_, pats) => { - let pats = pats - .iter() - .map(|pat| Pattern::from_sema(env, binds, pat)) - .collect(); - Pattern::And { pats } - } - } - } - - /// If this is an and-pattern, return its subpatterns. Otherwise pretend like there's an - /// and-pattern which has this as its only subpattern, and return self as a single-element - /// slice. - fn as_and_subpatterns(&self) -> &[Pattern] { - if let Pattern::And { pats } = self { - pats - } else { - std::slice::from_ref(self) - } - } -} diff --git a/cranelift/isle/isle/src/trie_again.rs b/cranelift/isle/isle/src/trie_again.rs new file mode 100644 index 0000000000..5b32989d55 --- /dev/null +++ b/cranelift/isle/isle/src/trie_again.rs @@ -0,0 +1,676 @@ +//! A strongly-normalizing intermediate representation for ISLE rules. This representation is chosen +//! to closely reflect the operations we can implement in Rust, to make code generation easy. +use crate::error::{Error, Source, Span}; +use crate::lexer::Pos; +use crate::sema::{self, RuleVisitor}; +use crate::DisjointSets; +use std::collections::{hash_map::Entry, HashMap}; + +/// A field index in a tuple or an enum variant. +#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct TupleIndex(u8); +/// A hash-consed identifier for a binding, stored in a [RuleSet]. +#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct BindingId(u16); +/// A hash-consed identifier for an expression, stored in a [RuleSet]. +#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct ExprId(u16); + +impl BindingId { + /// Get the index of this id. + pub fn index(self) -> usize { + self.0.into() + } +} + +impl ExprId { + /// Get the index of this id. + pub fn index(self) -> usize { + self.0.into() + } +} + +/// Expressions construct new values. Rust pattern matching can only destructure existing values, +/// not call functions or construct new values. So `if-let` and external extractor invocations need +/// to interrupt pattern matching in order to evaluate a suitable expression. These expressions are +/// also used when evaluating the right-hand side of a rule. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub enum Expr { + /// A binding from some sequence of pattern matches, used as an expression. + Binding { + /// Which binding site is being used as an expression? + source: BindingId, + }, + /// Evaluates to the given integer literal. + ConstInt { + /// The constant value. + val: i128, + }, + /// Evaluates to the given primitive Rust value. + ConstPrim { + /// The constant value. + val: sema::Sym, + }, + /// One of the arguments to the top-level function. + Argument { + /// Which of the function's arguments is this? + index: TupleIndex, + }, + /// The result of calling an external extractor. + Extractor { + /// Which extractor should be called? + term: sema::TermId, + /// What expression should be passed to the extractor? + parameter: ExprId, + }, + /// The result of calling an external constructor. + Constructor { + /// Which constructor should be called? + term: sema::TermId, + /// What expressions should be passed to the constructor? + parameters: Box<[ExprId]>, + }, + /// The result of constructing an enum variant. + Variant { + /// Which enum type should be constructed? + ty: sema::TypeId, + /// Which variant of that enum should be constructed? + variant: sema::VariantId, + /// What expressions should be provided for this variant's fields? + fields: Box<[ExprId]>, + }, +} + +/// Binding sites are the result of Rust pattern matching. This is the dual of an expression: while +/// expressions build up values, bindings take values apart. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum Binding { + /// A match begins at the result of some expression that produces a Rust value. + Expr { + /// Which expression is being matched? + constructor: ExprId, + }, + /// After some sequence of matches, we'll match one of the previous bindings against an enum + /// variant and produce a new binding from one of its fields. There must be a matching + /// [Constraint] for each `source`/`variant` pair that appears in a binding. + Variant { + /// Which binding is being matched? + source: BindingId, + /// Which enum variant are we pulling binding sites from? This is somewhat redundant with + /// information in a corresponding [Constraint]. However, it must be here so that different + /// enum variants aren't hash-consed into the same binding site. + variant: sema::VariantId, + /// Which field of this enum variant are we projecting out? Although ISLE uses named fields, + /// we track them by index for constant-time comparisons. The [sema::TypeEnv] can be used to + /// get the field names. + field: TupleIndex, + }, + /// After some sequence of matches, we'll match one of the previous bindings against + /// `Option::Some` and produce a new binding from its contents. (This currently only happens + /// with external extractors.) + Some { + /// Which binding is being matched? + source: BindingId, + }, + /// After some sequence of matches, we'll match one of the previous bindings against a tuple and + /// produce a new binding from one of its fields. (This currently only happens with external + /// extractors.) + Tuple { + /// Which binding is being matched? + source: BindingId, + /// Which tuple field are we projecting out? + field: TupleIndex, + }, +} + +/// Pattern matches which can fail. Some binding sites are the result of successfully matching a +/// constraint. A rule applies constraints to binding sites to determine whether the rule matches. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum Constraint { + /// The value must match this enum variant. + Variant { + /// Which enum type is being matched? This is implied by the binding where the constraint is + /// applied, but recorded here for convenience. + ty: sema::TypeId, + /// Which enum variant must this binding site match to satisfy the rule? + variant: sema::VariantId, + /// Number of fields in this variant of this enum. This is recorded in the constraint for + /// convenience, to avoid needing to look up the variant in a [sema::TypeEnv]. + fields: TupleIndex, + }, + /// The value must equal this integer literal. + ConstInt { + /// The constant value. + val: i128, + }, + /// The value must equal this Rust primitive value. + ConstPrim { + /// The constant value. + val: sema::Sym, + }, + /// The value must be an `Option::Some`, from a fallible extractor. + Some, +} + +/// A term-rewriting rule. All [BindingId]s and [ExprId]s are only meaningful in the context of the +/// [RuleSet] that contains this rule. +#[derive(Debug, Default)] +pub struct Rule { + /// Where was this rule defined? + pub pos: Pos, + /// All of these bindings must match for this rule to apply. Note that within a single rule, if + /// a binding site must match two different constants, then the rule can never match. + constraints: HashMap, + /// Sets of bindings which must be equal for this rule to match. + pub equals: DisjointSets, + /// If other rules apply along with this one, the one with the highest numeric priority is + /// evaluated. If multiple applicable rules have the same priority, that's an overlap error. + pub prio: i64, + /// If this rule applies, the top-level term should evaluate to this expression. + pub result: ExprId, +} + +/// Records whether a given pair of rules can both match on some input. +pub enum Overlap { + /// There is no input on which this pair of rules can both match. + No, + /// There is at least one input on which this pair of rules can both match. + Yes { + /// True if every input accepted by one rule is also accepted by the other. This does not + /// indicate which rule is more general and in fact the rules could match exactly the same + /// set of inputs. You can work out which by comparing the number of constraints in both + /// rules: The more general rule has fewer constraints. + subset: bool, + }, +} + +/// A collection of [Rule]s, along with hash-consed [Binding]s and [Expr]s for all of them. +#[derive(Debug, Default)] +pub struct RuleSet { + /// The [Rule]s for a single [sema::Term]. + pub rules: Vec, + /// The bindings identified by [BindingId]s within rules. + pub bindings: Vec, + /// The expressions identified by [ExprId]s within rules. + pub exprs: Vec, +} + +/// Construct a [RuleSet] for each term in `termenv` that has rules. +pub fn build( + termenv: &sema::TermEnv, + tyenv: &sema::TypeEnv, +) -> (Vec<(sema::TermId, RuleSet)>, Vec) { + let mut errors = Vec::new(); + let mut term = HashMap::new(); + for rule in termenv.rules.iter() { + term.entry(rule.lhs.root_term().unwrap()) + .or_insert_with(RuleSetBuilder::default) + .add_rule(rule, termenv, tyenv, &mut errors); + } + + // The `term` hash map may return terms in any order. Sort them to ensure that we produce the + // same output every time when given the same ISLE source. Rules are added to terms in `RuleId` + // order, so it's not necessary to sort within a `RuleSet`. + let mut result: Vec<_> = term + .into_iter() + .map(|(term, builder)| (term, builder.rules)) + .collect(); + result.sort_unstable_by_key(|(term, _)| *term); + + (result, errors) +} + +impl Rule { + /// Returns whether a given pair of rules can both match on some input, and if so, whether + /// either matches a subset of the other's inputs. If this function returns `No`, then the two + /// rules definitely do not overlap. However, it may return `Yes` in cases where the rules can't + /// overlap in practice, or where this analysis is not yet precise enough to decide. + pub fn may_overlap(&self, other: &Rule) -> Overlap { + // Two rules can't overlap if, for some binding site in the intersection of their + // constraints, the rules have different constraints: an input can't possibly match both + // rules then. If the rules do overlap, and one has a subset of the constraints of the + // other, then the less-constrained rule matches every input that the more-constrained rule + // matches, and possibly more. We test for both conditions at once, with the observation + // that if the intersection of two sets is equal to the smaller set, then it's a subset. So + // the outer loop needs to go over the rule with fewer constraints in order to correctly + // identify if it's a subset of the other rule. Also, that way around is faster. + let (small, big) = if self.constraints.len() <= other.constraints.len() { + (self, other) + } else { + (other, self) + }; + + // TODO: nonlinear constraints complicate the subset check + // For the purpose of overlap checking, equality constraints act like other constraints, in + // that they can cause rules to not overlap. However, because we don't have a concrete + // pattern to compare, the analysis to prove that is complicated. For now, we approximate + // the result. If `small` has any of these nonlinear constraints, conservatively report that + // it is not a subset of `big`. + let mut subset = small.equals.is_empty(); + + for (binding, a) in small.constraints.iter() { + if let Some(b) = big.constraints.get(binding) { + if a != b { + // If any binding site is constrained differently by both rules then there is + // no input where both rules can match. + return Overlap::No; + } + // Otherwise both are constrained in the same way at this binding site. That doesn't + // rule out any possibilities for what inputs the rules accept. + } else { + // The `big` rule's inputs are a subset of the `small` rule's inputs if every + // constraint in `small` is exactly matched in `big`. But we found a counterexample. + subset = false; + } + } + Overlap::Yes { subset } + } + + /// Returns the constraint that the given binding site must satisfy for this rule to match, if + /// there is one. + pub fn get_constraint(&self, source: BindingId) -> Option { + self.constraints.get(&source).copied() + } + + fn set_constraint( + &mut self, + source: BindingId, + constraint: Constraint, + ) -> Result<(), UnreachableError> { + match self.constraints.entry(source) { + Entry::Occupied(entry) => { + if entry.get() != &constraint { + return Err(UnreachableError { + pos: self.pos, + constraint_a: *entry.get(), + constraint_b: constraint, + }); + } + } + Entry::Vacant(entry) => { + entry.insert(constraint); + } + } + Ok(()) + } +} + +#[derive(Debug)] +struct UnreachableError { + pos: Pos, + constraint_a: Constraint, + constraint_b: Constraint, +} + +#[derive(Debug, Default)] +struct RuleSetBuilder { + current_rule: Rule, + binding_map: HashMap, + expr_map: HashMap, + unreachable: Vec, + rules: RuleSet, +} + +impl RuleSetBuilder { + fn add_rule( + &mut self, + rule: &sema::Rule, + termenv: &sema::TermEnv, + tyenv: &sema::TypeEnv, + errors: &mut Vec, + ) { + self.current_rule.pos = rule.pos; + self.current_rule.prio = rule.prio; + self.current_rule.result = rule.visit(self, termenv); + self.normalize_equivalence_classes(); + let rule = std::mem::take(&mut self.current_rule); + + if self.unreachable.is_empty() { + self.rules.rules.push(rule); + } else { + // If this rule can never match, drop it so it doesn't affect overlap checking. + errors.extend(self.unreachable.drain(..).map(|err| { + let src = Source::new( + tyenv.filenames[err.pos.file].clone(), + tyenv.file_texts[err.pos.file].clone(), + ); + Error::UnreachableError { + msg: format!( + "rule requires binding to match both {:?} and {:?}", + err.constraint_a, err.constraint_b + ), + src, + span: Span::new_single(err.pos), + } + })) + } + } + + /// Establish the invariant that a binding site can have a concrete constraint in `constraints`, + /// or an equality constraint in `equals`, but not both. This is useful because overlap checking + /// is most effective on concrete constraints, and also because it exposes more rule structure + /// for codegen. + /// + /// If a binding site is constrained and also required to be equal to another binding site, then + /// copy the constraint and push the equality inside it. For example: + /// - `(term x @ 2 x)` is rewritten to `(term 2 2)` + /// - `(term x @ (T.A _ _) x)` is rewritten to `(term (T.A y z) (T.A y z))` + /// In the latter case, note that every field of `T.A` has been replaced with a fresh variable + /// and each of the copies are set equal. + /// + /// If several binding sites are supposed to be equal but they each have conflicting constraints + /// then this rule is unreachable. For example, `(term x @ 2 (and x 3))` requires both arguments + /// to be equal but also requires them to match both 2 and 3, which can't happen for any input. + /// + /// We could do this incrementally, while building the rule. The implementation is nearly + /// identical but, having tried both ways, it's slightly easier to think about this as a + /// separate pass. Also, batching up this work should be slightly faster if there are multiple + /// binding sites set equal to each other. + fn normalize_equivalence_classes(&mut self) { + // First, find all the constraints that need to be copied to other binding sites in their + // respective equivalence classes. Note: do not remove these constraints here! Yes, we'll + // put them back later, but we rely on still having them around so that + // `set_constraint_or_error` can detect conflicting constraints. + let mut deferred_constraints = Vec::new(); + for (&binding, &constraint) in self.current_rule.constraints.iter() { + if let Some(root) = self.current_rule.equals.find_mut(binding) { + deferred_constraints.push((root, constraint)); + } + } + + // Pick one constraint and propagate it through its equivalence class. If there are no + // errors then it doesn't matter what order we do this in, because that means that any + // redundant constraints on an equivalence class were equal. We can write equal values into + // the constraint map in any order and get the same result. If there were errors, we aren't + // going to generate code from this rule, so order only affects how conflicts are reported. + while let Some((current, constraint)) = deferred_constraints.pop() { + // Remove the entire equivalence class and instead add copies of this constraint to + // every binding site in the class. If there are constraints on other binding sites in + // this class, then when we try to copy this constraint to those binding sites, + // `set_constraint_or_error` will check that the constraints are equal and record an + // appropriate error otherwise. + // + // Later, we'll re-visit those other binding sites because they're still in + // `deferred_constraints`, but `set` will be empty because we already deleted the + // equivalence class the first time we encountered it. + let set = self.current_rule.equals.remove_set_of(current); + match (constraint, set.split_first()) { + // If the equivalence class was empty we don't have to do anything. + (_, None) => continue, + + // If we removed an equivalence class with an enum variant constraint, make the + // fields of the variant equal instead. Create a binding for every field of every + // member of `set`. Arbitrarily pick one to set all the others equal to. If there + // are existing constraints on the new fields, copy those around the new equivalence + // classes too. + ( + Constraint::Variant { + fields, variant, .. + }, + Some((&base, rest)), + ) => { + let base_fields = + self.field_bindings(base, fields, variant, &mut deferred_constraints); + for &binding in rest { + for (&x, &y) in self + .field_bindings(binding, fields, variant, &mut deferred_constraints) + .iter() + .zip(base_fields.iter()) + { + self.current_rule.equals.merge(x, y); + } + } + } + + // These constraints don't introduce new binding sites. + (Constraint::ConstInt { .. } | Constraint::ConstPrim { .. }, _) => {} + + // Currently, `Some` constraints are only introduced implicitly during the + // translation from `sema`, so there's no way to set the corresponding binding + // sites equal to each other. Instead, any equality constraints get applied on + // the results of matching `Some()` or tuple patterns. + (Constraint::Some, _) => unreachable!(), + } + + for binding in set { + self.set_constraint_or_error(binding, constraint); + } + } + } + + fn field_bindings( + &mut self, + binding: BindingId, + fields: TupleIndex, + variant: sema::VariantId, + deferred_constraints: &mut Vec<(BindingId, Constraint)>, + ) -> Box<[BindingId]> { + (0..fields.0) + .map(TupleIndex) + .map(move |field| { + let binding = self.dedup_binding(Binding::Variant { + source: binding, + variant, + field, + }); + // We've just added an equality constraint to a binding site that may not have had + // one already. If that binding site already had a concrete constraint, then we need + // to "recursively" propagate that constraint through the new equivalence class too. + if let Some(constraint) = self.current_rule.get_constraint(binding) { + deferred_constraints.push((binding, constraint)); + } + binding + }) + .collect() + } + + fn dedup_binding(&mut self, binding: Binding) -> BindingId { + if let Some(binding) = self.binding_map.get(&binding) { + *binding + } else { + let id = BindingId(self.rules.bindings.len().try_into().unwrap()); + self.rules.bindings.push(binding.clone()); + self.binding_map.insert(binding, id); + id + } + } + + fn dedup_expr(&mut self, expr: Expr) -> ExprId { + if let Some(expr) = self.expr_map.get(&expr) { + *expr + } else { + let id = ExprId(self.rules.exprs.len().try_into().unwrap()); + self.rules.exprs.push(expr.clone()); + self.expr_map.insert(expr, id); + id + } + } + + fn set_constraint(&mut self, input: Binding, constraint: Constraint) -> BindingId { + let input = self.dedup_binding(input); + self.set_constraint_or_error(input, constraint); + input + } + + fn set_constraint_or_error(&mut self, input: BindingId, constraint: Constraint) { + if let Err(e) = self.current_rule.set_constraint(input, constraint) { + self.unreachable.push(e); + } + } +} + +impl sema::PatternVisitor for RuleSetBuilder { + /// The "identifier" this visitor uses for binding sites is a [Binding], not a [BindingId]. + /// Either choice would work but this approach avoids adding bindings to the [RuleSet] if they + /// are never used in any rule. + type PatternId = Binding; + + fn add_match_equal(&mut self, a: Binding, b: Binding, _ty: sema::TypeId) { + let a = self.dedup_binding(a); + let b = self.dedup_binding(b); + // If both bindings represent the same binding site, they're implicitly equal. + if a != b { + self.current_rule.equals.merge(a, b); + } + } + + fn add_match_int(&mut self, input: Binding, _ty: sema::TypeId, val: i128) { + self.set_constraint(input, Constraint::ConstInt { val }); + } + + fn add_match_prim(&mut self, input: Binding, _ty: sema::TypeId, val: sema::Sym) { + self.set_constraint(input, Constraint::ConstPrim { val }); + } + + fn add_match_variant( + &mut self, + input: Binding, + input_ty: sema::TypeId, + arg_tys: &[sema::TypeId], + variant: sema::VariantId, + ) -> Vec { + let fields = TupleIndex(arg_tys.len().try_into().unwrap()); + let source = self.set_constraint( + input, + Constraint::Variant { + fields, + ty: input_ty, + variant, + }, + ); + (0..fields.0) + .map(TupleIndex) + .map(|field| Binding::Variant { + source, + variant, + field, + }) + .collect() + } + + fn add_extract( + &mut self, + input: Binding, + _input_ty: sema::TypeId, + output_tys: Vec, + term: sema::TermId, + infallible: bool, + _multi: bool, + ) -> Vec { + // ISLE treats external extractors as patterns, but in this representation they're + // expressions, because Rust doesn't support calling functions during pattern matching. To + // glue the two representations together we have to introduce suitable adapter nodes. + let input = self.pattern_as_expr(input); + let input = self.dedup_expr(Expr::Extractor { + term, + parameter: input, + }); + let input = self.expr_as_pattern(input); + + // If the extractor is fallible, build a pattern and constraint for `Some` + let source = if infallible { + input + } else { + let source = self.set_constraint(input, Constraint::Some); + Binding::Some { source } + }; + + // If the extractor has multiple outputs, create a separate binding for each + match output_tys.len().try_into().unwrap() { + 0 => vec![], + 1 => vec![source], + outputs => { + let source = self.dedup_binding(source); + (0..outputs) + .map(TupleIndex) + .map(|field| Binding::Tuple { source, field }) + .collect() + } + } + } +} + +impl sema::ExprVisitor for RuleSetBuilder { + /// Unlike the `PatternVisitor` implementation, we use [ExprId] to identify intermediate + /// expressions, not [Expr]. Visited expressions are always used so we might as well deduplicate + /// them eagerly. + type ExprId = ExprId; + + fn add_const_int(&mut self, _ty: sema::TypeId, val: i128) -> ExprId { + self.dedup_expr(Expr::ConstInt { val }) + } + + fn add_const_prim(&mut self, _ty: sema::TypeId, val: sema::Sym) -> ExprId { + self.dedup_expr(Expr::ConstPrim { val }) + } + + fn add_create_variant( + &mut self, + inputs: Vec<(ExprId, sema::TypeId)>, + ty: sema::TypeId, + variant: sema::VariantId, + ) -> ExprId { + self.dedup_expr(Expr::Variant { + ty, + variant, + fields: inputs.into_iter().map(|(expr, _)| expr).collect(), + }) + } + + fn add_construct( + &mut self, + inputs: Vec<(ExprId, sema::TypeId)>, + _ty: sema::TypeId, + term: sema::TermId, + _infallible: bool, + _multi: bool, + ) -> ExprId { + self.dedup_expr(Expr::Constructor { + term, + parameters: inputs.into_iter().map(|(expr, _)| expr).collect(), + }) + } +} + +impl sema::RuleVisitor for RuleSetBuilder { + type PatternVisitor = Self; + type ExprVisitor = Self; + type Expr = ExprId; + + fn add_arg(&mut self, index: usize, _ty: sema::TypeId) -> Binding { + // Arguments don't need to be pattern-matched to reference them, so they're expressions + let index = TupleIndex(index.try_into().unwrap()); + let expr = self.dedup_expr(Expr::Argument { index }); + Binding::Expr { constructor: expr } + } + + fn add_pattern(&mut self, visitor: F) { + visitor(self) + } + + fn add_expr(&mut self, visitor: F) -> ExprId + where + F: FnOnce(&mut Self) -> sema::VisitedExpr, + { + visitor(self).value + } + + fn expr_as_pattern(&mut self, expr: ExprId) -> Binding { + if let &Expr::Binding { source: binding } = &self.rules.exprs[expr.index()] { + // Short-circuit wrapping a binding around an expr from another binding + self.rules.bindings[binding.index()] + } else { + Binding::Expr { constructor: expr } + } + } + + fn pattern_as_expr(&mut self, pattern: Binding) -> ExprId { + if let Binding::Expr { constructor } = pattern { + // Short-circuit wrapping an expr around a binding from another expr + constructor + } else { + let binding = self.dedup_binding(pattern); + self.dedup_expr(Expr::Binding { source: binding }) + } + } +}