From e5d76db97a11ba580df0b93984695a1368f1e10a Mon Sep 17 00:00:00 2001 From: Chris Fallin Date: Sat, 4 Sep 2021 17:01:56 -0700 Subject: [PATCH] WIP. --- cranelift/isle/src/codegen.rs | 391 ++++++++++++++++++++++++++++++++-- cranelift/isle/src/ir.rs | 44 ++-- cranelift/isle/src/sema.rs | 14 +- 3 files changed, 417 insertions(+), 32 deletions(-) diff --git a/cranelift/isle/src/codegen.rs b/cranelift/isle/src/codegen.rs index e97ef8b6ce..55f530f9c3 100644 --- a/cranelift/isle/src/codegen.rs +++ b/cranelift/isle/src/codegen.rs @@ -1,9 +1,9 @@ //! Generate Rust code from a series of Sequences. use crate::error::Error; -use crate::ir::{lower_rule, ExprSequence, PatternInst, PatternSequence}; -use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId}; -use std::collections::HashMap; +use crate::ir::{lower_rule, ExprInst, ExprSequence, InstId, PatternInst, PatternSequence, Value}; +use crate::sema::{RuleId, Term, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId}; +use std::collections::{HashMap, HashSet}; use std::fmt::Write; /// One "input symbol" for the decision tree that handles matching on @@ -495,6 +495,11 @@ pub struct Codegen<'a> { functions_by_output: HashMap, } +#[derive(Clone, Debug, Default)] +struct BodyContext { + borrowed_values: HashSet, +} + impl<'a> Codegen<'a> { pub fn compile(typeenv: &'a TypeEnv, termenv: &'a TermEnv) -> Result, Error> { let mut builder = TermFunctionsBuilder::new(typeenv, termenv); @@ -591,19 +596,70 @@ impl<'a> Codegen<'a> { } } - fn type_name(&self, typeid: TypeId, by_ref: bool) -> String { + fn extractor_name_and_infallible(&self, term: TermId) -> (String, bool) { + let termdata = &self.termenv.terms[term.index()]; + match &termdata.kind { + &TermKind::EnumVariant { .. } => panic!("using enum variant as extractor"), + &TermKind::Regular { + extractor: Some((sym, infallible)), + .. + } => (self.typeenv.syms[sym.index()].clone(), infallible), + &TermKind::Regular { + extractor: None, .. + } => ( + format!("extractor_{}", self.typeenv.syms[termdata.name.index()]), + false, + ), + } + } + + fn type_name(&self, typeid: TypeId, by_ref: Option<&str>) -> String { match &self.typeenv.types[typeid.index()] { &Type::Primitive(_, sym) => self.typeenv.syms[sym.index()].clone(), &Type::Enum { name, .. } => { - let r = if by_ref { "&" } else { "" }; + let r = by_ref.unwrap_or(""); format!("{}{}", r, self.typeenv.syms[name.index()]) } } } + fn value_name(&self, value: &Value) -> String { + match value { + &Value::Pattern { inst, output } => format!("pattern{}_{}", inst.index(), output), + &Value::Expr { inst, output } => format!("expr{}_{}", inst.index(), output), + } + } + + fn value_by_ref(&self, value: &Value, ctx: &BodyContext) -> String { + let raw_name = self.value_name(value); + let name_is_ref = ctx.borrowed_values.contains(value); + if name_is_ref { + raw_name + } else { + format!("&{}", raw_name) + } + } + + fn value_by_val(&self, value: &Value, ctx: &BodyContext) -> String { + let raw_name = self.value_name(value); + let name_is_ref = ctx.borrowed_values.contains(value); + if name_is_ref { + format!("{}.clone()", raw_name) + } else { + raw_name + } + } + + fn define_val(&self, value: &Value, ctx: &mut BodyContext, is_ref: bool) { + if is_ref { + ctx.borrowed_values.insert(value.clone()); + } + } + fn generate_internal_term_constructors(&self, code: &mut dyn Write) -> Result<(), Error> { for (&termid, trie) in &self.functions_by_input { let termdata = &self.termenv.terms[termid.index()]; + // Skip terms that are enum variants or that have external constructors. match &termdata.kind { &TermKind::EnumVariant { .. } => continue, @@ -618,7 +674,11 @@ impl<'a> Codegen<'a> { .iter() .enumerate() .map(|(i, &arg_ty)| { - format!("arg{}: {}", i, self.type_name(arg_ty, /* by_ref = */ true)) + format!( + "arg{}: {}", + i, + self.type_name(arg_ty, /* by_ref = */ Some("&")) + ) }) .collect::>(); writeln!( @@ -631,10 +691,11 @@ impl<'a> Codegen<'a> { "fn {}(ctx: &mut C, {}) -> Option<{}> {{", func_name, args.join(", "), - self.type_name(termdata.ret_ty, /* by_ref = */ false) + self.type_name(termdata.ret_ty, /* by_ref = */ None) )?; - self.generate_body(code, termid, trie)?; + let mut body_ctx = Default::default(); + self.generate_body(code, /* depth = */ 0, trie, " ", &mut body_ctx)?; writeln!(code, "}}")?; } @@ -642,16 +703,322 @@ impl<'a> Codegen<'a> { Ok(()) } - fn generate_internal_term_extractors(&self, _code: &mut dyn Write) -> Result<(), Error> { + fn generate_internal_term_extractors(&self, code: &mut dyn Write) -> Result<(), Error> { + for (&termid, trie) in &self.functions_by_output { + let termdata = &self.termenv.terms[termid.index()]; + + // Skip terms that are enum variants or that have external extractors. + match &termdata.kind { + &TermKind::EnumVariant { .. } => continue, + &TermKind::Regular { extractor, .. } if extractor.is_some() => continue, + _ => {} + } + + // Get the name of the term and build up the signature. + let (func_name, _) = self.extractor_name_and_infallible(termid); + let arg = format!( + "arg: {}", + self.type_name(termdata.ret_ty, /* by_ref = */ Some("&")) + ); + let ret_tuple_tys = termdata + .arg_tys + .iter() + .map(|ty| { + self.type_name(*ty, /* by_ref = */ None) + }) + .collect::>(); + + writeln!( + code, + "\n// Generated as internal extractor for term {}.", + self.typeenv.syms[termdata.name.index()], + )?; + writeln!( + code, + "fn {}<'a, C>(ctx: &mut C, {}) -> Option<({})> {{", + func_name, + arg, + ret_tuple_tys.join(", "), + )?; + + let mut body_ctx = Default::default(); + self.generate_extractor_header(code, termdata, &mut body_ctx)?; + self.generate_body(code, /* depth = */ 0, trie, " ", &mut body_ctx)?; + writeln!(code, " }}")?; + writeln!(code, "}}")?; + } + + Ok(()) + } + + fn generate_extractor_header( + &self, + code: &mut dyn Write, + termdata: &Term, + ctx: &mut BodyContext, + ) -> Result<(), Error> { + writeln!(code, " {{")?; + todo!(); + Ok(()) + } + + fn generate_expr_inst( + &self, + code: &mut dyn Write, + id: InstId, + inst: &ExprInst, + indent: &str, + ctx: &mut BodyContext, + ) -> Result<(), Error> { + match inst { + &ExprInst::ConstInt { ty, val } => { + let value = Value::Expr { + inst: id, + output: 0, + }; + let name = self.value_name(&value); + let ty = self.type_name(ty, /* by_ref = */ None); + self.define_val(&value, ctx, /* is_ref = */ false); + writeln!(code, "{}let {}: {} = {};", indent, name, ty, val)?; + } + &ExprInst::CreateVariant { + ref inputs, + ty, + variant, + } => { + let variantinfo = match &self.typeenv.types[ty.index()] { + &Type::Primitive(..) => panic!("CreateVariant with primitive type"), + &Type::Enum { ref variants, .. } => &variants[variant.index()], + }; + let mut input_fields = vec![]; + for ((input_value, _), field) in inputs.iter().zip(variantinfo.fields.iter()) { + let field_name = &self.typeenv.syms[field.name.index()]; + let value_expr = self.value_by_val(input_value, ctx); + input_fields.push(format!("{}: {}", field_name, value_expr)); + } + + let output = Value::Expr { + inst: id, + output: 0, + }; + let outputname = self.value_name(&output); + let full_variant_name = format!( + "{}::{}", + self.type_name(ty, None), + self.typeenv.syms[variantinfo.name.index()] + ); + writeln!( + code, + "{}let {} = {} {{", + indent, outputname, full_variant_name + )?; + for input_field in input_fields { + writeln!(code, "{} {},", indent, input_field)?; + } + writeln!(code, "{}}};", indent)?; + self.define_val(&output, ctx, /* is_ref = */ false); + } + &ExprInst::Construct { + ref inputs, term, .. + } => { + let mut input_exprs = vec![]; + for (input_value, _) in inputs { + let value_expr = self.value_by_val(input_value, ctx); + input_exprs.push(value_expr); + } + + let output = Value::Expr { + inst: id, + output: 0, + }; + let outputname = self.value_name(&output); + let ctor_name = self.constructor_name(term); + writeln!( + code, + "{}let {} = {}(ctx, {});", + indent, + outputname, + ctor_name, + input_exprs.join(", "), + )?; + self.define_val(&output, ctx, /* is_ref = */ false); + } + &ExprInst::Return { ref value, .. } => { + let value_expr = self.value_by_val(value, ctx); + writeln!(code, "{}return Some({});", indent, value_expr)?; + } + } + + Ok(()) + } + + fn generate_pattern_inst( + &self, + code: &mut dyn Write, + id: InstId, + inst: &PatternInst, + indent: &str, + ctx: &mut BodyContext, + ) -> Result<(), Error> { + match inst { + &PatternInst::Arg { index, .. } => { + let output = Value::Expr { + inst: id, + output: 0, + }; + let outputname = self.value_name(&output); + writeln!(code, "{}let {} = arg{};", indent, outputname, index)?; + writeln!(code, "{}{{", indent)?; + } + &PatternInst::MatchEqual { ref a, ref b, .. } => { + let a = self.value_by_ref(a, ctx); + let b = self.value_by_ref(b, ctx); + writeln!(code, "{}if {} == {} {{", indent, a, b)?; + } + &PatternInst::MatchInt { + ref input, int_val, .. + } => { + let input = self.value_by_val(input, ctx); + writeln!(code, "{}if {} == {} {{", indent, input, int_val)?; + } + &PatternInst::MatchVariant { + ref input, + input_ty, + variant, + ref arg_tys, + } => { + let input = self.value_by_ref(input, ctx); + let variants = match &self.typeenv.types[input_ty.index()] { + &Type::Primitive(..) => panic!("primitive type input to MatchVariant"), + &Type::Enum { ref variants, .. } => variants, + }; + let ty_name = self.type_name(input_ty, /* is_ref = */ Some("&")); + let variant = &variants[variant.index()]; + let variantname = &self.typeenv.syms[variant.name.index()]; + let args = arg_tys + .iter() + .enumerate() + .map(|(i, ty)| { + let value = Value::Pattern { + inst: id, + output: i, + }; + let valuename = self.value_name(&value); + match &self.typeenv.types[ty.index()] { + &Type::Primitive(..) => { + self.define_val(&value, ctx, /* is_ref = */ false); + valuename + } + &Type::Enum { .. } => { + self.define_val(&value, ctx, /* is_ref = */ true); + format!("ref {}", valuename) + } + } + }) + .collect::>(); + writeln!( + code, + "{}if let {}::{} {{ {} }} = {} {{", + indent, + ty_name, + variantname, + args.join(", "), + input + )?; + } + &PatternInst::Extract { + ref input, + input_ty, + ref arg_tys, + term, + } => { + let input = self.value_by_ref(input, ctx); + let (etor_name, infallible) = self.extractor_name_and_infallible(term); + + let args = arg_tys + .iter() + .enumerate() + .map(|(i, ty)| { + let value = Value::Pattern { + inst: id, + output: i, + }; + self.define_val(&value, ctx, /* is_ref = */ false); + self.value_name(&value) + }) + .collect::>(); + + if infallible { + writeln!( + code, + "{}let Some(({})) = {}(ctx, {});", + indent, + args.join(", "), + etor_name, + input + )?; + writeln!(code, "{}{{", indent)?; + } else { + writeln!( + code, + "{}if let Some(({})) = {}(ctx, {}) {{", + indent, + args.join(", "), + etor_name, + input + )?; + } + } + } + Ok(()) } fn generate_body( &self, - _code: &mut dyn Write, - _termid: TermId, - _trie: &TrieNode, + code: &mut dyn Write, + depth: usize, + trie: &TrieNode, + indent: &str, + ctx: &mut BodyContext, ) -> Result<(), Error> { + match trie { + &TrieNode::Empty => {} + + &TrieNode::Leaf { ref output, .. } => { + // If this is a leaf node, generate the ExprSequence and return. + for (id, inst) in output.insts.iter().enumerate() { + let id = InstId(id); + self.generate_expr_inst(code, id, inst, indent, ctx)?; + } + } + + &TrieNode::Decision { ref edges } => { + let subindent = format!("{} ", indent); + // if this is a decision node, generate each match op + // in turn (in priority order). + for &TrieEdge { + ref symbol, + ref node, + .. + } in edges + { + match symbol { + &TrieSymbol::EndOfMatch => { + self.generate_body(code, depth + 1, node, &subindent, ctx)?; + } + &TrieSymbol::Match { ref op } => { + let id = InstId(depth); + self.generate_pattern_inst(code, id, op, &subindent, ctx)?; + self.generate_body(code, depth + 1, node, &subindent, ctx)?; + writeln!(code, "{}}}", subindent)?; + } + } + } + } + } + + writeln!(code, "{}return None;", indent)?; Ok(()) } } diff --git a/cranelift/isle/src/ir.rs b/cranelift/isle/src/ir.rs index cde82e748d..b1f3057619 100644 --- a/cranelift/isle/src/ir.rs +++ b/cranelift/isle/src/ir.rs @@ -17,8 +17,9 @@ pub enum Value { /// A single Pattern instruction. #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum PatternInst { - /// Get the input root-term value. - Arg { ty: TypeId }, + /// Get the Nth input argument, which corresponds to the Nth field + /// of the root term. + Arg { index: usize, ty: TypeId }, /// Match a value as equal to another value. Produces no values. MatchEqual { a: Value, b: Value, ty: TypeId }, @@ -118,9 +119,9 @@ impl PatternSequence { id } - fn add_arg(&mut self, ty: TypeId) -> Value { + fn add_arg(&mut self, index: usize, ty: TypeId) -> Value { let inst = InstId(self.insts.len()); - self.add_inst(PatternInst::Arg { ty }); + self.add_inst(PatternInst::Arg { index, ty }); Value::Pattern { inst, output: 0 } } @@ -183,7 +184,9 @@ impl PatternSequence { /// this pattern, if any. fn gen_pattern( &mut self, - input: Value, + // If `input` is `None`, then this is the root pattern, and is + // implicitly an extraction with the N args as results. + input: Option, typeenv: &TypeEnv, termenv: &TermEnv, pat: &Pattern, @@ -193,8 +196,9 @@ impl PatternSequence { &Pattern::BindPattern(_ty, var, ref subpat) => { // Bind the appropriate variable and recurse. assert!(!vars.contains_key(&var)); - vars.insert(var, (None, input)); // bind first, so subpat can use it - let root_term = self.gen_pattern(input, typeenv, termenv, &*subpat, vars); + vars.insert(var, (None, input.unwrap())); // bind first, so subpat can use it + let root_term = + self.gen_pattern(input, typeenv, termenv, &*subpat, vars); vars.get_mut(&var).unwrap().0 = root_term; root_term } @@ -204,30 +208,40 @@ impl PatternSequence { .get(&var) .cloned() .expect("Variable should already be bound"); - self.add_match_equal(input, var_val, ty); + self.add_match_equal(input.unwrap(), var_val, ty); var_val_term } &Pattern::ConstInt(ty, value) => { // Assert that the value matches the constant integer. - self.add_match_int(input, ty, value); + self.add_match_int(input.unwrap(), ty, value); None } + &Pattern::Term(_, term, ref args) if input.is_none() => { + let termdata = &termenv.terms[term.index()]; + let arg_tys = &termdata.arg_tys[..]; + for (i, subpat) in args.iter().enumerate() { + let value = self.add_arg(i, arg_tys[i]); + self.gen_pattern(Some(value), typeenv, termenv, subpat, vars); + } + Some(term) + } &Pattern::Term(ty, term, ref args) => { // Determine whether the term has an external extractor or not. let termdata = &termenv.terms[term.index()]; let arg_tys = &termdata.arg_tys[..]; match &termdata.kind { &TermKind::EnumVariant { variant } => { - let arg_values = self.add_match_variant(input, ty, arg_tys, variant); + let arg_values = + self.add_match_variant(input.unwrap(), ty, arg_tys, variant); for (subpat, value) in args.iter().zip(arg_values.into_iter()) { - self.gen_pattern(value, typeenv, termenv, subpat, vars); + self.gen_pattern(Some(value), typeenv, termenv, subpat, vars); } None } &TermKind::Regular { .. } => { - let arg_values = self.add_extract(input, ty, arg_tys, term); + let arg_values = self.add_extract(input.unwrap(), ty, arg_tys, term); for (subpat, value) in args.iter().zip(arg_values.into_iter()) { - self.gen_pattern(value, typeenv, termenv, subpat, vars); + self.gen_pattern(Some(value), typeenv, termenv, subpat, vars); } Some(term) } @@ -341,10 +355,8 @@ pub fn lower_rule( // Lower the pattern, starting from the root input value. let ruledata = &termenv.rules[rule.index()]; - let input_ty = ruledata.lhs.ty(); - let input = pattern_seq.add_arg(input_ty); let mut vars = HashMap::new(); - let lhs_root_term = pattern_seq.gen_pattern(input, tyenv, termenv, &ruledata.lhs, &mut vars); + let lhs_root_term = pattern_seq.gen_pattern(None, tyenv, termenv, &ruledata.lhs, &mut vars); // Lower the expression, making use of the bound variables // from the pattern. diff --git a/cranelift/isle/src/sema.rs b/cranelift/isle/src/sema.rs index 2ff6f3cc1c..b9c3d2eeac 100644 --- a/cranelift/isle/src/sema.rs +++ b/cranelift/isle/src/sema.rs @@ -772,6 +772,7 @@ impl TermEnv { mod test { use super::*; use crate::ast::Ident; + use crate::lexer::Lexer; use crate::parser::Parser; #[test] @@ -780,14 +781,16 @@ mod test { (type u32 (primitive u32)) (type A extern (enum (B (f1 u32) (f2 u32)) (C (f1 u32)))) "; - let ast = Parser::new("file.isle", text) + let ast = Parser::new(Lexer::from_str(text, "file.isle")) .parse_defs() .expect("should parse"); let tyenv = TypeEnv::from_ast(&ast).expect("should not have type-definition errors"); let sym_a = tyenv.intern(&Ident("A".to_string())).unwrap(); - let sym_b = tyenv.intern(&Ident("A.B".to_string())).unwrap(); - let sym_c = tyenv.intern(&Ident("A.C".to_string())).unwrap(); + let sym_b = tyenv.intern(&Ident("B".to_string())).unwrap(); + let sym_c = tyenv.intern(&Ident("C".to_string())).unwrap(); + let sym_a_b = tyenv.intern(&Ident("A.B".to_string())).unwrap(); + let sym_a_c = tyenv.intern(&Ident("A.C".to_string())).unwrap(); let sym_u32 = tyenv.intern(&Ident("u32".to_string())).unwrap(); let sym_f1 = tyenv.intern(&Ident("f1".to_string())).unwrap(); let sym_f2 = tyenv.intern(&Ident("f2".to_string())).unwrap(); @@ -806,6 +809,7 @@ mod test { variants: vec![ Variant { name: sym_b, + fullname: sym_a_b, id: VariantId(0), fields: vec![ Field { @@ -822,6 +826,7 @@ mod test { }, Variant { name: sym_c, + fullname: sym_a_c, id: VariantId(1), fields: vec![Field { name: sym_f1, @@ -831,6 +836,7 @@ mod test { }, ], pos: Pos { + file: 0, offset: 58, line: 3, col: 18, @@ -862,7 +868,7 @@ mod test { (rule -1 (T3 _) (A.C 3)) "; - let ast = Parser::new("file.isle", text) + let ast = Parser::new(Lexer::from_str(text, "file.isle")) .parse_defs() .expect("should parse"); let mut tyenv = TypeEnv::from_ast(&ast).expect("should not have type-definition errors");