diff --git a/cranelift/isle/src/codegen.rs b/cranelift/isle/src/codegen.rs index 19c03a4692..cfc4f5eb77 100644 --- a/cranelift/isle/src/codegen.rs +++ b/cranelift/isle/src/codegen.rs @@ -1,8 +1,8 @@ //! Generate Rust code from a series of Sequences. +use crate::error::Error; use crate::ir::{lower_rule, ExprInst, ExprSequence, InstId, PatternInst, PatternSequence, Value}; use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId}; -use crate::{error::Error, ir::reverse_rule}; use std::collections::{HashMap, HashSet}; use std::fmt::Write; @@ -448,29 +448,40 @@ impl<'a> TermFunctionsBuilder<'a> { let rule = RuleId(rule); let prio = self.termenv.rules[rule.index()].prio.unwrap_or(0); - let (lhs_root, pattern, rhs_root, expr) = lower_rule(self.typeenv, self.termenv, rule); - log::trace!( - "build:\n- rule {:?}\n- lhs_root {:?} rhs_root {:?}\n- pattern {:?}\n- expr {:?}", - self.termenv.rules[rule.index()], - lhs_root, - rhs_root, - pattern, - expr - ); - if let Some(input_root_term) = lhs_root { + if let Some((pattern, expr, lhs_root)) = lower_rule( + self.typeenv, + self.termenv, + rule, + /* forward_dir = */ true, + ) { + log::trace!( + "build:\n- rule {:?}\n- fwd pattern {:?}\n- fwd expr {:?}", + self.termenv.rules[rule.index()], + pattern, + expr + ); self.builders_by_input - .entry(input_root_term) - .or_insert_with(|| TermFunctionBuilder::new(input_root_term)) + .entry(lhs_root) + .or_insert_with(|| TermFunctionBuilder::new(lhs_root)) .add_rule(prio, pattern.clone(), expr.clone()); } - if let Some(output_root_term) = rhs_root { - if let Some((reverse_pattern, reverse_expr)) = reverse_rule(&pattern, &expr) { - self.builders_by_output - .entry(output_root_term) - .or_insert_with(|| TermFunctionBuilder::new(output_root_term)) - .add_rule(prio, reverse_pattern, reverse_expr); - } + if let Some((pattern, expr, rhs_root)) = lower_rule( + self.typeenv, + self.termenv, + rule, + /* forward_dir = */ false, + ) { + log::trace!( + "build:\n- rule {:?}\n- rev pattern {:?}\n- rev expr {:?}", + self.termenv.rules[rule.index()], + pattern, + expr + ); + self.builders_by_output + .entry(rhs_root) + .or_insert_with(|| TermFunctionBuilder::new(rhs_root)) + .add_rule(prio, pattern, expr); } } } diff --git a/cranelift/isle/src/ir.rs b/cranelift/isle/src/ir.rs index d12c9bdd2c..4250f32320 100644 --- a/cranelift/isle/src/ir.rs +++ b/cranelift/isle/src/ir.rs @@ -72,7 +72,11 @@ pub enum ExprInst { }, /// Set the Nth return value. Produces no values. - Return { index: usize, ty: TypeId, value: Value }, + Return { + index: usize, + ty: TypeId, + value: Value, + }, } impl ExprInst { @@ -294,7 +298,11 @@ impl ExprSequence { } fn add_return(&mut self, ty: TypeId, value: Value) { - self.add_inst(ExprInst::Return { index: 0, ty, value }); + self.add_inst(ExprInst::Return { + index: 0, + ty, + value, + }); } fn add_multi_return(&mut self, index: usize, ty: TypeId, value: Value) { @@ -304,40 +312,55 @@ impl ExprSequence { /// Creates a sequence of ExprInsts to generate the given /// expression value. Returns the value ID as well as the root /// term ID, if any. + /// + /// If `gen_final_construct` is false and the value is a + /// constructor call, this returns the arguments instead. This is + /// used when codegen'ing extractors for internal terms. fn gen_expr( &mut self, typeenv: &TypeEnv, termenv: &TermEnv, expr: &Expr, vars: &HashMap, Value)>, - ) -> (Option, Value) { + gen_final_construct: bool, + ) -> (Option, Vec) { match expr { - &Expr::ConstInt(ty, val) => (None, self.add_const_int(ty, val)), + &Expr::ConstInt(ty, val) => (None, vec![self.add_const_int(ty, val)]), &Expr::Let(_ty, ref bindings, ref subexpr) => { let mut vars = vars.clone(); for &(var, _var_ty, ref var_expr) in bindings { let (var_value_term, var_value) = - self.gen_expr(typeenv, termenv, &*var_expr, &vars); + self.gen_expr(typeenv, termenv, &*var_expr, &vars, false); + let var_value = var_value[0]; vars.insert(var, (var_value_term, var_value)); } - self.gen_expr(typeenv, termenv, &*subexpr, &vars) + self.gen_expr(typeenv, termenv, &*subexpr, &vars, gen_final_construct) + } + &Expr::Var(_ty, var_id) => { + let (root_term, value) = vars.get(&var_id).cloned().unwrap(); + (root_term, vec![value]) } - &Expr::Var(_ty, var_id) => vars.get(&var_id).cloned().unwrap(), &Expr::Term(ty, term, ref arg_exprs) => { let termdata = &termenv.terms[term.index()]; let mut arg_values_tys = vec![]; for (arg_ty, arg_expr) in termdata.arg_tys.iter().cloned().zip(arg_exprs.iter()) { - arg_values_tys - .push((self.gen_expr(typeenv, termenv, &*arg_expr, &vars).1, arg_ty)); + arg_values_tys.push(( + self.gen_expr(typeenv, termenv, &*arg_expr, &vars, false).1[0], + arg_ty, + )); } match &termdata.kind { &TermKind::EnumVariant { variant } => ( None, - self.add_create_variant(&arg_values_tys[..], ty, variant), + vec![self.add_create_variant(&arg_values_tys[..], ty, variant)], + ), + &TermKind::Regular { .. } if !gen_final_construct => ( + Some(termdata.id), + arg_values_tys.into_iter().map(|(val, _ty)| val).collect(), ), &TermKind::Regular { .. } => ( Some(termdata.id), - self.add_construct(&arg_values_tys[..], ty, term), + vec![self.add_construct(&arg_values_tys[..], ty, term)], ), } } @@ -350,12 +373,8 @@ pub fn lower_rule( tyenv: &TypeEnv, termenv: &TermEnv, rule: RuleId, -) -> ( - Option, - PatternSequence, - Option, - ExprSequence, -) { + is_forward_dir: bool, +) -> Option<(PatternSequence, ExprSequence, TermId)> { let mut pattern_seq: PatternSequence = Default::default(); let mut expr_seq: ExprSequence = Default::default(); expr_seq.pos = termenv.rules[rule.index()].pos; @@ -363,113 +382,83 @@ pub fn lower_rule( // Lower the pattern, starting from the root input value. let ruledata = &termenv.rules[rule.index()]; let mut vars = HashMap::new(); - let lhs_root_term = pattern_seq.gen_pattern(None, tyenv, termenv, &ruledata.lhs, &mut vars); - // Lower the expression, making use of the bound variables - // from the pattern. - let (rhs_root_term, rhs_root) = expr_seq.gen_expr(tyenv, termenv, &ruledata.rhs, &vars); - // Return the root RHS value. - let output_ty = ruledata.rhs.ty(); - expr_seq.add_return(output_ty, rhs_root); + if is_forward_dir { + let lhs_root_term = pattern_seq.gen_pattern(None, tyenv, termenv, &ruledata.lhs, &mut vars); + let root_term = match lhs_root_term { + Some(t) => t, + None => { + return None; + } + }; - (lhs_root_term, pattern_seq, rhs_root_term, expr_seq) + // Lower the expression, making use of the bound variables + // from the pattern. + let (_, rhs_root_vals) = expr_seq.gen_expr( + tyenv, + termenv, + &ruledata.rhs, + &vars, + /* final_construct = */ true, + ); + // Return the root RHS value. + let output_ty = ruledata.rhs.ty(); + assert_eq!(rhs_root_vals.len(), 1); + expr_seq.add_return(output_ty, rhs_root_vals[0]); + Some((pattern_seq, expr_seq, root_term)) + } else { + let arg = pattern_seq.add_arg(0, ruledata.lhs.ty()); + let _ = pattern_seq.gen_pattern(Some(arg), tyenv, termenv, &ruledata.lhs, &mut vars); + let (rhs_root_term, rhs_root_vals) = expr_seq.gen_expr( + tyenv, + termenv, + &ruledata.rhs, + &vars, + /* final_construct = */ false, + ); + + let root_term = match rhs_root_term { + Some(t) => t, + None => { + return None; + } + }; + let termdata = &termenv.terms[root_term.index()]; + for (i, (val, ty)) in rhs_root_vals + .into_iter() + .zip(termdata.arg_tys.iter()) + .enumerate() + { + expr_seq.add_multi_return(i, *ty, val); + } + + Some((pattern_seq, expr_seq, root_term)) + } } -/// Reverse a sequence to form an extractor from a constructor. -pub fn reverse_rule( - orig_pat: &PatternSequence, - orig_expr: &ExprSequence, -) -> Option<( - PatternSequence, - ExprSequence, - )> -{ - let mut pattern_seq = PatternSequence::default(); - let mut expr_seq = ExprSequence::default(); - expr_seq.pos = orig_expr.pos; - - let mut value_map = HashMap::new(); - - for (id, inst) in orig_expr.insts.iter().enumerate().rev() { - let id = InstId(id); - match inst { - &ExprInst::Return { index, ty, value } => { - let new_value = pattern_seq.add_arg(index, ty); - value_map.insert(value, new_value); - } - &ExprInst::Construct { ref inputs, ty, term } => { - let arg_tys = inputs.iter().map(|(_, ty)| *ty).collect::>(); - let input_ty = ty; - // Input to the Extract is the output of the Construct. - let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone(); - let outputs = pattern_seq.add_extract(input, input_ty, &arg_tys[..], term); - for (input, output) in inputs.iter().map(|(val, _)| val).zip(outputs.into_iter()) { - value_map.insert(*input, output); - } - } - &ExprInst::CreateVariant { ref inputs, ty, variant } => { - let arg_tys = inputs.iter().map(|(_, ty)| *ty).collect::>(); - let input_ty = ty; - // Input to the MatchVariant is the output of the CreateVariant. - let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone(); - let outputs = pattern_seq.add_match_variant(input, input_ty, &arg_tys[..], variant); - for (input, output) in inputs.iter().map(|(val, _)| val).zip(outputs.into_iter()) { - value_map.insert(*input, output); - } - } - &ExprInst::ConstInt { ty, val } => { - let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone(); - pattern_seq.add_match_int(input, ty, val); - } +/// Trim the final Construct and Return ops in an ExprSequence in +/// order to allow the extractor to be codegen'd. +pub fn trim_expr_for_extractor(mut expr: ExprSequence) -> ExprSequence { + let ret_inst = expr.insts.pop().unwrap(); + let retval = match ret_inst { + ExprInst::Return { value, .. } => value, + _ => panic!("Last instruction is not a return"), + }; + assert_eq!( + retval, + Value::Expr { + inst: InstId(expr.insts.len() - 1), + output: 0 } + ); + let construct_inst = expr.insts.pop().unwrap(); + let inputs = match construct_inst { + ExprInst::Construct { inputs, .. } => inputs, + _ => panic!("Returned value is not a construct call"), + }; + for (i, (value, ty)) in inputs.into_iter().enumerate() { + expr.add_multi_return(i, ty, value); } - for (id, inst) in orig_pat.insts.iter().enumerate().rev() { - let id = InstId(id); - match inst { - &PatternInst::Extract { input, input_ty, ref arg_tys, term } => { - let mut inputs = vec![]; - for i in 0..arg_tys.len() { - let value = Value::Pattern { inst: id, output: i }; - let new_value = value_map.get(&value)?.clone(); - inputs.push((new_value, arg_tys[i])); - } - let output = expr_seq.add_construct(&inputs[..], input_ty, term); - value_map.insert(input, output); - - } - &PatternInst::MatchVariant { input, input_ty, ref arg_tys, variant } => { - let mut inputs = vec![]; - for i in 0..arg_tys.len() { - let value = Value::Pattern { inst: id, output: i }; - let new_value = value_map.get(&value)?.clone(); - inputs.push((new_value, arg_tys[i])); - } - let output = expr_seq.add_create_variant(&inputs[..], input_ty, variant); - value_map.insert(input, output); - } - &PatternInst::MatchEqual { a, b, .. } => { - if let Some(new_a) = value_map.get(&a).cloned() { - if !value_map.contains_key(&b) { - value_map.insert(b, new_a); - } - } else if let Some(new_b) = value_map.get(&b).cloned() { - if !value_map.contains_key(&a) { - value_map.insert(a, new_b); - } - } - } - &PatternInst::MatchInt { input, ty, int_val } => { - let output = expr_seq.add_const_int(ty, int_val); - value_map.insert(input, output); - } - &PatternInst::Arg { index, ty } => { - let value = Value::Pattern { inst: id, output: 0 }; - let new_value = value_map.get(&value)?.clone(); - expr_seq.add_multi_return(index, ty, new_value); - } - } - } - - Some((pattern_seq, expr_seq)) + expr }