diff --git a/cranelift/isle/isle/src/ir.rs b/cranelift/isle/isle/src/ir.rs index d7c53307cf..9266ebe742 100644 --- a/cranelift/isle/isle/src/ir.rs +++ b/cranelift/isle/isle/src/ir.rs @@ -3,7 +3,6 @@ use crate::lexer::Pos; use crate::log; use crate::sema::*; -use crate::StableMap; declare_id!( /// The id of an instruction in a `PatternSequence`. @@ -239,33 +238,70 @@ impl ExprSequence { } } -#[derive(Clone, Copy, Debug)] -enum ValueOrArgs { - Value(Value), - ImplicitTermFromArgs(TermId), -} - -impl ValueOrArgs { - fn to_value(&self) -> Option { - match self { - &ValueOrArgs::Value(v) => Some(v), - _ => None, - } - } -} - impl PatternSequence { fn add_inst(&mut self, inst: PatternInst) -> InstId { let id = InstId(self.insts.len()); self.insts.push(inst); id } +} + +/// Used as an intermediate representation of expressions in the [RuleVisitor] implementation for +/// [PatternSequence]. +pub struct ReturnExpr { + seq: ExprSequence, + output: Value, + output_ty: TypeId, +} + +impl RuleVisitor for PatternSequence { + type PatternVisitor = Self; + type ExprVisitor = ExprSequence; + type Expr = ReturnExpr; fn add_arg(&mut self, index: usize, ty: TypeId) -> Value { let inst = self.add_inst(PatternInst::Arg { index, ty }); Value::Pattern { inst, output: 0 } } + fn add_pattern(&mut self, visitor: F) { + visitor(self) + } + + fn add_expr(&mut self, visitor: F) -> ReturnExpr + where + F: FnOnce(&mut ExprSequence) -> VisitedExpr, + { + let mut expr = ExprSequence::default(); + let VisitedExpr { ty, value } = visitor(&mut expr); + let index = 0; + expr.add_inst(ExprInst::Return { index, ty, value }); + ReturnExpr { + seq: expr, + output: value, + output_ty: ty, + } + } + + fn expr_as_pattern(&mut self, expr: ReturnExpr) -> Value { + let inst = self.add_inst(PatternInst::Expr { + seq: expr.seq, + output: expr.output, + output_ty: expr.output_ty, + }); + + // Create values for all outputs. + Value::Pattern { inst, output: 0 } + } + + fn pattern_as_expr(&mut self, pattern: Value) -> Value { + pattern + } +} + +impl PatternVisitor for PatternSequence { + type PatternId = Value; + fn add_match_equal(&mut self, a: Value, b: Value, ty: TypeId) { self.add_inst(PatternInst::MatchEqual { a, b, ty }); } @@ -300,8 +336,8 @@ impl PatternSequence { fn add_extract( &mut self, - inputs: Vec, - input_tys: Vec, + input: Value, + input_ty: TypeId, output_tys: Vec, term: TermId, infallible: bool, @@ -309,8 +345,8 @@ impl PatternSequence { ) -> Vec { let outputs = output_tys.len(); let inst = self.add_inst(PatternInst::Extract { - inputs, - input_tys, + inputs: vec![input], + input_tys: vec![input_ty], output_tys, term, infallible, @@ -320,131 +356,6 @@ impl PatternSequence { .map(|output| Value::Pattern { inst, output }) .collect() } - - fn add_expr_seq(&mut self, seq: ExprSequence, output: Value, output_ty: TypeId) -> Value { - let inst = self.add_inst(PatternInst::Expr { - seq, - output, - output_ty, - }); - - // Create values for all outputs. - Value::Pattern { inst, output: 0 } - } - - /// Generate PatternInsts to match the given (sub)pattern. Works - /// recursively down the AST. - fn gen_pattern( - &mut self, - input: ValueOrArgs, - termenv: &TermEnv, - pat: &Pattern, - vars: &mut StableMap, - ) { - match pat { - &Pattern::BindPattern(_ty, var, ref subpat) => { - // Bind the appropriate variable and recurse. - assert!(!vars.contains_key(&var)); - if let Some(v) = input.to_value() { - vars.insert(var, v); - } - self.gen_pattern(input, termenv, subpat, vars); - } - &Pattern::Var(ty, var) => { - // Assert that the value matches the existing bound var. - let var_val = vars - .get(&var) - .cloned() - .expect("Variable should already be bound"); - let input_val = input - .to_value() - .expect("Cannot match an =var pattern against root term"); - self.add_match_equal(input_val, var_val, ty); - } - &Pattern::ConstInt(ty, value) => { - // Assert that the value matches the constant integer. - let input_val = input - .to_value() - .expect("Cannot match an integer pattern against root term"); - self.add_match_int(input_val, ty, value); - } - &Pattern::ConstPrim(ty, value) => { - let input_val = input - .to_value() - .expect("Cannot match a constant-primitive pattern against root term"); - self.add_match_prim(input_val, ty, value); - } - &Pattern::Term(ty, term, ref args) => { - match input { - ValueOrArgs::ImplicitTermFromArgs(termid) => { - assert_eq!( - termid, term, - "Cannot match a different term against root pattern" - ); - let termdata = &termenv.terms[term.index()]; - let arg_tys = &termdata.arg_tys[..]; - for (i, subpat) in args.iter().enumerate() { - let value = self.add_arg(i, arg_tys[i]); - self.gen_pattern(ValueOrArgs::Value(value), termenv, subpat, vars); - } - } - ValueOrArgs::Value(input) => { - // Determine whether the term has an external extractor or not. - let termdata = &termenv.terms[term.index()]; - let arg_values = match &termdata.kind { - TermKind::EnumVariant { variant } => { - self.add_match_variant(input, ty, &termdata.arg_tys, *variant) - } - TermKind::Decl { - extractor_kind: None, - .. - } => { - panic!("Pattern invocation of undefined term body") - } - TermKind::Decl { - extractor_kind: Some(ExtractorKind::InternalExtractor { .. }), - .. - } => { - panic!("Should have been expanded away") - } - TermKind::Decl { - multi, - extractor_kind: - Some(ExtractorKind::ExternalExtractor { infallible, .. }), - .. - } => { - // Evaluate all `input` args. - let inputs = vec![input]; - let input_tys = vec![termdata.ret_ty]; - let output_tys = args.iter().map(|arg| arg.ty()).collect(); - - // Invoke the extractor. - self.add_extract( - inputs, - input_tys, - output_tys, - term, - *infallible && !*multi, - *multi, - ) - } - }; - for (pat, val) in args.iter().zip(arg_values) { - self.gen_pattern(ValueOrArgs::Value(val), termenv, pat, vars); - } - } - } - } - &Pattern::And(_ty, ref children) => { - for child in children { - self.gen_pattern(input, termenv, child, vars); - } - } - &Pattern::Wildcard(_ty) => { - // Nothing! - } - } - } } impl ExprSequence { @@ -453,6 +364,10 @@ impl ExprSequence { self.insts.push(inst); id } +} + +impl ExprVisitor for ExprSequence { + type ExprId = Value; fn add_const_int(&mut self, ty: TypeId, val: i128) -> Value { let inst = self.add_inst(ExprInst::ConstInt { ty, val }); @@ -495,134 +410,15 @@ impl ExprSequence { }); Value::Expr { inst, output: 0 } } - - fn add_return(&mut self, ty: TypeId, value: Value) { - self.add_inst(ExprInst::Return { - index: 0, - ty, - value, - }); - } - - /// Creates a sequence of ExprInsts to generate the given - /// expression value. Returns the value ID as well as the root - /// term ID, if any. - fn gen_expr( - &mut self, - termenv: &TermEnv, - expr: &Expr, - vars: &StableMap, - ) -> Value { - log!("gen_expr: expr {:?}", expr); - match expr { - &Expr::ConstInt(ty, val) => self.add_const_int(ty, val), - &Expr::ConstPrim(ty, val) => self.add_const_prim(ty, val), - &Expr::Let { - ty: _ty, - ref bindings, - ref body, - } => { - let mut vars = vars.clone(); - for &(var, _var_ty, ref var_expr) in bindings { - let var_value = self.gen_expr(termenv, var_expr, &vars); - vars.insert(var, var_value); - } - self.gen_expr(termenv, body, &vars) - } - &Expr::Var(_ty, var_id) => vars.get(&var_id).cloned().unwrap(), - &Expr::Term(ty, term, ref arg_exprs) => { - let termdata = &termenv.terms[term.index()]; - let arg_values_tys = arg_exprs - .iter() - .map(|arg_expr| self.gen_expr(termenv, arg_expr, vars)) - .zip(termdata.arg_tys.iter().copied()) - .collect(); - match &termdata.kind { - TermKind::EnumVariant { variant } => { - self.add_create_variant(arg_values_tys, ty, *variant) - } - TermKind::Decl { - constructor_kind: Some(ConstructorKind::InternalConstructor), - multi, - .. - } => { - self.add_construct( - arg_values_tys, - ty, - term, - /* infallible = */ false, - *multi, - ) - } - TermKind::Decl { - constructor_kind: Some(ConstructorKind::ExternalConstructor { .. }), - pure, - multi, - .. - } => { - self.add_construct( - arg_values_tys, - ty, - term, - /* infallible = */ !pure, - *multi, - ) - } - TermKind::Decl { - constructor_kind: None, - .. - } => panic!("Should have been caught by typechecking"), - } - } - } - } } /// Build a sequence from a rule. pub fn lower_rule(termenv: &TermEnv, rule: RuleId) -> (PatternSequence, ExprSequence) { - let mut pattern_seq: PatternSequence = Default::default(); - let mut expr_seq: ExprSequence = Default::default(); - let ruledata = &termenv.rules[rule.index()]; + log!("lower_rule: ruledata {:?}", ruledata); + + let mut pattern_seq = PatternSequence::default(); + let mut expr_seq = ruledata.visit(&mut pattern_seq, termenv).seq; expr_seq.pos = ruledata.pos; - - let mut vars = StableMap::new(); - let root_term = ruledata - .lhs - .root_term() - .expect("Pattern must have a term at the root"); - - log!("lower_rule: ruledata {:?}", ruledata,); - - // Lower the pattern, starting from the root input value. - pattern_seq.gen_pattern( - ValueOrArgs::ImplicitTermFromArgs(root_term), - termenv, - &ruledata.lhs, - &mut vars, - ); - - // Lower the `if-let` clauses into the pattern seq, using - // `PatternInst::Expr` for the sub-exprs (right-hand sides). - for iflet in &ruledata.iflets { - let mut subexpr_seq: ExprSequence = Default::default(); - let subexpr_ret_value = subexpr_seq.gen_expr(termenv, &iflet.rhs, &mut vars); - subexpr_seq.add_return(iflet.rhs.ty(), subexpr_ret_value); - let pattern_value = - pattern_seq.add_expr_seq(subexpr_seq, subexpr_ret_value, iflet.rhs.ty()); - pattern_seq.gen_pattern( - ValueOrArgs::Value(pattern_value), - termenv, - &iflet.lhs, - &mut vars, - ); - } - - // Lower the expression, making use of the bound variables - // from the pattern. - let rhs_root_val = expr_seq.gen_expr(termenv, &ruledata.rhs, &vars); - // Return the root RHS value. - let output_ty = ruledata.rhs.ty(); - expr_seq.add_return(output_ty, rhs_root_val); (pattern_seq, expr_seq) } diff --git a/cranelift/isle/isle/src/sema.rs b/cranelift/isle/isle/src/sema.rs index 9340729d71..501b93a179 100644 --- a/cranelift/isle/isle/src/sema.rs +++ b/cranelift/isle/isle/src/sema.rs @@ -22,6 +22,7 @@ use crate::{StableMap, StableSet}; use std::collections::hash_map::Entry; use std::collections::BTreeMap; use std::collections::BTreeSet; +use std::collections::HashMap; use std::sync::Arc; declare_id!( @@ -500,6 +501,43 @@ pub enum Expr { }, } +/// Visitor interface for [Pattern]s. Visitors can assign an arbitrary identifier to each +/// subpattern, which is threaded through to subsequent calls into the visitor. +pub trait PatternVisitor { + /// The type of subpattern identifiers. + type PatternId: Copy; + + /// Match if `a` and `b` have equal values. + fn add_match_equal(&mut self, a: Self::PatternId, b: Self::PatternId, ty: TypeId); + /// Match if `input` is the given integer constant. + fn add_match_int(&mut self, input: Self::PatternId, ty: TypeId, int_val: i128); + /// Match if `input` is the given primitive constant. + fn add_match_prim(&mut self, input: Self::PatternId, ty: TypeId, val: Sym); + + /// Match if `input` is the given enum variant. Returns an identifier for each field within the + /// enum variant. The length of the return list must equal the length of `arg_tys`. + fn add_match_variant( + &mut self, + input: Self::PatternId, + input_ty: TypeId, + arg_tys: &[TypeId], + variant: VariantId, + ) -> Vec; + + /// Match if the given external extractor succeeds on `input`. Returns an identifier for each + /// return value from the external extractor. The length of the return list must equal the + /// length of `output_tys`. + fn add_extract( + &mut self, + input: Self::PatternId, + input_ty: TypeId, + output_tys: Vec, + term: TermId, + infallible: bool, + multi: bool, + ) -> Vec; +} + impl Pattern { /// Get this pattern's type. pub fn ty(&self) -> TypeId { @@ -522,6 +560,114 @@ impl Pattern { _ => None, } } + + /// Recursively visit every sub-pattern. + pub fn visit( + &self, + visitor: &mut V, + input: V::PatternId, + termenv: &TermEnv, + vars: &mut HashMap, + ) { + match self { + &Pattern::BindPattern(_ty, var, ref subpat) => { + // Bind the appropriate variable and recurse. + assert!(!vars.contains_key(&var)); + vars.insert(var, input); + subpat.visit(visitor, input, termenv, vars); + } + &Pattern::Var(ty, var) => { + // Assert that the value matches the existing bound var. + let var_val = vars + .get(&var) + .copied() + .expect("Variable should already be bound"); + visitor.add_match_equal(input, var_val, ty); + } + &Pattern::ConstInt(ty, value) => visitor.add_match_int(input, ty, value), + &Pattern::ConstPrim(ty, value) => visitor.add_match_prim(input, ty, value), + &Pattern::Term(ty, term, ref args) => { + // Determine whether the term has an external extractor or not. + let termdata = &termenv.terms[term.index()]; + let arg_values = match &termdata.kind { + TermKind::EnumVariant { variant } => { + visitor.add_match_variant(input, ty, &termdata.arg_tys, *variant) + } + TermKind::Decl { + extractor_kind: None, + .. + } => { + panic!("Pattern invocation of undefined term body") + } + TermKind::Decl { + extractor_kind: Some(ExtractorKind::InternalExtractor { .. }), + .. + } => { + panic!("Should have been expanded away") + } + TermKind::Decl { + multi, + extractor_kind: Some(ExtractorKind::ExternalExtractor { infallible, .. }), + .. + } => { + // Evaluate all `input` args. + let output_tys = args.iter().map(|arg| arg.ty()).collect(); + + // Invoke the extractor. + visitor.add_extract( + input, + termdata.ret_ty, + output_tys, + term, + *infallible && !*multi, + *multi, + ) + } + }; + for (pat, val) in args.iter().zip(arg_values) { + pat.visit(visitor, val, termenv, vars); + } + } + &Pattern::And(_ty, ref children) => { + for child in children { + child.visit(visitor, input, termenv, vars); + } + } + &Pattern::Wildcard(_ty) => { + // Nothing! + } + } + } +} + +/// Visitor interface for [Expr]s. Visitors can return an arbitrary identifier for each +/// subexpression, which is threaded through to subsequent calls into the visitor. +pub trait ExprVisitor { + /// The type of subexpression identifiers. + type ExprId: Copy; + + /// Construct a constant integer. + fn add_const_int(&mut self, ty: TypeId, val: i128) -> Self::ExprId; + /// Construct a primitive constant. + fn add_const_prim(&mut self, ty: TypeId, val: Sym) -> Self::ExprId; + + /// Construct an enum variant with the given `inputs` assigned to the variant's fields in order. + fn add_create_variant( + &mut self, + inputs: Vec<(Self::ExprId, TypeId)>, + ty: TypeId, + variant: VariantId, + ) -> Self::ExprId; + + /// Call an external constructor with the given `inputs` as arguments. + fn add_construct( + &mut self, + inputs: Vec<(Self::ExprId, TypeId)>, + ty: TypeId, + term: TermId, + infallible: bool, + multi: bool, + ) -> Self::ExprId; } impl Expr { @@ -535,6 +681,177 @@ impl Expr { &Self::Let { ty: t, .. } => t, } } + + /// Recursively visit every subexpression. + pub fn visit( + &self, + visitor: &mut V, + termenv: &TermEnv, + vars: &HashMap, + ) -> V::ExprId { + log!("Expr::visit: expr {:?}", self); + match self { + &Expr::ConstInt(ty, val) => visitor.add_const_int(ty, val), + &Expr::ConstPrim(ty, val) => visitor.add_const_prim(ty, val), + &Expr::Let { + ty: _ty, + ref bindings, + ref body, + } => { + let mut vars = vars.clone(); + for &(var, _var_ty, ref var_expr) in bindings { + let var_value = var_expr.visit(visitor, termenv, &vars); + vars.insert(var, var_value); + } + body.visit(visitor, termenv, &vars) + } + &Expr::Var(_ty, var_id) => *vars.get(&var_id).unwrap(), + &Expr::Term(ty, term, ref arg_exprs) => { + let termdata = &termenv.terms[term.index()]; + let arg_values_tys = arg_exprs + .iter() + .map(|arg_expr| arg_expr.visit(visitor, termenv, vars)) + .zip(termdata.arg_tys.iter().copied()) + .collect(); + match &termdata.kind { + TermKind::EnumVariant { variant } => { + visitor.add_create_variant(arg_values_tys, ty, *variant) + } + TermKind::Decl { + constructor_kind: Some(ConstructorKind::InternalConstructor), + multi, + .. + } => { + visitor.add_construct( + arg_values_tys, + ty, + term, + /* infallible = */ false, + *multi, + ) + } + TermKind::Decl { + constructor_kind: Some(ConstructorKind::ExternalConstructor { .. }), + pure, + multi, + .. + } => { + visitor.add_construct( + arg_values_tys, + ty, + term, + /* infallible = */ !pure, + *multi, + ) + } + TermKind::Decl { + constructor_kind: None, + .. + } => panic!("Should have been caught by typechecking"), + } + } + } + } + + fn visit_in_rule( + &self, + visitor: &mut V, + termenv: &TermEnv, + vars: &HashMap::PatternId>, + ) -> V::Expr { + let var_exprs = vars + .iter() + .map(|(&var, &val)| (var, visitor.pattern_as_expr(val))) + .collect(); + visitor.add_expr(|visitor| VisitedExpr { + ty: self.ty(), + value: self.visit(visitor, termenv, &var_exprs), + }) + } +} + +/// Information about an expression after it has been fully visited in [RuleVisitor::add_expr]. +#[derive(Clone, Copy)] +pub struct VisitedExpr { + /// The type of the top-level expression. + pub ty: TypeId, + /// The identifier returned by the visitor for the top-level expression. + pub value: V::ExprId, +} + +/// Visitor interface for [Rule]s. Visitors must be able to visit patterns by implementing +/// [PatternVisitor], and to visit expressions by providing a type that implements [ExprVisitor]. +pub trait RuleVisitor { + /// The type of pattern visitors constructed by [RuleVisitor::add_pattern]. + type PatternVisitor: PatternVisitor; + /// The type of expression visitors constructed by [RuleVisitor::add_expr]. + type ExprVisitor: ExprVisitor; + /// The type returned from [RuleVisitor::add_expr], which may be exchanged for a subpattern + /// identifier using [RuleVisitor::expr_as_pattern]. + type Expr; + + /// Visit one of the arguments to the top-level pattern. + fn add_arg( + &mut self, + index: usize, + ty: TypeId, + ) -> ::PatternId; + + /// Visit a pattern, used once for the rule's left-hand side and once for each if-let. You can + /// determine which part of the rule the pattern comes from based on whether the `PatternId` + /// passed to the first call to this visitor came from `add_arg` or `expr_as_pattern`. + fn add_pattern(&mut self, visitor: F) + where + F: FnOnce(&mut Self::PatternVisitor); + + /// Visit an expression, used once for each if-let and once for the rule's right-hand side. + fn add_expr(&mut self, visitor: F) -> Self::Expr + where + F: FnOnce(&mut Self::ExprVisitor) -> VisitedExpr; + + /// Given an expression from [RuleVisitor::add_expr], return an identifier that can be used with + /// a pattern visitor in [RuleVisitor::add_pattern]. + fn expr_as_pattern( + &mut self, + expr: Self::Expr, + ) -> ::PatternId; + + /// Given an identifier from the pattern visitor, return an identifier that can be used with + /// the expression visitor. + fn pattern_as_expr( + &mut self, + pattern: ::PatternId, + ) -> ::ExprId; +} + +impl Rule { + /// Recursively visit every pattern and expression in this rule. Returns the [RuleVisitor::Expr] + /// that was returned from [RuleVisitor::add_expr] when that function was called on the rule's + /// right-hand side. + pub fn visit(&self, visitor: &mut V, termenv: &TermEnv) -> V::Expr { + let mut vars = HashMap::new(); + + // Visit the pattern, starting from the root input value. + if let &Pattern::Term(_, term, ref args) = &self.lhs { + let termdata = &termenv.terms[term.index()]; + for (i, (subpat, &arg_ty)) in args.iter().zip(termdata.arg_tys.iter()).enumerate() { + let value = visitor.add_arg(i, arg_ty); + visitor.add_pattern(|visitor| subpat.visit(visitor, value, termenv, &mut vars)); + } + } else { + unreachable!("Pattern must have a term at the root"); + } + + // Visit the `if-let` clauses, using `V::ExprVisitor` for the sub-exprs (right-hand sides). + for iflet in self.iflets.iter() { + let subexpr = iflet.rhs.visit_in_rule(visitor, termenv, &vars); + let value = visitor.expr_as_pattern(subexpr); + visitor.add_pattern(|visitor| iflet.lhs.visit(visitor, value, termenv, &mut vars)); + } + + // Visit the rule's right-hand side, making use of the bound variables from the pattern. + self.rhs.visit_in_rule(visitor, termenv, &vars) + } } /// Given an `Option`, unwrap the inner `T` value, or `continue` if it is