From be1140e80aa4d76fc7bf553059280088f0be8b97 Mon Sep 17 00:00:00 2001 From: Chris Fallin Date: Sat, 4 Sep 2021 18:45:03 -0700 Subject: [PATCH] WIP. --- cranelift/isle/examples/test.isle | 9 +++ cranelift/isle/src/codegen.rs | 77 ++++++++++++--------- cranelift/isle/src/ir.rs | 109 +++++++++++++++++++++++++++++- 3 files changed, 159 insertions(+), 36 deletions(-) diff --git a/cranelift/isle/examples/test.isle b/cranelift/isle/examples/test.isle index 1ea1c3ce98..39125a94d1 100644 --- a/cranelift/isle/examples/test.isle +++ b/cranelift/isle/examples/test.isle @@ -10,3 +10,12 @@ (rule (Lower (A.A1 sub @ (Input (A.A2 42)))) (B.B2 sub)) + +(decl Extractor (B) A) +(rule + (A.A2 x) + (Extractor (B.B1 x))) + +(rule + (Lower (Extractor b)) + b) diff --git a/cranelift/isle/src/codegen.rs b/cranelift/isle/src/codegen.rs index 001c1f36b9..19c03a4692 100644 --- a/cranelift/isle/src/codegen.rs +++ b/cranelift/isle/src/codegen.rs @@ -1,8 +1,8 @@ //! Generate Rust code from a series of Sequences. -use crate::error::Error; use crate::ir::{lower_rule, ExprInst, ExprSequence, InstId, PatternInst, PatternSequence, Value}; -use crate::sema::{RuleId, Term, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId}; +use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId}; +use crate::{error::Error, ir::reverse_rule}; use std::collections::{HashMap, HashSet}; use std::fmt::Write; @@ -463,11 +463,14 @@ impl<'a> TermFunctionsBuilder<'a> { .or_insert_with(|| TermFunctionBuilder::new(input_root_term)) .add_rule(prio, pattern.clone(), expr.clone()); } + if let Some(output_root_term) = rhs_root { - self.builders_by_output - .entry(output_root_term) - .or_insert_with(|| TermFunctionBuilder::new(output_root_term)) - .add_rule(prio, pattern, expr); + if let Some((reverse_pattern, reverse_expr)) = reverse_rule(&pattern, &expr) { + self.builders_by_output + .entry(output_root_term) + .or_insert_with(|| TermFunctionBuilder::new(output_root_term)) + .add_rule(prio, reverse_pattern, reverse_expr); + } } } } @@ -498,6 +501,8 @@ pub struct Codegen<'a> { #[derive(Clone, Debug, Default)] struct BodyContext { borrowed_values: HashSet, + expected_return_vals: usize, + tuple_return: bool, } impl<'a> Codegen<'a> { @@ -694,7 +699,8 @@ impl<'a> Codegen<'a> { self.type_name(termdata.ret_ty, /* by_ref = */ None) )?; - let mut body_ctx = Default::default(); + let mut body_ctx: BodyContext = Default::default(); + body_ctx.expected_return_vals = 1; self.generate_body(code, /* depth = */ 0, trie, " ", &mut body_ctx)?; writeln!(code, "}}")?; @@ -717,7 +723,7 @@ impl<'a> Codegen<'a> { // Get the name of the term and build up the signature. let (func_name, _) = self.extractor_name_and_infallible(termid); let arg = format!( - "arg: {}", + "arg0: {}", self.type_name(termdata.ret_ty, /* by_ref = */ Some("&")) ); let ret_tuple_tys = termdata @@ -735,14 +741,15 @@ impl<'a> Codegen<'a> { )?; writeln!( code, - "fn {}<'a, C>(ctx: &mut C, {}) -> Option<({})> {{", + "fn {}(ctx: &mut C, {}) -> Option<({},)> {{", func_name, arg, ret_tuple_tys.join(", "), )?; - let mut body_ctx = Default::default(); - self.generate_extractor_header(code, termdata, &mut body_ctx)?; + let mut body_ctx: BodyContext = Default::default(); + body_ctx.expected_return_vals = ret_tuple_tys.len(); + body_ctx.tuple_return = true; self.generate_body(code, /* depth = */ 0, trie, " ", &mut body_ctx)?; writeln!(code, " }}")?; writeln!(code, "}}")?; @@ -751,17 +758,6 @@ impl<'a> Codegen<'a> { Ok(()) } - fn generate_extractor_header( - &self, - code: &mut dyn Write, - termdata: &Term, - ctx: &mut BodyContext, - ) -> Result<(), Error> { - writeln!(code, " {{")?; - todo!(); - Ok(()) - } - fn generate_expr_inst( &self, code: &mut dyn Write, @@ -769,6 +765,7 @@ impl<'a> Codegen<'a> { inst: &ExprInst, indent: &str, ctx: &mut BodyContext, + returns: &mut Vec<(usize, String)>, ) -> Result<(), Error> { match inst { &ExprInst::ConstInt { ty, val } => { @@ -843,9 +840,11 @@ impl<'a> Codegen<'a> { )?; self.define_val(&output, ctx, /* is_ref = */ false); } - &ExprInst::Return { ref value, .. } => { + &ExprInst::Return { + index, ref value, .. + } => { let value_expr = self.value_by_val(value, ctx); - writeln!(code, "{}return Some({});", indent, value_expr)?; + returns.push((index, value_expr)); } } @@ -936,9 +935,9 @@ impl<'a> Codegen<'a> { } &PatternInst::Extract { ref input, - input_ty, ref arg_tys, term, + .. } => { let input = self.value_by_ref(input, ctx); let (etor_name, infallible) = self.extractor_name_and_infallible(term); @@ -946,7 +945,7 @@ impl<'a> Codegen<'a> { let args = arg_tys .iter() .enumerate() - .map(|(i, ty)| { + .map(|(i, _ty)| { let value = Value::Pattern { inst: id, output: i, @@ -959,7 +958,7 @@ impl<'a> Codegen<'a> { if infallible { writeln!( code, - "{}let Some(({})) = {}(ctx, {});", + "{}let Some(({},)) = {}(ctx, {});", indent, args.join(", "), etor_name, @@ -969,7 +968,7 @@ impl<'a> Codegen<'a> { } else { writeln!( code, - "{}if let Some(({})) = {}(ctx, {}) {{", + "{}if let Some(({},)) = {}(ctx, {}) {{", indent, args.join(", "), etor_name, @@ -1002,14 +1001,26 @@ impl<'a> Codegen<'a> { output.pos.pretty_print_line(&self.typeenv.filenames[..]) )?; // If this is a leaf node, generate the ExprSequence and return. + let mut returns = vec![]; for (id, inst) in output.insts.iter().enumerate() { let id = InstId(id); - self.generate_expr_inst(code, id, inst, indent, ctx)?; - if let &ExprInst::Return { .. } = inst { - returned = true; - break; - } + self.generate_expr_inst(code, id, inst, indent, ctx, &mut returns)?; } + + assert_eq!(returns.len(), ctx.expected_return_vals); + returns.sort_by_key(|(index, _)| *index); + if ctx.tuple_return { + let return_values = returns + .into_iter() + .map(|(_, expr)| expr) + .collect::>() + .join(", "); + writeln!(code, "{}return Some(({},));", indent, return_values)?; + } else { + writeln!(code, "{}return Some({});", indent, returns[0].1)?; + } + + returned = true; } &TrieNode::Decision { ref edges } => { diff --git a/cranelift/isle/src/ir.rs b/cranelift/isle/src/ir.rs index 33641c4de2..d12c9bdd2c 100644 --- a/cranelift/isle/src/ir.rs +++ b/cranelift/isle/src/ir.rs @@ -71,8 +71,8 @@ pub enum ExprInst { term: TermId, }, - /// Set the return value. Produces no values. - Return { ty: TypeId, value: Value }, + /// Set the Nth return value. Produces no values. + Return { index: usize, ty: TypeId, value: Value }, } impl ExprInst { @@ -294,7 +294,11 @@ impl ExprSequence { } fn add_return(&mut self, ty: TypeId, value: Value) { - self.add_inst(ExprInst::Return { ty, value }); + self.add_inst(ExprInst::Return { index: 0, ty, value }); + } + + fn add_multi_return(&mut self, index: usize, ty: TypeId, value: Value) { + self.add_inst(ExprInst::Return { index, ty, value }); } /// Creates a sequence of ExprInsts to generate the given @@ -370,3 +374,102 @@ pub fn lower_rule( (lhs_root_term, pattern_seq, rhs_root_term, expr_seq) } + +/// Reverse a sequence to form an extractor from a constructor. +pub fn reverse_rule( + orig_pat: &PatternSequence, + orig_expr: &ExprSequence, +) -> Option<( + PatternSequence, + ExprSequence, + )> +{ + let mut pattern_seq = PatternSequence::default(); + let mut expr_seq = ExprSequence::default(); + expr_seq.pos = orig_expr.pos; + + let mut value_map = HashMap::new(); + + for (id, inst) in orig_expr.insts.iter().enumerate().rev() { + let id = InstId(id); + match inst { + &ExprInst::Return { index, ty, value } => { + let new_value = pattern_seq.add_arg(index, ty); + value_map.insert(value, new_value); + } + &ExprInst::Construct { ref inputs, ty, term } => { + let arg_tys = inputs.iter().map(|(_, ty)| *ty).collect::>(); + let input_ty = ty; + // Input to the Extract is the output of the Construct. + let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone(); + let outputs = pattern_seq.add_extract(input, input_ty, &arg_tys[..], term); + for (input, output) in inputs.iter().map(|(val, _)| val).zip(outputs.into_iter()) { + value_map.insert(*input, output); + } + } + &ExprInst::CreateVariant { ref inputs, ty, variant } => { + let arg_tys = inputs.iter().map(|(_, ty)| *ty).collect::>(); + let input_ty = ty; + // Input to the MatchVariant is the output of the CreateVariant. + let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone(); + let outputs = pattern_seq.add_match_variant(input, input_ty, &arg_tys[..], variant); + for (input, output) in inputs.iter().map(|(val, _)| val).zip(outputs.into_iter()) { + value_map.insert(*input, output); + } + } + &ExprInst::ConstInt { ty, val } => { + let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone(); + pattern_seq.add_match_int(input, ty, val); + } + } + } + + for (id, inst) in orig_pat.insts.iter().enumerate().rev() { + let id = InstId(id); + match inst { + &PatternInst::Extract { input, input_ty, ref arg_tys, term } => { + let mut inputs = vec![]; + for i in 0..arg_tys.len() { + let value = Value::Pattern { inst: id, output: i }; + let new_value = value_map.get(&value)?.clone(); + inputs.push((new_value, arg_tys[i])); + } + let output = expr_seq.add_construct(&inputs[..], input_ty, term); + value_map.insert(input, output); + + } + &PatternInst::MatchVariant { input, input_ty, ref arg_tys, variant } => { + let mut inputs = vec![]; + for i in 0..arg_tys.len() { + let value = Value::Pattern { inst: id, output: i }; + let new_value = value_map.get(&value)?.clone(); + inputs.push((new_value, arg_tys[i])); + } + let output = expr_seq.add_create_variant(&inputs[..], input_ty, variant); + value_map.insert(input, output); + } + &PatternInst::MatchEqual { a, b, .. } => { + if let Some(new_a) = value_map.get(&a).cloned() { + if !value_map.contains_key(&b) { + value_map.insert(b, new_a); + } + } else if let Some(new_b) = value_map.get(&b).cloned() { + if !value_map.contains_key(&a) { + value_map.insert(a, new_b); + } + } + } + &PatternInst::MatchInt { input, ty, int_val } => { + let output = expr_seq.add_const_int(ty, int_val); + value_map.insert(input, output); + } + &PatternInst::Arg { index, ty } => { + let value = Value::Pattern { inst: id, output: 0 }; + let new_value = value_map.get(&value)?.clone(); + expr_seq.add_multi_return(index, ty, new_value); + } + } + } + + Some((pattern_seq, expr_seq)) +}