Assorted ISLE changes to prep for new codegen (#5441)

* Multi-extractors should only be used in multi-terms

* ISLE int literals should be in range for their type

See #5431 and #5423.

* Make StableSet usable in public interfaces

Also implement an immutable version of DisjointSets::find_mut.

* Return analyzed terms from overlap check

If the caller wants the `trie_again::RuleSet` for a term, don't make
them recompute it.

* Expose binding lookups and sources

* Don't dedup or prune impure constructor calls

* Record int types for bindings and constraints

This means that bindings for constant integers that have the same value
but not the same type no longer hash-cons into the same binding ID.

* Track binding sites from calling multi-terms

* Implement more traits
This commit is contained in:
Jamey Sharp
2022-12-14 14:41:29 -08:00
committed by GitHub
parent be710df237
commit e03d65cca7
8 changed files with 156 additions and 49 deletions

View File

@@ -5,7 +5,7 @@
(extern extractor E1 e1_etor) (extern extractor E1 e1_etor)
(decl Rule (u32) u32) (decl multi Rule (u32) u32)
(rule 1 (Rule (E1 a idx)) (rule 1 (Rule (E1 a idx))
(if-let (A.B) a) (if-let (A.B) a)

View File

@@ -1,5 +1,9 @@
mod multi_extractor; mod multi_extractor;
use multi_extractor::ContextIter;
pub(crate) type ConstructorVec<T> = Vec<T>;
#[derive(Clone)] #[derive(Clone)]
pub enum A { pub enum A {
B, B,
@@ -40,7 +44,7 @@ impl multi_extractor::Context for Context {
fn main() { fn main() {
let mut ctx = Context; let mut ctx = Context;
let x = multi_extractor::constructor_Rule(&mut ctx, 0xf0); let x = multi_extractor::constructor_Rule(&mut ctx, 0xf0).next(&mut ctx);
let y = multi_extractor::constructor_Rule(&mut ctx, 0); let y = multi_extractor::constructor_Rule(&mut ctx, 0).next(&mut ctx);
println!("x = {:?} y = {:?}", x, y); println!("x = {:?} y = {:?}", x, y);
} }

View File

@@ -3,8 +3,8 @@
(decl partial X (i64) i64) (decl partial X (i64) i64)
(rule (X -1) -2) (rule (X -1) -2)
(rule (X -2) -3) (rule (X -2) -3)
(rule (X 0x7fff_ffff_ffff_ffff) 0x8000_0000_0000_0000) (rule (X 0x7fff_ffff_ffff_ffff) -0x8000_0000_0000_0000)
(rule (X 0xffff_ffff_ffff_fff0) 1) (rule (X -16) 1)
(type i128 (primitive i128)) (type i128 (primitive i128))

View File

@@ -398,6 +398,7 @@ impl ExprVisitor for ExprSequence {
inputs: Vec<(Value, TypeId)>, inputs: Vec<(Value, TypeId)>,
ty: TypeId, ty: TypeId,
term: TermId, term: TermId,
_pure: bool,
infallible: bool, infallible: bool,
multi: bool, multi: bool,
) -> Value { ) -> Value {

View File

@@ -25,8 +25,8 @@ macro_rules! declare_id {
/// A wrapper around a [HashSet] which prevents accidentally observing the non-deterministic /// A wrapper around a [HashSet] which prevents accidentally observing the non-deterministic
/// iteration order. /// iteration order.
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
struct StableSet<T>(HashSet<T>); pub struct StableSet<T>(HashSet<T>);
impl<T> StableSet<T> { impl<T> StableSet<T> {
fn new() -> Self { fn new() -> Self {
@@ -35,11 +35,13 @@ impl<T> StableSet<T> {
} }
impl<T: Hash + Eq> StableSet<T> { impl<T: Hash + Eq> StableSet<T> {
fn insert(&mut self, val: T) -> bool { /// Adds a value to the set. Returns whether the value was newly inserted.
pub fn insert(&mut self, val: T) -> bool {
self.0.insert(val) self.0.insert(val)
} }
fn contains(&self, val: &T) -> bool { /// Returns true if the set contains a value.
pub fn contains(&self, val: &T) -> bool {
self.0.contains(val) self.0.contains(val)
} }
} }
@@ -59,6 +61,7 @@ impl<K, V> StableMap<K, V> {
} }
} }
// NOTE: Can't auto-derive this
impl<K, V> Default for StableMap<K, V> { impl<K, V> Default for StableMap<K, V> {
fn default() -> Self { fn default() -> Self {
StableMap(HashMap::new()) StableMap(HashMap::new())
@@ -125,6 +128,28 @@ impl<T: Copy + std::fmt::Debug + Eq + Hash> DisjointSets<T> {
None None
} }
/// Find a representative member of the set containing `x`. If `x` has not been merged with any
/// other items using `merge`, returns `None`. This method does not update the data structure to
/// make future queries faster, so `find_mut` should be preferred.
///
/// ```
/// let mut sets = cranelift_isle::DisjointSets::default();
/// sets.merge(1, 2);
/// sets.merge(1, 3);
/// sets.merge(2, 4);
/// assert_eq!(sets.find(3).unwrap(), sets.find(4).unwrap());
/// assert_eq!(sets.find(10), None);
/// ```
pub fn find(&self, mut x: T) -> Option<T> {
while let Some(node) = self.parent.get(&x) {
if node.0 == x {
return Some(x);
}
x = node.0;
}
None
}
/// Merge the set containing `x` with the set containing `y`. This method takes amortized /// Merge the set containing `x` with the set containing `y`. This method takes amortized
/// constant time. /// constant time.
pub fn merge(&mut self, x: T, y: T) { pub fn merge(&mut self, x: T, y: T) {

View File

@@ -9,12 +9,15 @@ use crate::sema::{TermEnv, TermId, TermKind, TypeEnv};
use crate::trie_again; use crate::trie_again;
/// Check for overlap. /// Check for overlap.
pub fn check(tyenv: &TypeEnv, termenv: &TermEnv) -> Result<(), error::Errors> { pub fn check(
tyenv: &TypeEnv,
termenv: &TermEnv,
) -> Result<Vec<(TermId, trie_again::RuleSet)>, error::Errors> {
let (terms, mut errors) = trie_again::build(termenv); let (terms, mut errors) = trie_again::build(termenv);
errors.append(&mut check_overlaps(terms, termenv).report()); errors.append(&mut check_overlaps(&terms, termenv).report());
if errors.is_empty() { if errors.is_empty() {
Ok(()) Ok(terms)
} else { } else {
Err(error::Errors { Err(error::Errors {
errors, errors,
@@ -108,7 +111,7 @@ impl Errors {
/// Determine if any rules overlap in the input that they accept. This checks every unique pair of /// Determine if any rules overlap in the input that they accept. This checks every unique pair of
/// rules, as checking rules in aggregate tends to suffer from exponential explosion in the /// rules, as checking rules in aggregate tends to suffer from exponential explosion in the
/// presence of wildcard patterns. /// presence of wildcard patterns.
fn check_overlaps(terms: Vec<(TermId, trie_again::RuleSet)>, env: &TermEnv) -> Errors { fn check_overlaps(terms: &[(TermId, trie_again::RuleSet)], env: &TermEnv) -> Errors {
let mut errs = Errors::default(); let mut errs = Errors::default();
for (tid, ruleset) in terms { for (tid, ruleset) in terms {
let is_multi_ctor = match &env.terms[tid.index()].kind { let is_multi_ctor = match &env.terms[tid.index()].kind {

View File

@@ -707,6 +707,7 @@ pub trait ExprVisitor {
inputs: Vec<(Self::ExprId, TypeId)>, inputs: Vec<(Self::ExprId, TypeId)>,
ty: TypeId, ty: TypeId,
term: TermId, term: TermId,
pure: bool,
infallible: bool, infallible: bool,
multi: bool, multi: bool,
) -> Self::ExprId; ) -> Self::ExprId;
@@ -768,6 +769,7 @@ impl Expr {
arg_values_tys, arg_values_tys,
ty, ty,
term, term,
flags.pure,
/* infallible = */ !flags.partial, /* infallible = */ !flags.partial,
flags.multi, flags.multi,
) )
@@ -1997,6 +1999,8 @@ impl TermEnv {
termdata.check_args_count(args, tyenv, pos, sym); termdata.check_args_count(args, tyenv, pos, sym);
// TODO: check that multi-extractors are only used in terms declared `multi`
match &termdata.kind { match &termdata.kind {
TermKind::EnumVariant { .. } => {} TermKind::EnumVariant { .. } => {}
TermKind::Decl { TermKind::Decl {
@@ -2166,7 +2170,7 @@ impl TermEnv {
tyenv.report_error( tyenv.report_error(
pos, pos,
format!( format!(
"Used multi-term '{}' but this rule is not in a multi-term", "Used multi-constructor '{}' but this rule is not in a multi-term",
sym.0 sym.0
), ),
); );

View File

@@ -3,7 +3,7 @@
use crate::error::{Error, Span}; use crate::error::{Error, Span};
use crate::lexer::Pos; use crate::lexer::Pos;
use crate::sema; use crate::sema;
use crate::DisjointSets; use crate::{DisjointSets, StableSet};
use std::collections::{hash_map::Entry, HashMap}; use std::collections::{hash_map::Entry, HashMap};
/// A field index in a tuple or an enum variant. /// A field index in a tuple or an enum variant.
@@ -13,6 +13,29 @@ pub struct TupleIndex(u8);
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct BindingId(u16); pub struct BindingId(u16);
impl std::convert::TryFrom<usize> for TupleIndex {
type Error = <u8 as std::convert::TryFrom<usize>>::Error;
fn try_from(value: usize) -> Result<Self, Self::Error> {
Ok(TupleIndex(value.try_into()?))
}
}
impl std::convert::TryFrom<usize> for BindingId {
type Error = <u16 as std::convert::TryFrom<usize>>::Error;
fn try_from(value: usize) -> Result<Self, Self::Error> {
Ok(BindingId(value.try_into()?))
}
}
impl TupleIndex {
/// Get the index of this field.
pub fn index(self) -> usize {
self.0.into()
}
}
impl BindingId { impl BindingId {
/// Get the index of this id. /// Get the index of this id.
pub fn index(self) -> usize { pub fn index(self) -> usize {
@@ -28,6 +51,8 @@ pub enum Binding {
ConstInt { ConstInt {
/// The constant value. /// The constant value.
val: i128, val: i128,
/// The constant's type. Unsigned types preserve the representation of `val`, not its value.
ty: sema::TypeId,
}, },
/// Evaluates to the given primitive Rust value. /// Evaluates to the given primitive Rust value.
ConstPrim { ConstPrim {
@@ -52,6 +77,14 @@ pub enum Binding {
term: sema::TermId, term: sema::TermId,
/// What expressions should be passed to the constructor? /// What expressions should be passed to the constructor?
parameters: Box<[BindingId]>, parameters: Box<[BindingId]>,
/// For impure constructors, a unique number for each use of this term. Always 0 for pure
/// constructors.
instance: u32,
},
/// The result of getting one value from a multi-constructor or multi-extractor.
Iterator {
/// Which expression produced the iterator that this consumes?
source: BindingId,
}, },
/// The result of constructing an enum variant. /// The result of constructing an enum variant.
MakeVariant { MakeVariant {
@@ -97,7 +130,7 @@ pub enum Binding {
/// Pattern matches which can fail. Some binding sites are the result of successfully matching a /// Pattern matches which can fail. Some binding sites are the result of successfully matching a
/// constraint. A rule applies constraints to binding sites to determine whether the rule matches. /// constraint. A rule applies constraints to binding sites to determine whether the rule matches.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum Constraint { pub enum Constraint {
/// The value must match this enum variant. /// The value must match this enum variant.
Variant { Variant {
@@ -114,6 +147,8 @@ pub enum Constraint {
ConstInt { ConstInt {
/// The constant value. /// The constant value.
val: i128, val: i128,
/// The constant's type. Unsigned types preserve the representation of `val`, not its value.
ty: sema::TypeId,
}, },
/// The value must equal this Rust primitive value. /// The value must equal this Rust primitive value.
ConstPrim { ConstPrim {
@@ -136,14 +171,19 @@ pub struct Rule {
constraints: HashMap<BindingId, Constraint>, constraints: HashMap<BindingId, Constraint>,
/// Sets of bindings which must be equal for this rule to match. /// Sets of bindings which must be equal for this rule to match.
pub equals: DisjointSets<BindingId>, pub equals: DisjointSets<BindingId>,
/// These bindings are from multi-terms which need to be evaluated in this rule.
pub iterators: StableSet<BindingId>,
/// If other rules apply along with this one, the one with the highest numeric priority is /// If other rules apply along with this one, the one with the highest numeric priority is
/// evaluated. If multiple applicable rules have the same priority, that's an overlap error. /// evaluated. If multiple applicable rules have the same priority, that's an overlap error.
pub prio: i64, pub prio: i64,
/// If this rule applies, these side effects should be evaluated before returning.
pub impure: Vec<BindingId>,
/// If this rule applies, the top-level term should evaluate to this expression. /// If this rule applies, the top-level term should evaluate to this expression.
pub result: BindingId, pub result: BindingId,
} }
/// Records whether a given pair of rules can both match on some input. /// Records whether a given pair of rules can both match on some input.
#[derive(Debug, Eq, PartialEq)]
pub enum Overlap { pub enum Overlap {
/// There is no input on which this pair of rules can both match. /// There is no input on which this pair of rules can both match.
No, No,
@@ -164,6 +204,8 @@ pub struct RuleSet {
pub rules: Vec<Rule>, pub rules: Vec<Rule>,
/// The bindings identified by [BindingId]s within rules. /// The bindings identified by [BindingId]s within rules.
pub bindings: Vec<Binding>, pub bindings: Vec<Binding>,
/// Intern table for de-duplicating [Binding]s.
binding_map: HashMap<Binding, BindingId>,
} }
/// Construct a [RuleSet] for each term in `termenv` that has rules. /// Construct a [RuleSet] for each term in `termenv` that has rules.
@@ -188,6 +230,31 @@ pub fn build(termenv: &sema::TermEnv) -> (Vec<(sema::TermId, RuleSet)>, Vec<Erro
(result, errors) (result, errors)
} }
impl RuleSet {
/// Returns the [BindingId] corresponding to the given [Binding] within this rule-set, if any.
pub fn find_binding(&self, binding: &Binding) -> Option<BindingId> {
self.binding_map.get(binding).copied()
}
}
impl Binding {
/// Returns the binding sites which must be evaluated before this binding.
pub fn sources(&self) -> &[BindingId] {
match self {
Binding::ConstInt { .. } => &[][..],
Binding::ConstPrim { .. } => &[][..],
Binding::Argument { .. } => &[][..],
Binding::Extractor { parameter, .. } => std::slice::from_ref(parameter),
Binding::Constructor { parameters, .. } => &parameters[..],
Binding::Iterator { source } => std::slice::from_ref(source),
Binding::MakeVariant { fields, .. } => &fields[..],
Binding::MatchVariant { source, .. } => std::slice::from_ref(source),
Binding::MatchSome { source } => std::slice::from_ref(source),
Binding::MatchTuple { source, .. } => std::slice::from_ref(source),
}
}
}
impl Constraint { impl Constraint {
/// Return the nested [Binding]s from matching the given [Constraint] against the given [BindingId]. /// Return the nested [Binding]s from matching the given [Constraint] against the given [BindingId].
pub fn bindings_for(self, source: BindingId) -> Vec<Binding> { pub fn bindings_for(self, source: BindingId) -> Vec<Binding> {
@@ -305,13 +372,14 @@ struct UnreachableError {
#[derive(Debug, Default)] #[derive(Debug, Default)]
struct RuleSetBuilder { struct RuleSetBuilder {
current_rule: Rule, current_rule: Rule,
binding_map: HashMap<Binding, BindingId>, impure_instance: u32,
unreachable: Vec<UnreachableError>, unreachable: Vec<UnreachableError>,
rules: RuleSet, rules: RuleSet,
} }
impl RuleSetBuilder { impl RuleSetBuilder {
fn add_rule(&mut self, rule: &sema::Rule, termenv: &sema::TermEnv, errors: &mut Vec<Error>) { fn add_rule(&mut self, rule: &sema::Rule, termenv: &sema::TermEnv, errors: &mut Vec<Error>) {
self.impure_instance = 0;
self.current_rule.pos = rule.pos; self.current_rule.pos = rule.pos;
self.current_rule.prio = rule.prio; self.current_rule.prio = rule.prio;
self.current_rule.result = rule.visit(self, termenv); self.current_rule.result = rule.visit(self, termenv);
@@ -412,12 +480,12 @@ impl RuleSetBuilder {
} }
fn dedup_binding(&mut self, binding: Binding) -> BindingId { fn dedup_binding(&mut self, binding: Binding) -> BindingId {
if let Some(binding) = self.binding_map.get(&binding) { if let Some(binding) = self.rules.binding_map.get(&binding) {
*binding *binding
} else { } else {
let id = BindingId(self.rules.bindings.len().try_into().unwrap()); let id = BindingId(self.rules.bindings.len().try_into().unwrap());
self.rules.bindings.push(binding.clone()); self.rules.bindings.push(binding.clone());
self.binding_map.insert(binding, id); self.rules.binding_map.insert(binding, id);
id id
} }
} }
@@ -444,8 +512,8 @@ impl sema::PatternVisitor for RuleSetBuilder {
} }
} }
fn add_match_int(&mut self, input: BindingId, _ty: sema::TypeId, val: i128) { fn add_match_int(&mut self, input: BindingId, ty: sema::TypeId, val: i128) {
let bindings = self.set_constraint(input, Constraint::ConstInt { val }); let bindings = self.set_constraint(input, Constraint::ConstInt { val, ty });
debug_assert_eq!(bindings, &[]); debug_assert_eq!(bindings, &[]);
} }
@@ -479,7 +547,7 @@ impl sema::PatternVisitor for RuleSetBuilder {
output_tys: Vec<sema::TypeId>, output_tys: Vec<sema::TypeId>,
term: sema::TermId, term: sema::TermId,
infallible: bool, infallible: bool,
_multi: bool, multi: bool,
) -> Vec<BindingId> { ) -> Vec<BindingId> {
let source = self.dedup_binding(Binding::Extractor { let source = self.dedup_binding(Binding::Extractor {
term, term,
@@ -487,7 +555,10 @@ impl sema::PatternVisitor for RuleSetBuilder {
}); });
// If the extractor is fallible, build a pattern and constraint for `Some` // If the extractor is fallible, build a pattern and constraint for `Some`
let source = if infallible { let source = if multi {
self.current_rule.iterators.insert(source);
self.dedup_binding(Binding::Iterator { source })
} else if infallible {
source source
} else { } else {
let bindings = self.set_constraint(source, Constraint::Some); let bindings = self.set_constraint(source, Constraint::Some);
@@ -510,8 +581,8 @@ impl sema::PatternVisitor for RuleSetBuilder {
impl sema::ExprVisitor for RuleSetBuilder { impl sema::ExprVisitor for RuleSetBuilder {
type ExprId = BindingId; type ExprId = BindingId;
fn add_const_int(&mut self, _ty: sema::TypeId, val: i128) -> BindingId { fn add_const_int(&mut self, ty: sema::TypeId, val: i128) -> BindingId {
self.dedup_binding(Binding::ConstInt { val }) self.dedup_binding(Binding::ConstInt { val, ty })
} }
fn add_const_prim(&mut self, _ty: sema::TypeId, val: sema::Sym) -> BindingId { fn add_const_prim(&mut self, _ty: sema::TypeId, val: sema::Sym) -> BindingId {
@@ -536,23 +607,40 @@ impl sema::ExprVisitor for RuleSetBuilder {
inputs: Vec<(BindingId, sema::TypeId)>, inputs: Vec<(BindingId, sema::TypeId)>,
_ty: sema::TypeId, _ty: sema::TypeId,
term: sema::TermId, term: sema::TermId,
pure: bool,
infallible: bool, infallible: bool,
_multi: bool, multi: bool,
) -> BindingId { ) -> BindingId {
let instance = if pure {
0
} else {
self.impure_instance += 1;
self.impure_instance
};
let source = self.dedup_binding(Binding::Constructor { let source = self.dedup_binding(Binding::Constructor {
term, term,
parameters: inputs.into_iter().map(|(expr, _)| expr).collect(), parameters: inputs.into_iter().map(|(expr, _)| expr).collect(),
instance,
}); });
// If the constructor is fallible, build a pattern for `Some`, but not a constraint. If the // If the constructor is fallible, build a pattern for `Some`, but not a constraint. If the
// constructor is on the right-hand side of a rule then its failure is not considered when // constructor is on the right-hand side of a rule then its failure is not considered when
// deciding which rule to evaluate. Corresponding constraints are only added if this // deciding which rule to evaluate. Corresponding constraints are only added if this
// expression is subsequently used as a pattern; see `expr_as_pattern`. // expression is subsequently used as a pattern; see `expr_as_pattern`.
if infallible { let source = if multi {
self.current_rule.iterators.insert(source);
self.dedup_binding(Binding::Iterator { source })
} else if infallible {
source source
} else { } else {
self.dedup_binding(Binding::MatchSome { source }) self.dedup_binding(Binding::MatchSome { source })
};
if !pure {
self.current_rule.impure.push(source);
} }
source
} }
} }
@@ -580,28 +668,10 @@ impl sema::RuleVisitor for RuleSetBuilder {
fn expr_as_pattern(&mut self, expr: BindingId) -> BindingId { fn expr_as_pattern(&mut self, expr: BindingId) -> BindingId {
let mut todo = vec![expr]; let mut todo = vec![expr];
while let Some(expr) = todo.pop() { while let Some(expr) = todo.pop() {
match &self.rules.bindings[expr.index()] { let expr = &self.rules.bindings[expr.index()];
Binding::ConstInt { .. } | Binding::ConstPrim { .. } | Binding::Argument { .. } => { todo.extend_from_slice(expr.sources());
} if let &Binding::MatchSome { source } = expr {
Binding::Constructor {
parameters: sources,
..
}
| Binding::MakeVariant {
fields: sources, ..
} => todo.extend_from_slice(sources),
&Binding::Extractor {
parameter: source, ..
}
| &Binding::MatchVariant { source, .. }
| &Binding::MatchTuple { source, .. } => todo.push(source),
&Binding::MatchSome { source } => {
let _ = self.set_constraint(source, Constraint::Some); let _ = self.set_constraint(source, Constraint::Some);
todo.push(source);
}
} }
} }
expr expr