diff --git a/cranelift/isle/TODO b/cranelift/isle/TODO index 86bf62c88a..b6e0acff40 100644 --- a/cranelift/isle/TODO +++ b/cranelift/isle/TODO @@ -1,5 +1,8 @@ -- inputs to external extractors? "polarity" of args? -- "extractor macros" rather than full rule reversal? (rule ...) and (pattern ...)? +- Optimizations + - Infallible patterns; optimize away control flow when possible. + - Don't do the closure-wrapping thing for expressions inside of patterns. + - Document semantics carefully, especially wrt extractors. + - Build out an initial set of bindings for Cranelift LowerCtx with extractors for instruction info. diff --git a/cranelift/isle/isle_examples/test.isle b/cranelift/isle/isle_examples/test.isle index 39125a94d1..09d5ea92af 100644 --- a/cranelift/isle/isle_examples/test.isle +++ b/cranelift/isle/isle_examples/test.isle @@ -3,7 +3,7 @@ (type B (enum (B1 (x u32)) (B2 (x u32)))) (decl Input (A) u32) -(extractor Input get_input) ;; fn get_input(ctx: &mut C, ret: u32) -> Option<(A,)> +(extern extractor Input get_input) ;; fn get_input(ctx: &mut C, ret: u32) -> Option<(A,)> (decl Lower (A) B) @@ -12,10 +12,10 @@ (B.B2 sub)) (decl Extractor (B) A) -(rule - (A.A2 x) - (Extractor (B.B1 x))) +(extractor + (Extractor x) + (A.A2 x)) (rule (Lower (Extractor b)) - b) + (B.B1 b)) diff --git a/cranelift/isle/isle_examples/test3.isle b/cranelift/isle/isle_examples/test3.isle index 99d537b0e2..df4c7337cd 100644 --- a/cranelift/isle/isle_examples/test3.isle +++ b/cranelift/isle/isle_examples/test3.isle @@ -5,29 +5,29 @@ Store)) (type Inst (primitive Inst)) +(type InstInput (primitive InstInput)) (type Reg (primitive Reg)) (type u32 (primitive u32)) (decl Op (Opcode) Inst) -(extractor Op get_opcode) +(extern extractor Op get_opcode) -(decl InputToReg (Inst u32) Reg) -(constructor InputToReg put_input_in_reg) +(decl InstInput (InstInput u32) Inst) +(extern extractor InstInput get_inst_input (out in)) + +(decl Producer (Inst) InstInput) +(extern extractor Producer get_input_producer) + +(decl UseInput (InstInput) Reg) +(extern constructor UseInput put_input_in_reg) (type MachInst (enum (Add (a Reg) (b Reg)) + (Add3 (a Reg) (b Reg) (c Reg)) (Sub (a Reg) (b Reg)))) (decl Lower (Inst) MachInst) -;; These can be made nicer by defining some extractors -- see below. -(rule - (Lower inst @ (Op (Opcode.Iadd))) - (MachInst.Add (InputToReg inst 0) (InputToReg inst 1))) -(rule - (Lower inst @ (Op (Opcode.Isub))) - (MachInst.Sub (InputToReg inst 0) (InputToReg inst 1))) - ;; Extractors that give syntax sugar for (Iadd ra rb), etc. ;; ;; Note that this is somewhat simplistic: it directly connects inputs to @@ -39,19 +39,28 @@ ;; we are dealing (at the semantic level) with pure value equivalences of ;; "terms", not arbitrary side-effecting calls. -(decl Iadd (Reg Reg) Inst) -(decl Isub (Reg Reg) Inst) -(rule - inst @ (Op Opcode.Iadd) - (Iadd (InputToReg inst 0) (InputToReg inst 1))) -(rule - inst @ (Op Opcode.Isub) - (Isub (InputToReg inst 0) (InputToReg inst 1))) +(decl Iadd (InstInput InstInput) Inst) +(decl Isub (InstInput InstInput) Inst) +(extractor + (Iadd a b) + (and + (Op (Opcode.Iadd)) + (InstInput a <0) + (InstInput b <1))) +(extractor + (Isub a b) + (and + (Op (Opcode.Isub)) + (InstInput a <0) + (InstInput b <1))) ;; Now the nice syntax-sugar that "end-user" backend authors can write: (rule - (Lower (Iadd ra rb)) - (MachInst.Add ra rb)) + (Lower (Iadd ra rb)) + (MachInst.Add (UseInput ra) (UseInput rb))) (rule - (Lower (Isub ra rb)) - (MachInst.Sub ra rb)) + (Lower (Iadd (Producer (Iadd ra rb)) rc)) + (MachInst.Add3 (UseInput ra) (UseInput rb) (UseInput rc))) +(rule + (Lower (Isub ra rb)) + (MachInst.Sub (UseInput ra) (UseInput rb))) \ No newline at end of file diff --git a/cranelift/isle/isle_examples/test4.isle b/cranelift/isle/isle_examples/test4.isle index 085d2bee3b..a899122bbb 100644 --- a/cranelift/isle/isle_examples/test4.isle +++ b/cranelift/isle/isle_examples/test4.isle @@ -3,8 +3,8 @@ (decl Ext1 (u32) A) (decl Ext2 (u32) A) -(extractor Ext1 ext1) -(extractor Ext2 ext2) +(extern extractor Ext1 ext1) +(extern extractor Ext2 ext2) (decl Lower (A) A) diff --git a/cranelift/isle/src/ast.rs b/cranelift/isle/src/ast.rs index e216ca9b82..8742aaa166 100644 --- a/cranelift/isle/src/ast.rs +++ b/cranelift/isle/src/ast.rs @@ -12,6 +12,7 @@ pub struct Defs { pub enum Def { Type(Type), Rule(Rule), + Extractor(Extractor), Decl(Decl), Extern(Extern), } @@ -69,6 +70,16 @@ pub struct Rule { pub prio: Option, } +/// An extractor macro: (A x y) becomes (B x _ y ...). Expanded during +/// ast-to-sema pass. +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Extractor { + pub term: Ident, + pub args: Vec, + pub template: Pattern, + pub pos: Pos, +} + /// A pattern: the left-hand side of a rule. #[derive(Clone, PartialEq, Eq, Debug)] pub enum Pattern { @@ -80,13 +91,40 @@ pub enum Pattern { /// An operator that matches a constant integer value. ConstInt { val: i64 }, /// An application of a type variant or term. - Term { sym: Ident, args: Vec }, + Term { + sym: Ident, + args: Vec, + }, /// An operator that matches anything. Wildcard, /// N sub-patterns that must all match. And { subpats: Vec }, } +impl Pattern { + pub fn root_term(&self) -> Option<&Ident> { + match self { + &Pattern::BindPattern { ref subpat, .. } => subpat.root_term(), + &Pattern::Term { ref sym, .. } => Some(sym), + _ => None, + } + } +} + +/// A pattern in a term argument. Adds "evaluated expression" to kinds +/// of patterns in addition to all options in `Pattern`. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum TermArgPattern { + /// A regular pattern that must match the existing value in the term's argument. + Pattern(Pattern), + /// An expression that is evaluated during the match phase and can + /// be given into an extractor. This is essentially a limited form + /// of unification or bidirectional argument flow (a la Prolog): + /// we can pass an arg *into* an extractor rather than getting the + /// arg *out of* it. + Expr(Expr), +} + /// An expression: the right-hand side of a rule. /// /// Note that this *almost* looks like a core Lisp or lambda calculus, @@ -124,8 +162,15 @@ pub enum Extern { func: Ident, /// The position of this decl. pos: Pos, - /// Whether this extractor is infallible (always matches). - infallible: bool, + /// Poliarity of args: whether values are inputs or outputs to + /// the external extractor function. This is a sort of + /// statically-defined approximation to Prolog-style + /// unification; we allow for the same flexible directionality + /// but fix it at DSL-definition time. By default, every arg + /// is an *output* from the extractor (and the 'retval", or + /// more precisely the term value that we are extracting, is + /// an "input"). + arg_polarity: Option>, }, /// An external constructor: `(constructor Term rustfunc)` form. Constructor { @@ -137,3 +182,13 @@ pub enum Extern { pos: Pos, }, } + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ArgPolarity { + /// An arg that must be given an Expr in the pattern and passes + /// data *to* the extractor op. + Input, + /// An arg that must be given a regular pattern (not Expr) and + /// receives data *from* the extractor op. + Output, +} diff --git a/cranelift/isle/src/codegen.rs b/cranelift/isle/src/codegen.rs index 77a8ba0895..3e29005926 100644 --- a/cranelift/isle/src/codegen.rs +++ b/cranelift/isle/src/codegen.rs @@ -1,8 +1,8 @@ //! Generate Rust code from a series of Sequences. -use crate::error::Error; use crate::ir::{lower_rule, ExprInst, ExprSequence, InstId, PatternInst, PatternSequence, Value}; -use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId, Variant}; +use crate::sema::{RuleId, TermEnv, TermId, Type, TypeEnv, TypeId, Variant}; +use crate::{error::Error, sema::ExternalSig}; use std::collections::{HashMap, HashSet}; use std::fmt::Write; @@ -452,8 +452,7 @@ impl TermFunctionBuilder { struct TermFunctionsBuilder<'a> { typeenv: &'a TypeEnv, termenv: &'a TermEnv, - builders_by_input: HashMap, - builders_by_output: HashMap, + builders_by_term: HashMap, } impl<'a> TermFunctionsBuilder<'a> { @@ -461,8 +460,7 @@ impl<'a> TermFunctionsBuilder<'a> { log::trace!("typeenv: {:?}", typeenv); log::trace!("termenv: {:?}", termenv); Self { - builders_by_input: HashMap::new(), - builders_by_output: HashMap::new(), + builders_by_term: HashMap::new(), typeenv, termenv, } @@ -473,56 +471,29 @@ impl<'a> TermFunctionsBuilder<'a> { let rule = RuleId(rule); let prio = self.termenv.rules[rule.index()].prio.unwrap_or(0); - if let Some((pattern, expr, lhs_root)) = lower_rule( - self.typeenv, - self.termenv, - rule, - /* forward_dir = */ true, - ) { - log::trace!( - "build:\n- rule {:?}\n- fwd pattern {:?}\n- fwd expr {:?}", - self.termenv.rules[rule.index()], - pattern, - expr - ); - self.builders_by_input - .entry(lhs_root) - .or_insert_with(|| TermFunctionBuilder::new(lhs_root)) - .add_rule(prio, pattern.clone(), expr.clone()); - } + let (pattern, expr) = lower_rule(self.typeenv, self.termenv, rule); + let root_term = self.termenv.rules[rule.index()].lhs.root_term().unwrap(); - if let Some((pattern, expr, rhs_root)) = lower_rule( - self.typeenv, - self.termenv, - rule, - /* forward_dir = */ false, - ) { - log::trace!( - "build:\n- rule {:?}\n- rev pattern {:?}\n- rev expr {:?}", - self.termenv.rules[rule.index()], - pattern, - expr - ); - self.builders_by_output - .entry(rhs_root) - .or_insert_with(|| TermFunctionBuilder::new(rhs_root)) - .add_rule(prio, pattern, expr); - } + log::trace!( + "build:\n- rule {:?}\n- pattern {:?}\n- expr {:?}", + self.termenv.rules[rule.index()], + pattern, + expr + ); + self.builders_by_term + .entry(root_term) + .or_insert_with(|| TermFunctionBuilder::new(root_term)) + .add_rule(prio, pattern.clone(), expr.clone()); } } - fn finalize(self) -> (HashMap, HashMap) { - let functions_by_input = self - .builders_by_input + fn finalize(self) -> HashMap { + let functions_by_term = self + .builders_by_term .into_iter() .map(|(term, builder)| (term, builder.trie)) .collect::>(); - let functions_by_output = self - .builders_by_output - .into_iter() - .map(|(term, builder)| (term, builder.trie)) - .collect::>(); - (functions_by_input, functions_by_output) + functions_by_term } } @@ -530,15 +501,13 @@ impl<'a> TermFunctionsBuilder<'a> { pub struct Codegen<'a> { typeenv: &'a TypeEnv, termenv: &'a TermEnv, - functions_by_input: HashMap, - functions_by_output: HashMap, + functions_by_term: HashMap, } #[derive(Clone, Debug, Default)] struct BodyContext { - borrowed_values: HashSet, - expected_return_vals: usize, - tuple_return: bool, + /// For each value: (is_ref, ty). + values: HashMap, } impl<'a> Codegen<'a> { @@ -546,12 +515,11 @@ impl<'a> Codegen<'a> { let mut builder = TermFunctionsBuilder::new(typeenv, termenv); builder.build(); log::trace!("builder: {:?}", builder); - let (functions_by_input, functions_by_output) = builder.finalize(); + let functions_by_term = builder.finalize(); Ok(Codegen { typeenv, termenv, - functions_by_input, - functions_by_output, + functions_by_term, }) } @@ -562,7 +530,6 @@ impl<'a> Codegen<'a> { self.generate_ctx_trait(&mut code)?; self.generate_internal_types(&mut code)?; self.generate_internal_term_constructors(&mut code)?; - self.generate_internal_term_extractors(&mut code)?; Ok(code) } @@ -580,7 +547,11 @@ impl<'a> Codegen<'a> { writeln!( code, - "\n#![allow(dead_code, unreachable_code, unused_imports, unused_variables, non_snake_case)]" + "\n#![allow(dead_code, unreachable_code, unreachable_patterns)]" + )?; + writeln!( + code, + "#![allow(unused_imports, unused_variables, non_snake_case)]" )?; writeln!(code, "\nuse super::*; // Pulls in all external types.")?; @@ -588,6 +559,32 @@ impl<'a> Codegen<'a> { Ok(()) } + fn generate_trait_sig( + &self, + code: &mut dyn Write, + indent: &str, + sig: &ExternalSig, + ) -> Result<(), Error> { + writeln!( + code, + "{}fn {}(&mut self, {}) -> Option<({},)>;", + indent, + sig.func_name, + sig.arg_tys + .iter() + .enumerate() + .map(|(i, &ty)| format!("arg{}: {}", i, self.type_name(ty, /* by_ref = */ true))) + .collect::>() + .join(", "), + sig.ret_tys + .iter() + .map(|&ty| self.type_name(ty, /* by_ref = */ false)) + .collect::>() + .join(", ") + )?; + Ok(()) + } + fn generate_ctx_trait(&self, code: &mut dyn Write) -> Result<(), Error> { writeln!(code, "")?; writeln!( @@ -604,74 +601,9 @@ impl<'a> Codegen<'a> { )?; writeln!(code, "pub trait Context {{")?; for term in &self.termenv.terms { - if let &TermKind::Regular { - extractor, - constructor, - .. - } = &term.kind - { - if let Some((etor_name, infallible)) = extractor { - let etor_name = &self.typeenv.syms[etor_name.index()]; - let arg_is_prim = match &self.typeenv.types[term.ret_ty.index()] { - &Type::Primitive(..) => true, - _ => false, - }; - let arg = format!( - "arg0: {}", - self.type_name( - term.ret_ty, - /* by_ref = */ if arg_is_prim { None } else { Some("&") } - ), - ); - let ret_tuple_tys = term - .arg_tys - .iter() - .map(|ty| { - self.type_name(*ty, /* by_ref = */ None) - }) - .collect::>(); - if infallible { - writeln!( - code, - " fn {}(&mut self, {}) -> ({},);", - etor_name, - arg, - ret_tuple_tys.join(", ") - )?; - } else { - writeln!( - code, - " fn {}(&mut self, {}) -> Option<({},)>;", - etor_name, - arg, - ret_tuple_tys.join(", ") - )?; - } - } - - if let Some(ctor_name) = constructor { - let ctor_name = &self.typeenv.syms[ctor_name.index()]; - let args = term - .arg_tys - .iter() - .enumerate() - .map(|(i, &arg_ty)| { - format!( - "arg{}: {}", - i, - self.type_name(arg_ty, /* by_ref = */ Some("&")) - ) - }) - .collect::>(); - let ret = self.type_name(term.ret_ty, /* by_ref = */ None); - writeln!( - code, - " fn {}(&mut self, {}) -> Option<{}>;", - ctor_name, - args.join(", "), - ret, - )?; - } + if term.is_external() { + let ext_sig = term.to_sig(self.typeenv).unwrap(); + self.generate_trait_sig(code, " ", &ext_sig)?; } } writeln!(code, "}}")?; @@ -721,44 +653,11 @@ impl<'a> Codegen<'a> { Ok(()) } - fn constructor_name(&self, term: TermId) -> String { - let termdata = &self.termenv.terms[term.index()]; - match &termdata.kind { - &TermKind::EnumVariant { .. } => panic!("using enum variant as constructor"), - &TermKind::Regular { - constructor: Some(sym), - .. - } => format!("C::{}", self.typeenv.syms[sym.index()]), - &TermKind::Regular { - constructor: None, .. - } => { - format!("constructor_{}", self.typeenv.syms[termdata.name.index()]) - } - } - } - - fn extractor_name_and_infallible(&self, term: TermId) -> (String, bool) { - let termdata = &self.termenv.terms[term.index()]; - match &termdata.kind { - &TermKind::EnumVariant { .. } => panic!("using enum variant as extractor"), - &TermKind::Regular { - extractor: Some((sym, infallible)), - .. - } => (format!("C::{}", self.typeenv.syms[sym.index()]), infallible), - &TermKind::Regular { - extractor: None, .. - } => ( - format!("extractor_{}", self.typeenv.syms[termdata.name.index()]), - false, - ), - } - } - - fn type_name(&self, typeid: TypeId, by_ref: Option<&str>) -> String { + fn type_name(&self, typeid: TypeId, by_ref: bool) -> String { match &self.typeenv.types[typeid.index()] { &Type::Primitive(_, sym) => self.typeenv.syms[sym.index()].clone(), &Type::Enum { name, .. } => { - let r = by_ref.unwrap_or(""); + let r = if by_ref { "&" } else { "" }; format!("{}{}", r, self.typeenv.syms[name.index()]) } } @@ -771,10 +670,24 @@ impl<'a> Codegen<'a> { } } + fn ty_prim(&self, ty: TypeId) -> bool { + self.typeenv.types[ty.index()].is_prim() + } + + fn value_binder(&self, value: &Value, is_ref: bool, ty: TypeId) -> String { + let prim = self.ty_prim(ty); + if prim || !is_ref { + format!("{}", self.value_name(value)) + } else { + format!("ref {}", self.value_name(value)) + } + } + fn value_by_ref(&self, value: &Value, ctx: &BodyContext) -> String { let raw_name = self.value_name(value); - let name_is_ref = ctx.borrowed_values.contains(value); - if name_is_ref { + let &(is_ref, ty) = ctx.values.get(value).unwrap(); + let prim = self.ty_prim(ty); + if is_ref || prim { raw_name } else { format!("&{}", raw_name) @@ -783,50 +696,41 @@ impl<'a> Codegen<'a> { fn value_by_val(&self, value: &Value, ctx: &BodyContext) -> String { let raw_name = self.value_name(value); - let name_is_ref = ctx.borrowed_values.contains(value); - if name_is_ref { + let &(is_ref, _) = ctx.values.get(value).unwrap(); + if is_ref { format!("{}.clone()", raw_name) } else { raw_name } } - fn define_val(&self, value: &Value, ctx: &mut BodyContext, is_ref: bool) { - if is_ref { - ctx.borrowed_values.insert(value.clone()); - } + fn define_val(&self, value: &Value, ctx: &mut BodyContext, is_ref: bool, ty: TypeId) { + let is_ref = !self.ty_prim(ty) && is_ref; + ctx.values.insert(value.clone(), (is_ref, ty)); } fn generate_internal_term_constructors(&self, code: &mut dyn Write) -> Result<(), Error> { - for (&termid, trie) in &self.functions_by_input { + for (&termid, trie) in &self.functions_by_term { let termdata = &self.termenv.terms[termid.index()]; // Skip terms that are enum variants or that have external // constructors/extractors. - match &termdata.kind { - &TermKind::EnumVariant { .. } => continue, - &TermKind::Regular { - constructor, - extractor, - .. - } if constructor.is_some() || extractor.is_some() => continue, - _ => {} + if !termdata.is_constructor() || termdata.is_external() { + continue; } - // Get the name of the term and build up the signature. - let func_name = self.constructor_name(termid); - let args = termdata + let sig = termdata.to_sig(self.typeenv).unwrap(); + + let args = sig .arg_tys .iter() .enumerate() - .map(|(i, &arg_ty)| { - format!( - "arg{}: {}", - i, - self.type_name(arg_ty, /* by_ref = */ Some("&")) - ) - }) - .collect::>(); + .map(|(i, &ty)| format!("arg{}: {}", i, self.type_name(ty, true))) + .collect::>() + .join(", "); + assert_eq!(sig.ret_tys.len(), 1); + let ret = self.type_name(sig.ret_tys[0], false); + writeln!( code, "\n// Generated as internal constructor for term {}.", @@ -835,13 +739,10 @@ impl<'a> Codegen<'a> { writeln!( code, "pub fn {}(ctx: &mut C, {}) -> Option<{}> {{", - func_name, - args.join(", "), - self.type_name(termdata.ret_ty, /* by_ref = */ None) + sig.func_name, args, ret, )?; let mut body_ctx: BodyContext = Default::default(); - body_ctx.expected_return_vals = 1; let returned = self.generate_body(code, /* depth = */ 0, trie, " ", &mut body_ctx)?; if !returned { @@ -854,69 +755,6 @@ impl<'a> Codegen<'a> { Ok(()) } - fn generate_internal_term_extractors(&self, code: &mut dyn Write) -> Result<(), Error> { - for (&termid, trie) in &self.functions_by_output { - let termdata = &self.termenv.terms[termid.index()]; - - // Skip terms that are enum variants or that have external extractors. - match &termdata.kind { - &TermKind::EnumVariant { .. } => continue, - &TermKind::Regular { - constructor, - extractor, - .. - } if constructor.is_some() || extractor.is_some() => continue, - _ => {} - } - - // Get the name of the term and build up the signature. - let (func_name, _) = self.extractor_name_and_infallible(termid); - let arg_is_prim = match &self.typeenv.types[termdata.ret_ty.index()] { - &Type::Primitive(..) => true, - _ => false, - }; - let arg = format!( - "arg0: {}", - self.type_name( - termdata.ret_ty, - /* by_ref = */ if arg_is_prim { None } else { Some("&") } - ), - ); - let ret_tuple_tys = termdata - .arg_tys - .iter() - .map(|ty| { - self.type_name(*ty, /* by_ref = */ None) - }) - .collect::>(); - - writeln!( - code, - "\n// Generated as internal extractor for term {}.", - self.typeenv.syms[termdata.name.index()], - )?; - writeln!( - code, - "pub fn {}(ctx: &mut C, {}) -> Option<({},)> {{", - func_name, - arg, - ret_tuple_tys.join(", "), - )?; - - let mut body_ctx: BodyContext = Default::default(); - body_ctx.expected_return_vals = ret_tuple_tys.len(); - body_ctx.tuple_return = true; - let returned = - self.generate_body(code, /* depth = */ 0, trie, " ", &mut body_ctx)?; - if !returned { - writeln!(code, " return None;")?; - } - writeln!(code, "}}")?; - } - - Ok(()) - } - fn generate_expr_inst( &self, code: &mut dyn Write, @@ -926,15 +764,16 @@ impl<'a> Codegen<'a> { ctx: &mut BodyContext, returns: &mut Vec<(usize, String)>, ) -> Result<(), Error> { + log::trace!("generate_expr_inst: {:?}", inst); match inst { &ExprInst::ConstInt { ty, val } => { let value = Value::Expr { inst: id, output: 0, }; + self.define_val(&value, ctx, /* is_ref = */ false, ty); let name = self.value_name(&value); - let ty = self.type_name(ty, /* by_ref = */ None); - self.define_val(&value, ctx, /* is_ref = */ false); + let ty = self.type_name(ty, /* by_ref = */ false); writeln!(code, "{}let {}: {} = {};", indent, name, ty, val)?; } &ExprInst::CreateVariant { @@ -960,7 +799,7 @@ impl<'a> Codegen<'a> { let outputname = self.value_name(&output); let full_variant_name = format!( "{}::{}", - self.type_name(ty, None), + self.type_name(ty, false), self.typeenv.syms[variantinfo.name.index()] ); if input_fields.is_empty() { @@ -980,7 +819,7 @@ impl<'a> Codegen<'a> { } writeln!(code, "{}}};", indent)?; } - self.define_val(&output, ctx, /* is_ref = */ false); + self.define_val(&output, ctx, /* is_ref = */ false, ty); } &ExprInst::Construct { ref inputs, term, .. @@ -996,16 +835,18 @@ impl<'a> Codegen<'a> { output: 0, }; let outputname = self.value_name(&output); - let ctor_name = self.constructor_name(term); + let termdata = &self.termenv.terms[term.index()]; + let sig = termdata.to_sig(self.typeenv).unwrap(); + assert_eq!(input_exprs.len(), sig.arg_tys.len()); writeln!( code, "{}let {} = {}(ctx, {});", indent, outputname, - ctor_name, + sig.full_name, input_exprs.join(", "), )?; - self.define_val(&output, ctx, /* is_ref = */ false); + self.define_val(&output, ctx, /* is_ref = */ false, termdata.ret_ty); } &ExprInst::Return { index, ref value, .. @@ -1029,23 +870,15 @@ impl<'a> Codegen<'a> { .iter() .zip(variant.fields.iter()) .enumerate() - .map(|(i, (ty, field))| { + .map(|(i, (&ty, field))| { let value = Value::Pattern { inst: id, output: i, }; - let valuename = self.value_name(&value); + let valuename = self.value_binder(&value, /* is_ref = */ true, ty); let fieldname = &self.typeenv.syms[field.name.index()]; - match &self.typeenv.types[ty.index()] { - &Type::Primitive(..) => { - self.define_val(&value, ctx, /* is_ref = */ false); - format!("{}: {}", fieldname, valuename) - } - &Type::Enum { .. } => { - self.define_val(&value, ctx, /* is_ref = */ true); - format!("{}: ref {}", fieldname, valuename) - } - } + self.define_val(&value, ctx, /* is_ref = */ false, field.ty); + format!("{}: {}", fieldname, valuename) }) .collect::>() } @@ -1080,6 +913,7 @@ impl<'a> Codegen<'a> { }, ctx, is_ref, + ty, ); Ok(true) } @@ -1107,7 +941,7 @@ impl<'a> Codegen<'a> { &Type::Primitive(..) => panic!("primitive type input to MatchVariant"), &Type::Enum { ref variants, .. } => variants, }; - let ty_name = self.type_name(input_ty, /* is_ref = */ Some("&")); + let ty_name = self.type_name(input_ty, /* is_ref = */ true); let variant = &variants[variant.index()]; let variantname = &self.typeenv.syms[variant.name.index()]; let args = self.match_variant_binders(variant, &arg_tys[..], id, ctx); @@ -1124,57 +958,70 @@ impl<'a> Codegen<'a> { Ok(false) } &PatternInst::Extract { - ref input, - input_ty, - ref arg_tys, + ref inputs, + ref output_tys, term, .. } => { - let input_ty_prim = match &self.typeenv.types[input_ty.index()] { - &Type::Primitive(..) => true, - _ => false, - }; - let input = if input_ty_prim { - self.value_by_val(input, ctx) - } else { - self.value_by_ref(input, ctx) - }; - let (etor_name, infallible) = self.extractor_name_and_infallible(term); + let termdata = &self.termenv.terms[term.index()]; + let sig = termdata.to_sig(self.typeenv).unwrap(); - let args = arg_tys + let input_values = inputs + .iter() + .map(|input| self.value_by_ref(input, ctx)) + .collect::>(); + let output_binders = output_tys .iter() .enumerate() - .map(|(i, _ty)| { - let value = Value::Pattern { + .map(|(i, &ty)| { + let output_val = Value::Pattern { inst: id, output: i, }; - self.define_val(&value, ctx, /* is_ref = */ false); - self.value_name(&value) + self.define_val(&output_val, ctx, /* is_ref = */ false, ty); + self.value_binder(&output_val, /* is_ref = */ false, ty) }) .collect::>(); - if infallible { - writeln!( - code, - "{}let Some(({},)) = {}(ctx, {});", - indent, - args.join(", "), - etor_name, - input - )?; - writeln!(code, "{}{{", indent)?; - } else { - writeln!( - code, - "{}if let Some(({},)) = {}(ctx, {}) {{", - indent, - args.join(", "), - etor_name, - input - )?; + writeln!( + code, + "{}if let Some(({},)) = {}(ctx, {}) {{", + indent, + output_binders.join(", "), + sig.full_name, + input_values.join(", "), + )?; + + Ok(false) + } + &PatternInst::Expr { ref seq, output_ty, .. } => { + let closure_name = format!("closure{}", id.index()); + writeln!(code, "{}let {} = || {{", indent, closure_name)?; + let subindent = format!("{} ", indent); + let mut subctx = ctx.clone(); + let mut returns = vec![]; + for (id, inst) in seq.insts.iter().enumerate() { + let id = InstId(id); + self.generate_expr_inst(code, id, inst, &subindent, &mut subctx, &mut returns)?; } - Ok(infallible) + assert_eq!(returns.len(), 1); + writeln!(code, "{}return Some({});", subindent, returns[0].1)?; + writeln!(code, "{}}};", indent)?; + + let output = Value::Pattern { + inst: id, + output: 0, + }; + writeln!( + code, + "{}if let Some({}) = {}() {{", + indent, + self.value_binder(&output, /* is_ref = */ false, output_ty), + closure_name + )?; + self.define_val(&output, ctx, /* is_ref = */ false, output_ty); + + Ok(false) } } } @@ -1206,18 +1053,8 @@ impl<'a> Codegen<'a> { self.generate_expr_inst(code, id, inst, indent, ctx, &mut returns)?; } - assert_eq!(returns.len(), ctx.expected_return_vals); - returns.sort_by_key(|(index, _)| *index); - if ctx.tuple_return { - let return_values = returns - .into_iter() - .map(|(_, expr)| expr) - .collect::>() - .join(", "); - writeln!(code, "{}return Some(({},));", indent, return_values)?; - } else { - writeln!(code, "{}return Some({});", indent, returns[0].1)?; - } + assert_eq!(returns.len(), 1); + writeln!(code, "{}return Some({});", indent, returns[0].1)?; returned = true; } diff --git a/cranelift/isle/src/ir.rs b/cranelift/isle/src/ir.rs index 7951c1175e..19802b4f1e 100644 --- a/cranelift/isle/src/ir.rs +++ b/cranelift/isle/src/ir.rs @@ -41,14 +41,25 @@ pub enum PatternInst { variant: VariantId, }, - /// Invoke an extractor, taking the given value as input and - /// producing `|arg_tys|` values as output. + /// Invoke an extractor, taking the given values as input (the + /// first is the value to extract, the other are the + /// `Input`-polarity extractor args) and producing an output valu + /// efor each `Output`-polarity extractor arg. Extract { - input: Value, - input_ty: TypeId, - arg_tys: Vec, + inputs: Vec, + input_tys: Vec, + output_tys: Vec, term: TermId, }, + + /// Evaluate an expression and provide the given value as the + /// result of this match instruction. The expression has access to + /// the pattern-values up to this point in the sequence. + Expr { + seq: ExprSequence, + output: Value, + output_ty: TypeId, + }, } /// A single Expr instruction. @@ -110,7 +121,7 @@ pub struct PatternSequence { /// A linear sequence of instructions that produce a new value from /// the right-hand side of a rule, given bindings that come from a /// `Pattern` derived from the left-hand side. -#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Default, PartialOrd, Ord)] pub struct ExprSequence { /// Instruction sequence for expression. InstId indexes into this /// sequence for `Value::Expr` values. @@ -119,6 +130,21 @@ pub struct ExprSequence { pub pos: Pos, } +#[derive(Clone, Copy, Debug)] +enum ValueOrArgs { + Value(Value), + ImplicitTermFromArgs(TermId), +} + +impl ValueOrArgs { + fn to_value(&self) -> Option { + match self { + &ValueOrArgs::Value(v) => Some(v), + _ => None, + } + } +} + impl PatternSequence { fn add_inst(&mut self, inst: PatternInst) -> InstId { let id = InstId(self.insts.len()); @@ -165,104 +191,198 @@ impl PatternSequence { fn add_extract( &mut self, - input: Value, - input_ty: TypeId, - arg_tys: &[TypeId], + inputs: Vec, + input_tys: Vec, + output_tys: Vec, term: TermId, ) -> Vec { let inst = InstId(self.insts.len()); let mut outs = vec![]; - for (i, _arg_ty) in arg_tys.iter().enumerate() { + for i in 0..output_tys.len() { let val = Value::Pattern { inst, output: i }; outs.push(val); } - let arg_tys = arg_tys.iter().cloned().collect(); + let output_tys = output_tys.iter().cloned().collect(); self.add_inst(PatternInst::Extract { - input, - input_ty, - arg_tys, + inputs, + input_tys, + output_tys, term, }); outs } + fn add_expr_seq(&mut self, seq: ExprSequence, output: Value, output_ty: TypeId) -> Value { + let inst = self.add_inst(PatternInst::Expr { + seq, + output, + output_ty, + }); + + // Create values for all outputs. + Value::Pattern { inst, output: 0 } + } + /// Generate PatternInsts to match the given (sub)pattern. Works - /// recursively down the AST. Returns the root term matched by - /// this pattern, if any. + /// recursively down the AST. fn gen_pattern( &mut self, - // If `input` is `None`, then this is the root pattern, and is - // implicitly an extraction with the N args as results. - input: Option, + input: ValueOrArgs, typeenv: &TypeEnv, termenv: &TermEnv, pat: &Pattern, - vars: &mut HashMap, Value)>, - ) -> Option { + vars: &mut HashMap, + ) { match pat { &Pattern::BindPattern(_ty, var, ref subpat) => { // Bind the appropriate variable and recurse. assert!(!vars.contains_key(&var)); - vars.insert(var, (None, input.unwrap())); // bind first, so subpat can use it + if let Some(v) = input.to_value() { + vars.insert(var, v); + } let root_term = self.gen_pattern(input, typeenv, termenv, &*subpat, vars); - vars.get_mut(&var).unwrap().0 = root_term; root_term } &Pattern::Var(ty, var) => { // Assert that the value matches the existing bound var. - let (var_val_term, var_val) = vars + let var_val = vars .get(&var) .cloned() .expect("Variable should already be bound"); - self.add_match_equal(input.unwrap(), var_val, ty); - var_val_term + let input_val = input + .to_value() + .expect("Cannot match an =var pattern against root term"); + self.add_match_equal(input_val, var_val, ty); } &Pattern::ConstInt(ty, value) => { // Assert that the value matches the constant integer. - self.add_match_int(input.unwrap(), ty, value); - None - } - &Pattern::Term(_, term, ref args) if input.is_none() => { - let termdata = &termenv.terms[term.index()]; - let arg_tys = &termdata.arg_tys[..]; - for (i, subpat) in args.iter().enumerate() { - let value = self.add_arg(i, arg_tys[i]); - self.gen_pattern(Some(value), typeenv, termenv, subpat, vars); - } - Some(term) + let input_val = input + .to_value() + .expect("Cannot match an =var pattern against root term"); + self.add_match_int(input_val, ty, value); } &Pattern::Term(ty, term, ref args) => { - // Determine whether the term has an external extractor or not. - let termdata = &termenv.terms[term.index()]; - let arg_tys = &termdata.arg_tys[..]; - match &termdata.kind { - &TermKind::EnumVariant { variant } => { - let arg_values = - self.add_match_variant(input.unwrap(), ty, arg_tys, variant); - for (subpat, value) in args.iter().zip(arg_values.into_iter()) { - self.gen_pattern(Some(value), typeenv, termenv, subpat, vars); + match input { + ValueOrArgs::ImplicitTermFromArgs(termid) => { + assert_eq!( + termid, term, + "Cannot match a different term against root pattern" + ); + let termdata = &termenv.terms[term.index()]; + let arg_tys = &termdata.arg_tys[..]; + for (i, subpat) in args.iter().enumerate() { + let value = self.add_arg(i, arg_tys[i]); + let subpat = match subpat { + &TermArgPattern::Expr(..) => { + panic!("Should have been caught in typechecking") + } + &TermArgPattern::Pattern(ref pat) => pat, + }; + self.gen_pattern( + ValueOrArgs::Value(value), + typeenv, + termenv, + subpat, + vars, + ); } - None } - &TermKind::Regular { .. } => { - let arg_values = self.add_extract(input.unwrap(), ty, arg_tys, term); - for (subpat, value) in args.iter().zip(arg_values.into_iter()) { - self.gen_pattern(Some(value), typeenv, termenv, subpat, vars); + ValueOrArgs::Value(input) => { + // Determine whether the term has an external extractor or not. + let termdata = &termenv.terms[term.index()]; + let arg_tys = &termdata.arg_tys[..]; + match &termdata.kind { + &TermKind::Declared => { + panic!("Pattern invocation of undefined term body"); + } + &TermKind::EnumVariant { variant } => { + let arg_values = + self.add_match_variant(input, ty, arg_tys, variant); + for (subpat, value) in args.iter().zip(arg_values.into_iter()) { + let subpat = match subpat { + &TermArgPattern::Pattern(ref pat) => pat, + _ => unreachable!("Should have been caught by sema"), + }; + self.gen_pattern( + ValueOrArgs::Value(value), + typeenv, + termenv, + subpat, + vars, + ); + } + } + &TermKind::InternalConstructor + | &TermKind::ExternalConstructor { .. } => { + panic!("Should not invoke constructor in pattern"); + } + &TermKind::InternalExtractor { .. } => { + panic!("Should have been expanded away"); + } + &TermKind::ExternalExtractor { + ref arg_polarity, .. + } => { + // Evaluate all `input` args. + let mut inputs = vec![]; + let mut input_tys = vec![]; + let mut output_tys = vec![]; + let mut output_pats = vec![]; + inputs.push(input); + input_tys.push(termdata.ret_ty); + for (arg, pol) in args.iter().zip(arg_polarity.iter()) { + match pol { + &ArgPolarity::Input => { + let expr = match arg { + &TermArgPattern::Expr(ref expr) => expr, + _ => panic!( + "Should have been caught by typechecking" + ), + }; + let mut seq = ExprSequence::default(); + let value = seq.gen_expr(typeenv, termenv, expr, vars); + seq.add_return(expr.ty(), value); + let value = self.add_expr_seq(seq, value, expr.ty()); + inputs.push(value); + input_tys.push(expr.ty()); + } + &ArgPolarity::Output => { + let pat = match arg { + &TermArgPattern::Pattern(ref pat) => pat, + _ => panic!( + "Should have been caught by typechecking" + ), + }; + output_tys.push(pat.ty()); + output_pats.push(pat); + } + } + } + + // Invoke the extractor. + let arg_values = + self.add_extract(inputs, input_tys, output_tys, term); + + for (pat, &val) in output_pats.iter().zip(arg_values.iter()) { + self.gen_pattern( + ValueOrArgs::Value(val), + typeenv, + termenv, + pat, + vars, + ); + } + } } - Some(term) } } } &Pattern::And(_ty, ref children) => { - let input = input.unwrap(); for child in children { - self.gen_pattern(Some(input), typeenv, termenv, child, vars); + self.gen_pattern(input, typeenv, termenv, child, vars); } - None } &Pattern::Wildcard(_ty) => { // Nothing! - None } } } @@ -319,63 +439,40 @@ impl ExprSequence { /// Creates a sequence of ExprInsts to generate the given /// expression value. Returns the value ID as well as the root /// term ID, if any. - /// - /// If `gen_final_construct` is false and the value is a - /// constructor call, this returns the arguments instead. This is - /// used when codegen'ing extractors for internal terms. fn gen_expr( &mut self, typeenv: &TypeEnv, termenv: &TermEnv, expr: &Expr, - vars: &HashMap, Value)>, - gen_final_construct: bool, - ) -> (Option, Vec) { - log::trace!( - "gen_expr: expr {:?} gen_final_construct {}", - expr, - gen_final_construct - ); + vars: &HashMap, + ) -> Value { + log::trace!("gen_expr: expr {:?}", expr); match expr { - &Expr::ConstInt(ty, val) => (None, vec![self.add_const_int(ty, val)]), + &Expr::ConstInt(ty, val) => self.add_const_int(ty, val), &Expr::Let(_ty, ref bindings, ref subexpr) => { let mut vars = vars.clone(); for &(var, _var_ty, ref var_expr) in bindings { - let (var_value_term, var_value) = - self.gen_expr(typeenv, termenv, &*var_expr, &vars, true); - let var_value = var_value[0]; - vars.insert(var, (var_value_term, var_value)); + let var_value = self.gen_expr(typeenv, termenv, &*var_expr, &vars); + vars.insert(var, var_value); } - self.gen_expr(typeenv, termenv, &*subexpr, &vars, gen_final_construct) - } - &Expr::Var(_ty, var_id) => { - let (root_term, value) = vars.get(&var_id).cloned().unwrap(); - (root_term, vec![value]) + self.gen_expr(typeenv, termenv, &*subexpr, &vars) } + &Expr::Var(_ty, var_id) => vars.get(&var_id).cloned().unwrap(), &Expr::Term(ty, term, ref arg_exprs) => { let termdata = &termenv.terms[term.index()]; let mut arg_values_tys = vec![]; - log::trace!("Term gen_expr term {}", term.index()); for (arg_ty, arg_expr) in termdata.arg_tys.iter().cloned().zip(arg_exprs.iter()) { - log::trace!("generating for arg_expr {:?}", arg_expr); - arg_values_tys.push(( - self.gen_expr(typeenv, termenv, &*arg_expr, &vars, true).1[0], - arg_ty, - )); + arg_values_tys + .push((self.gen_expr(typeenv, termenv, &*arg_expr, &vars), arg_ty)); } match &termdata.kind { - &TermKind::EnumVariant { variant } => ( - None, - vec![self.add_create_variant(&arg_values_tys[..], ty, variant)], - ), - &TermKind::Regular { .. } if !gen_final_construct => ( - Some(termdata.id), - arg_values_tys.into_iter().map(|(val, _ty)| val).collect(), - ), - &TermKind::Regular { .. } => ( - Some(termdata.id), - vec![self.add_construct(&arg_values_tys[..], ty, term)], - ), + &TermKind::EnumVariant { variant } => { + self.add_create_variant(&arg_values_tys[..], ty, variant) + } + &TermKind::InternalConstructor | &TermKind::ExternalConstructor { .. } => { + self.add_construct(&arg_values_tys[..], ty, term) + } + _ => panic!("Should have been caught by typechecking"), } } } @@ -387,114 +484,34 @@ pub fn lower_rule( tyenv: &TypeEnv, termenv: &TermEnv, rule: RuleId, - is_forward_dir: bool, -) -> Option<(PatternSequence, ExprSequence, TermId)> { +) -> (PatternSequence, ExprSequence) { let mut pattern_seq: PatternSequence = Default::default(); let mut expr_seq: ExprSequence = Default::default(); expr_seq.pos = termenv.rules[rule.index()].pos; - // Lower the pattern, starting from the root input value. let ruledata = &termenv.rules[rule.index()]; let mut vars = HashMap::new(); + let root_term = ruledata + .lhs + .root_term() + .expect("Pattern must have a term at the root"); - log::trace!( - "lower_rule: ruledata {:?} forward {}", - ruledata, - is_forward_dir + log::trace!("lower_rule: ruledata {:?}", ruledata,); + + // Lower the pattern, starting from the root input value. + pattern_seq.gen_pattern( + ValueOrArgs::ImplicitTermFromArgs(root_term), + tyenv, + termenv, + &ruledata.lhs, + &mut vars, ); - if is_forward_dir { - let can_do_forward = match &ruledata.lhs { - &Pattern::Term(..) => true, - _ => false, - }; - if !can_do_forward { - return None; - } - - let lhs_root_term = pattern_seq.gen_pattern(None, tyenv, termenv, &ruledata.lhs, &mut vars); - let root_term = match lhs_root_term { - Some(t) => t, - None => { - return None; - } - }; - - // Lower the expression, making use of the bound variables - // from the pattern. - let (_, rhs_root_vals) = expr_seq.gen_expr( - tyenv, - termenv, - &ruledata.rhs, - &vars, - /* final_construct = */ true, - ); - // Return the root RHS value. - let output_ty = ruledata.rhs.ty(); - assert_eq!(rhs_root_vals.len(), 1); - expr_seq.add_return(output_ty, rhs_root_vals[0]); - Some((pattern_seq, expr_seq, root_term)) - } else { - let can_reverse = match &ruledata.rhs { - &Expr::Term(..) => true, - _ => false, - }; - if !can_reverse { - return None; - } - - let arg = pattern_seq.add_arg(0, ruledata.lhs.ty()); - let _ = pattern_seq.gen_pattern(Some(arg), tyenv, termenv, &ruledata.lhs, &mut vars); - let (rhs_root_term, rhs_root_vals) = expr_seq.gen_expr( - tyenv, - termenv, - &ruledata.rhs, - &vars, - /* final_construct = */ false, - ); - - let root_term = match rhs_root_term { - Some(t) => t, - None => { - return None; - } - }; - let termdata = &termenv.terms[root_term.index()]; - for (i, (val, ty)) in rhs_root_vals - .into_iter() - .zip(termdata.arg_tys.iter()) - .enumerate() - { - expr_seq.add_multi_return(i, *ty, val); - } - - Some((pattern_seq, expr_seq, root_term)) - } -} - -/// Trim the final Construct and Return ops in an ExprSequence in -/// order to allow the extractor to be codegen'd. -pub fn trim_expr_for_extractor(mut expr: ExprSequence) -> ExprSequence { - let ret_inst = expr.insts.pop().unwrap(); - let retval = match ret_inst { - ExprInst::Return { value, .. } => value, - _ => panic!("Last instruction is not a return"), - }; - assert_eq!( - retval, - Value::Expr { - inst: InstId(expr.insts.len() - 1), - output: 0 - } - ); - let construct_inst = expr.insts.pop().unwrap(); - let inputs = match construct_inst { - ExprInst::Construct { inputs, .. } => inputs, - _ => panic!("Returned value is not a construct call"), - }; - for (i, (value, ty)) in inputs.into_iter().enumerate() { - expr.add_multi_return(i, ty, value); - } - - expr + // Lower the expression, making use of the bound variables + // from the pattern. + let rhs_root_val = expr_seq.gen_expr(tyenv, termenv, &ruledata.rhs, &vars); + // Return the root RHS value. + let output_ty = ruledata.rhs.ty(); + expr_seq.add_return(output_ty, rhs_root_val); + (pattern_seq, expr_seq) } diff --git a/cranelift/isle/src/lexer.rs b/cranelift/isle/src/lexer.rs index 5c85f5a730..261dc9c910 100644 --- a/cranelift/isle/src/lexer.rs +++ b/cranelift/isle/src/lexer.rs @@ -18,7 +18,7 @@ enum LexerInput<'a> { File { content: String, filename: String }, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash, PartialOrd, Ord)] pub struct Pos { pub file: usize, pub offset: usize, @@ -41,6 +41,8 @@ pub enum Token { RParen, Symbol(String), Int(i64), + At, + Lt, } impl<'a> Lexer<'a> { @@ -133,7 +135,7 @@ impl<'a> Lexer<'a> { } fn is_sym_other_char(c: u8) -> bool { match c { - b'(' | b')' | b';' => false, + b'(' | b')' | b';' | b'@' | b'<' => false, c if c.is_ascii_whitespace() => false, _ => true, } @@ -168,6 +170,14 @@ impl<'a> Lexer<'a> { self.advance_pos(); Some((char_pos, Token::RParen)) } + b'@' => { + self.advance_pos(); + Some((char_pos, Token::At)) + } + b'<' => { + self.advance_pos(); + Some((char_pos, Token::Lt)) + } c if is_sym_first_char(c) => { let start = self.pos.offset; let start_pos = self.pos; diff --git a/cranelift/isle/src/parser.rs b/cranelift/isle/src/parser.rs index e5265adc76..7ec4fdc0ea 100644 --- a/cranelift/isle/src/parser.rs +++ b/cranelift/isle/src/parser.rs @@ -53,6 +53,12 @@ impl<'a> Parser<'a> { fn is_rparen(&self) -> bool { self.is(|tok| *tok == Token::RParen) } + fn is_at(&self) -> bool { + self.is(|tok| *tok == Token::At) + } + fn is_lt(&self) -> bool { + self.is(|tok| *tok == Token::Lt) + } fn is_sym(&self) -> bool { self.is(|tok| tok.is_sym()) } @@ -72,6 +78,12 @@ impl<'a> Parser<'a> { fn rparen(&mut self) -> ParseResult<()> { self.take(|tok| *tok == Token::RParen).map(|_| ()) } + fn at(&mut self) -> ParseResult<()> { + self.take(|tok| *tok == Token::At).map(|_| ()) + } + fn lt(&mut self) -> ParseResult<()> { + self.take(|tok| *tok == Token::Lt).map(|_| ()) + } fn symbol(&mut self) -> ParseResult { match self.take(|tok| tok.is_sym())? { @@ -103,10 +115,10 @@ impl<'a> Parser<'a> { let pos = self.pos(); let def = match &self.symbol()?[..] { "type" => Def::Type(self.parse_type()?), - "rule" => Def::Rule(self.parse_rule()?), "decl" => Def::Decl(self.parse_decl()?), - "constructor" => Def::Extern(self.parse_ctor()?), - "extractor" => Def::Extern(self.parse_etor()?), + "rule" => Def::Rule(self.parse_rule()?), + "extractor" => Def::Extractor(self.parse_etor()?), + "extern" => Def::Extern(self.parse_extern()?), s => { return Err(self.error(pos.unwrap(), format!("Unexpected identifier: {}", s))); } @@ -231,32 +243,72 @@ impl<'a> Parser<'a> { }) } - fn parse_ctor(&mut self) -> ParseResult { + fn parse_extern(&mut self) -> ParseResult { let pos = self.pos(); - let term = self.parse_ident()?; - let func = self.parse_ident()?; - Ok(Extern::Constructor { - term, - func, - pos: pos.unwrap(), - }) + if self.is_sym_str("constructor") { + self.symbol()?; + let term = self.parse_ident()?; + let func = self.parse_ident()?; + Ok(Extern::Constructor { + term, + func, + pos: pos.unwrap(), + }) + } else if self.is_sym_str("extractor") { + self.symbol()?; + let term = self.parse_ident()?; + let func = self.parse_ident()?; + let arg_polarity = if self.is_lparen() { + let mut pol = vec![]; + self.lparen()?; + while !self.is_rparen() { + if self.is_sym_str("in") { + self.symbol()?; + pol.push(ArgPolarity::Input); + } else if self.is_sym_str("out") { + self.symbol()?; + pol.push(ArgPolarity::Output); + } else { + return Err( + self.error(pos.unwrap(), "Invalid argument polarity".to_string()) + ); + } + } + self.rparen()?; + Some(pol) + } else { + None + }; + Ok(Extern::Extractor { + term, + func, + pos: pos.unwrap(), + arg_polarity, + }) + } else { + Err(self.error( + pos.unwrap(), + "Invalid extern: must be (extern constructor ...) or (extern extractor ...)" + .to_string(), + )) + } } - fn parse_etor(&mut self) -> ParseResult { + fn parse_etor(&mut self) -> ParseResult { let pos = self.pos(); - let infallible = if self.is_sym_str("infallible") { - self.symbol()?; - true - } else { - false - }; + self.lparen()?; let term = self.parse_ident()?; - let func = self.parse_ident()?; - Ok(Extern::Extractor { + let mut args = vec![]; + while !self.is_rparen() { + args.push(self.parse_ident()?); + } + self.rparen()?; + let template = self.parse_pattern()?; + Ok(Extractor { term, - func, + args, + template, pos: pos.unwrap(), - infallible, }) } @@ -292,8 +344,8 @@ impl<'a> Parser<'a> { Ok(Pattern::Var { var }) } else { let var = self.str_to_ident(pos.unwrap(), &s)?; - if self.is_sym_str("@") { - self.symbol()?; + if self.is_at() { + self.at()?; let subpat = Box::new(self.parse_pattern()?); Ok(Pattern::BindPattern { var, subpat }) } else { @@ -317,7 +369,7 @@ impl<'a> Parser<'a> { let sym = self.parse_ident()?; let mut args = vec![]; while !self.is_rparen() { - args.push(self.parse_pattern()?); + args.push(self.parse_pattern_term_arg()?); } self.rparen()?; Ok(Pattern::Term { sym, args }) @@ -327,6 +379,15 @@ impl<'a> Parser<'a> { } } + fn parse_pattern_term_arg(&mut self) -> ParseResult { + if self.is_lt() { + self.lt()?; + Ok(TermArgPattern::Expr(self.parse_expr()?)) + } else { + Ok(TermArgPattern::Pattern(self.parse_pattern()?)) + } + } + fn parse_expr(&mut self) -> ParseResult { let pos = self.pos(); if self.is_lparen() { diff --git a/cranelift/isle/src/sema.rs b/cranelift/isle/src/sema.rs index f3d9049180..72aa779c69 100644 --- a/cranelift/isle/src/sema.rs +++ b/cranelift/isle/src/sema.rs @@ -55,6 +55,13 @@ impl Type { Self::Primitive(_, name) | Self::Enum { name, .. } => &tyenv.syms[name.index()], } } + + pub fn is_prim(&self) -> bool { + match self { + &Type::Primitive(..) => true, + _ => false, + } + } } #[derive(Clone, Debug, PartialEq, Eq)] @@ -96,15 +103,120 @@ pub enum TermKind { /// `A1`. variant: VariantId, }, - Regular { - // Producer and consumer rules are catalogued separately after - // building Sequences. Here we just record whether an - // extractor and/or constructor is known. - /// Extractor func and `infallible` flag. - extractor: Option<(Sym, bool)>, - /// Constructor func. - constructor: Option, + /// A term with "internal" rules that work in the forward + /// direction. Becomes a compiled Rust function in the generated + /// code. + InternalConstructor, + /// A term that defines an "extractor macro" in the LHS of a + /// pattern. Its arguments take patterns and are simply + /// substituted with the given patterns when used. + InternalExtractor { + args: Vec, + template: ast::Pattern, }, + /// A term defined solely by an external extractor function. + ExternalExtractor { + /// Extractor func. + name: Sym, + /// Which arguments of the extractor are inputs and which are outputs? + arg_polarity: Vec, + }, + /// A term defined solely by an external constructor function. + ExternalConstructor { + /// Constructor func. + name: Sym, + }, + /// Declared but no body or externs associated (yet). + Declared, +} + +pub use crate::ast::ArgPolarity; + +#[derive(Clone, Debug)] +pub struct ExternalSig { + pub func_name: String, + pub full_name: String, + pub arg_tys: Vec, + pub ret_tys: Vec, +} + +impl Term { + pub fn ty(&self) -> TypeId { + self.ret_ty + } + + pub fn to_variant(&self) -> Option { + match &self.kind { + &TermKind::EnumVariant { variant } => Some(variant), + _ => None, + } + } + + pub fn is_constructor(&self) -> bool { + match &self.kind { + &TermKind::InternalConstructor { .. } | &TermKind::ExternalConstructor { .. } => true, + _ => false, + } + } + + pub fn is_extractor(&self) -> bool { + match &self.kind { + &TermKind::InternalExtractor { .. } | &TermKind::ExternalExtractor { .. } => true, + _ => false, + } + } + + pub fn is_external(&self) -> bool { + match &self.kind { + &TermKind::ExternalExtractor { .. } | &TermKind::ExternalConstructor { .. } => true, + _ => false, + } + } + + pub fn to_sig(&self, tyenv: &TypeEnv) -> Option { + match &self.kind { + &TermKind::ExternalConstructor { name } => Some(ExternalSig { + func_name: tyenv.syms[name.index()].clone(), + full_name: format!("C::{}", tyenv.syms[name.index()]), + arg_tys: self.arg_tys.clone(), + ret_tys: vec![self.ret_ty], + }), + &TermKind::ExternalExtractor { + name, + ref arg_polarity, + } => { + let mut arg_tys = vec![]; + let mut ret_tys = vec![]; + arg_tys.push(self.ret_ty); + for (&arg, polarity) in self.arg_tys.iter().zip(arg_polarity.iter()) { + match polarity { + &ArgPolarity::Input => { + arg_tys.push(arg); + } + &ArgPolarity::Output => { + ret_tys.push(arg); + } + } + } + Some(ExternalSig { + func_name: tyenv.syms[name.index()].clone(), + full_name: format!("C::{}", tyenv.syms[name.index()]), + arg_tys, + ret_tys, + }) + } + &TermKind::InternalConstructor { .. } => { + let name = format!("constructor_{}", tyenv.syms[self.name.index()]); + Some(ExternalSig { + func_name: name.clone(), + full_name: name, + arg_tys: self.arg_tys.clone(), + ret_tys: vec![self.ret_ty], + }) + } + _ => None, + } + } } #[derive(Clone, Debug)] @@ -121,11 +233,17 @@ pub enum Pattern { BindPattern(TypeId, VarId, Box), Var(TypeId, VarId), ConstInt(TypeId, i64), - Term(TypeId, TermId, Vec), + Term(TypeId, TermId, Vec), Wildcard(TypeId), And(TypeId, Vec), } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum TermArgPattern { + Pattern(Pattern), + Expr(Expr), +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum Expr { Term(TypeId, TermId, Vec), @@ -145,6 +263,14 @@ impl Pattern { &Self::And(t, ..) => t, } } + + pub fn root_term(&self) -> Option { + match self { + &Pattern::Term(_, term, _) => Some(term), + &Pattern::BindPattern(_, _, ref subpat) => subpat.root_term(), + _ => None, + } + } } impl Expr { @@ -295,13 +421,13 @@ impl TypeEnv { } } -#[derive(Clone)] +#[derive(Clone, Debug)] struct Bindings { next_var: usize, vars: Vec, } -#[derive(Clone)] +#[derive(Clone, Debug)] struct BoundVar { name: Sym, id: VarId, @@ -318,6 +444,8 @@ impl TermEnv { env.collect_term_sigs(tyenv, defs)?; env.collect_enum_variant_terms(tyenv)?; + env.collect_constructors(tyenv, defs)?; + env.collect_extractor_templates(tyenv, defs)?; env.collect_rules(tyenv, defs)?; Ok(env) @@ -361,10 +489,7 @@ impl TermEnv { name, arg_tys, ret_ty, - kind: TermKind::Regular { - extractor: None, - constructor: None, - }, + kind: TermKind::Declared, }); } _ => {} @@ -415,6 +540,87 @@ impl TermEnv { Ok(()) } + fn collect_constructors(&mut self, tyenv: &mut TypeEnv, defs: &ast::Defs) -> SemaResult<()> { + for def in &defs.defs { + match def { + &ast::Def::Rule(ref rule) => { + let pos = rule.pos; + let term = match rule.pattern.root_term() { + Some(t) => t, + None => { + return Err(tyenv.error( + pos, + "Rule does not have a term at the LHS root".to_string(), + )); + } + }; + let sym = tyenv.intern_mut(&term); + let term = match self.term_map.get(&sym) { + Some(&tid) => tid, + None => { + return Err( + tyenv.error(pos, "Rule LHS root term is not defined".to_string()) + ); + } + }; + let termdata = &mut self.terms[term.index()]; + match &termdata.kind { + &TermKind::Declared => { + termdata.kind = TermKind::InternalConstructor; + } + &TermKind::InternalConstructor => { + // OK, no error; multiple rules can apply to one internal constructor term. + } + _ => { + return Err(tyenv.error(pos, "Rule LHS root term is incorrect kind; cannot be internal constructor".to_string())); + } + } + } + _ => {} + } + } + Ok(()) + } + + fn collect_extractor_templates( + &mut self, + tyenv: &mut TypeEnv, + defs: &ast::Defs, + ) -> SemaResult<()> { + for def in &defs.defs { + match def { + &ast::Def::Extractor(ref ext) => { + let sym = tyenv.intern_mut(&ext.term); + let term = self.term_map.get(&sym).ok_or_else(|| { + tyenv.error( + ext.pos, + "Extractor macro body definition on a non-existent term".to_string(), + ) + })?; + let termdata = &mut self.terms[term.index()]; + match &termdata.kind { + &TermKind::Declared => { + termdata.kind = TermKind::InternalExtractor { + args: ext.args.clone(), + template: ext.template.clone(), + }; + } + _ => { + return Err(tyenv.error( + ext.pos, + "Extractor macro body defined on term of incorrect kind" + .to_string(), + )); + } + } + } + _ => {} + } + } + + Ok(()) + } + fn collect_rules(&mut self, tyenv: &mut TypeEnv, defs: &ast::Defs) -> SemaResult<()> { for def in &defs.defs { match def { @@ -431,9 +637,11 @@ impl TermEnv { &rule.pattern, None, &mut bindings, + None, )?; let rhs = self.translate_expr(tyenv, rule.pos, &rule.expr, ty, &mut bindings)?; + let rid = RuleId(self.rules.len()); self.rules.push(Rule { id: rid, @@ -459,35 +667,27 @@ impl TermEnv { )) } }; - match &mut self.terms[term_id.index()].kind { - &mut TermKind::EnumVariant { .. } => { + let termdata = &mut self.terms[term_id.index()]; + match &termdata.kind { + &TermKind::Declared => { + termdata.kind = TermKind::ExternalConstructor { name: func_sym }; + } + _ => { return Err(tyenv.error( pos, - format!("Constructor defined on enum type '{}'", term.0), + format!( + "Constructor defined on term of improper type '{}'", + term.0 + ), )); } - &mut TermKind::Regular { - ref mut constructor, - .. - } => { - if constructor.is_some() { - return Err(tyenv.error( - pos, - format!( - "Constructor defined more than once on term '{}'", - term.0 - ), - )); - } - *constructor = Some(func_sym); - } } } &ast::Def::Extern(ast::Extern::Extractor { ref term, ref func, pos, - infallible, + ref arg_polarity, }) => { let term_sym = tyenv.intern_mut(term); let func_sym = tyenv.intern_mut(func); @@ -500,27 +700,31 @@ impl TermEnv { )) } }; - match &mut self.terms[term_id.index()].kind { - &mut TermKind::EnumVariant { .. } => { + + let termdata = &mut self.terms[term_id.index()]; + + let arg_polarity = if let Some(pol) = arg_polarity.as_ref() { + if pol.len() != termdata.arg_tys.len() { + return Err(tyenv.error(pos, "Incorrect number of argument-polarity directions in extractor definition".to_string())); + } + pol.clone() + } else { + vec![ArgPolarity::Output; termdata.arg_tys.len()] + }; + + match &termdata.kind { + &TermKind::Declared => { + termdata.kind = TermKind::ExternalExtractor { + name: func_sym, + arg_polarity, + }; + } + _ => { return Err(tyenv.error( pos, - format!("Extractor defined on enum type '{}'", term.0), + format!("Extractor defined on term of improper type '{}'", term.0), )); } - &mut TermKind::Regular { - ref mut extractor, .. - } => { - if extractor.is_some() { - return Err(tyenv.error( - pos, - format!( - "Extractor defined more than once on term '{}'", - term.0 - ), - )); - } - *extractor = Some((func_sym, infallible)); - } } } _ => {} @@ -537,7 +741,10 @@ impl TermEnv { pat: &ast::Pattern, expected_ty: Option, bindings: &mut Bindings, + macro_args: Option<&HashMap>, ) -> SemaResult<(Pattern, TypeId)> { + log::trace!("translate_pattern: {:?}", pat); + log::trace!("translate_pattern: bindings = {:?}", bindings); match pat { // TODO: flag on primitive type decl indicating it's an integer type? &ast::Pattern::ConstInt { val } => { @@ -556,8 +763,14 @@ impl TermEnv { let mut expected_ty = expected_ty; let mut children = vec![]; for subpat in subpats { - let (subpat, ty) = - self.translate_pattern(tyenv, pos, &*subpat, expected_ty, bindings)?; + let (subpat, ty) = self.translate_pattern( + tyenv, + pos, + &*subpat, + expected_ty, + bindings, + macro_args, + )?; expected_ty = expected_ty.or(Some(ty)); children.push(subpat); } @@ -571,9 +784,29 @@ impl TermEnv { ref var, ref subpat, } => { + // Handle macro-arg substitution. + if macro_args.is_some() && &**subpat == &ast::Pattern::Wildcard { + if let Some(macro_ast) = macro_args.as_ref().unwrap().get(var) { + return self.translate_pattern( + tyenv, + pos, + macro_ast, + expected_ty, + bindings, + macro_args, + ); + } + } + // Do the subpattern first so we can resolve the type for sure. - let (subpat, ty) = - self.translate_pattern(tyenv, pos, &*subpat, expected_ty, bindings)?; + let (subpat, ty) = self.translate_pattern( + tyenv, + pos, + &*subpat, + expected_ty, + bindings, + macro_args, + )?; let name = tyenv.intern_mut(var); if bindings.vars.iter().any(|bv| bv.name == name) { @@ -644,12 +877,85 @@ impl TermEnv { )); } + let termdata = &self.terms[tid.index()]; + + match &termdata.kind { + &TermKind::EnumVariant { .. } => { + for arg in args { + if let &ast::TermArgPattern::Expr(..) = arg { + return Err(tyenv.error(pos, format!("Term in pattern '{}' cannot have an injected expr, because it is an enum variant", sym.0))); + } + } + } + &TermKind::ExternalExtractor { + ref arg_polarity, .. + } => { + for (arg, pol) in args.iter().zip(arg_polarity.iter()) { + match (arg, pol) { + (&ast::TermArgPattern::Expr(..), &ArgPolarity::Input) => {} + (&ast::TermArgPattern::Expr(..), &ArgPolarity::Output) => { + return Err(tyenv.error( + pos, + "Expression used for output-polarity extractor arg" + .to_string(), + )); + } + (_, &ArgPolarity::Output) => {} + (_, &ArgPolarity::Input) => { + return Err(tyenv.error(pos, "Non-expression used in pattern but expression required for input-polarity extractor arg".to_string())); + } + } + } + } + &TermKind::InternalExtractor { + args: ref template_args, + ref template, + } => { + // Expand the extractor macro! We create a map + // from macro args to AST pattern trees and + // then evaluate the template with these + // substitutions. + let mut arg_map = HashMap::new(); + for (template_arg, sub_ast) in template_args.iter().zip(args.iter()) { + let sub_ast = match sub_ast { + &ast::TermArgPattern::Pattern(ref pat) => pat.clone(), + &ast::TermArgPattern::Expr(_) => { + return Err(tyenv.error(pos, "Cannot expand an extractor macro with an expression in a macro argument".to_string())); + } + }; + arg_map.insert(template_arg.clone(), sub_ast.clone()); + } + log::trace!("internal extractor map = {:?}", arg_map); + return self.translate_pattern( + tyenv, + pos, + template, + expected_ty, + bindings, + Some(&arg_map), + ); + } + &TermKind::ExternalConstructor { .. } | &TermKind::InternalConstructor => { + // OK. + } + &TermKind::Declared => { + return Err(tyenv + .error(pos, format!("Declared but undefined term '{}' used", sym.0))); + } + } + // Resolve subpatterns. let mut subpats = vec![]; for (i, arg) in args.iter().enumerate() { let arg_ty = self.terms[tid.index()].arg_tys[i]; - let (subpat, _) = - self.translate_pattern(tyenv, pos, arg, Some(arg_ty), bindings)?; + let (subpat, _) = self.translate_pattern_term_arg( + tyenv, + pos, + arg, + Some(arg_ty), + bindings, + macro_args, + )?; subpats.push(subpat); } @@ -658,6 +964,35 @@ impl TermEnv { } } + fn translate_pattern_term_arg( + &self, + tyenv: &mut TypeEnv, + pos: Pos, + pat: &ast::TermArgPattern, + expected_ty: Option, + bindings: &mut Bindings, + macro_args: Option<&HashMap>, + ) -> SemaResult<(TermArgPattern, TypeId)> { + match pat { + &ast::TermArgPattern::Pattern(ref pat) => { + let (subpat, ty) = + self.translate_pattern(tyenv, pos, pat, expected_ty, bindings, macro_args)?; + Ok((TermArgPattern::Pattern(subpat), ty)) + } + &ast::TermArgPattern::Expr(ref expr) => { + if expected_ty.is_none() { + return Err(tyenv.error( + pos, + "Expression in pattern must have expected type".to_string(), + )); + } + let ty = expected_ty.unwrap(); + let expr = self.translate_expr(tyenv, pos, expr, expected_ty.unwrap(), bindings)?; + Ok((TermArgPattern::Expr(expr), ty)) + } + } + } + fn translate_expr( &self, tyenv: &mut TypeEnv, @@ -867,33 +1202,4 @@ mod test { ] ); } - - #[test] - fn build_rules() { - let text = r" - (type u32 (primitive u32)) - (type A extern (enum (B (f1 u32) (f2 u32)) (C (f1 u32)))) - - (decl T1 (A) u32) - (decl T2 (A A) A) - (decl T3 (u32) A) - - (constructor T1 t1_ctor) - (extractor T2 t2_etor) - - (rule - (T1 _) 1) - (rule - (T2 x =x) (T3 42)) - (rule - (T3 1) (A.C 2)) - (rule -1 - (T3 _) (A.C 3)) - "; - let ast = Parser::new(Lexer::from_str(text, "file.isle")) - .parse_defs() - .expect("should parse"); - let mut tyenv = TypeEnv::from_ast(&ast).expect("should not have type-definition errors"); - let _ = TermEnv::from_ast(&mut tyenv, &ast).expect("could not typecheck rules"); - } }