Working extractor and constructor generation from rules!

This commit is contained in:
Chris Fallin
2021-09-04 20:03:21 -07:00
parent be1140e80a
commit d7efd9f219
2 changed files with 139 additions and 139 deletions

View File

@@ -1,8 +1,8 @@
//! Generate Rust code from a series of Sequences. //! Generate Rust code from a series of Sequences.
use crate::error::Error;
use crate::ir::{lower_rule, ExprInst, ExprSequence, InstId, PatternInst, PatternSequence, Value}; use crate::ir::{lower_rule, ExprInst, ExprSequence, InstId, PatternInst, PatternSequence, Value};
use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId}; use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId};
use crate::{error::Error, ir::reverse_rule};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::fmt::Write; use std::fmt::Write;
@@ -448,29 +448,40 @@ impl<'a> TermFunctionsBuilder<'a> {
let rule = RuleId(rule); let rule = RuleId(rule);
let prio = self.termenv.rules[rule.index()].prio.unwrap_or(0); let prio = self.termenv.rules[rule.index()].prio.unwrap_or(0);
let (lhs_root, pattern, rhs_root, expr) = lower_rule(self.typeenv, self.termenv, rule); if let Some((pattern, expr, lhs_root)) = lower_rule(
log::trace!( self.typeenv,
"build:\n- rule {:?}\n- lhs_root {:?} rhs_root {:?}\n- pattern {:?}\n- expr {:?}", self.termenv,
self.termenv.rules[rule.index()], rule,
lhs_root, /* forward_dir = */ true,
rhs_root, ) {
pattern, log::trace!(
expr "build:\n- rule {:?}\n- fwd pattern {:?}\n- fwd expr {:?}",
); self.termenv.rules[rule.index()],
if let Some(input_root_term) = lhs_root { pattern,
expr
);
self.builders_by_input self.builders_by_input
.entry(input_root_term) .entry(lhs_root)
.or_insert_with(|| TermFunctionBuilder::new(input_root_term)) .or_insert_with(|| TermFunctionBuilder::new(lhs_root))
.add_rule(prio, pattern.clone(), expr.clone()); .add_rule(prio, pattern.clone(), expr.clone());
} }
if let Some(output_root_term) = rhs_root { if let Some((pattern, expr, rhs_root)) = lower_rule(
if let Some((reverse_pattern, reverse_expr)) = reverse_rule(&pattern, &expr) { self.typeenv,
self.builders_by_output self.termenv,
.entry(output_root_term) rule,
.or_insert_with(|| TermFunctionBuilder::new(output_root_term)) /* forward_dir = */ false,
.add_rule(prio, reverse_pattern, reverse_expr); ) {
} log::trace!(
"build:\n- rule {:?}\n- rev pattern {:?}\n- rev expr {:?}",
self.termenv.rules[rule.index()],
pattern,
expr
);
self.builders_by_output
.entry(rhs_root)
.or_insert_with(|| TermFunctionBuilder::new(rhs_root))
.add_rule(prio, pattern, expr);
} }
} }
} }

View File

@@ -72,7 +72,11 @@ pub enum ExprInst {
}, },
/// Set the Nth return value. Produces no values. /// Set the Nth return value. Produces no values.
Return { index: usize, ty: TypeId, value: Value }, Return {
index: usize,
ty: TypeId,
value: Value,
},
} }
impl ExprInst { impl ExprInst {
@@ -294,7 +298,11 @@ impl ExprSequence {
} }
fn add_return(&mut self, ty: TypeId, value: Value) { fn add_return(&mut self, ty: TypeId, value: Value) {
self.add_inst(ExprInst::Return { index: 0, ty, value }); self.add_inst(ExprInst::Return {
index: 0,
ty,
value,
});
} }
fn add_multi_return(&mut self, index: usize, ty: TypeId, value: Value) { fn add_multi_return(&mut self, index: usize, ty: TypeId, value: Value) {
@@ -304,40 +312,55 @@ impl ExprSequence {
/// Creates a sequence of ExprInsts to generate the given /// Creates a sequence of ExprInsts to generate the given
/// expression value. Returns the value ID as well as the root /// expression value. Returns the value ID as well as the root
/// term ID, if any. /// term ID, if any.
///
/// If `gen_final_construct` is false and the value is a
/// constructor call, this returns the arguments instead. This is
/// used when codegen'ing extractors for internal terms.
fn gen_expr( fn gen_expr(
&mut self, &mut self,
typeenv: &TypeEnv, typeenv: &TypeEnv,
termenv: &TermEnv, termenv: &TermEnv,
expr: &Expr, expr: &Expr,
vars: &HashMap<VarId, (Option<TermId>, Value)>, vars: &HashMap<VarId, (Option<TermId>, Value)>,
) -> (Option<TermId>, Value) { gen_final_construct: bool,
) -> (Option<TermId>, Vec<Value>) {
match expr { match expr {
&Expr::ConstInt(ty, val) => (None, self.add_const_int(ty, val)), &Expr::ConstInt(ty, val) => (None, vec![self.add_const_int(ty, val)]),
&Expr::Let(_ty, ref bindings, ref subexpr) => { &Expr::Let(_ty, ref bindings, ref subexpr) => {
let mut vars = vars.clone(); let mut vars = vars.clone();
for &(var, _var_ty, ref var_expr) in bindings { for &(var, _var_ty, ref var_expr) in bindings {
let (var_value_term, var_value) = let (var_value_term, var_value) =
self.gen_expr(typeenv, termenv, &*var_expr, &vars); self.gen_expr(typeenv, termenv, &*var_expr, &vars, false);
let var_value = var_value[0];
vars.insert(var, (var_value_term, var_value)); vars.insert(var, (var_value_term, var_value));
} }
self.gen_expr(typeenv, termenv, &*subexpr, &vars) self.gen_expr(typeenv, termenv, &*subexpr, &vars, gen_final_construct)
}
&Expr::Var(_ty, var_id) => {
let (root_term, value) = vars.get(&var_id).cloned().unwrap();
(root_term, vec![value])
} }
&Expr::Var(_ty, var_id) => vars.get(&var_id).cloned().unwrap(),
&Expr::Term(ty, term, ref arg_exprs) => { &Expr::Term(ty, term, ref arg_exprs) => {
let termdata = &termenv.terms[term.index()]; let termdata = &termenv.terms[term.index()];
let mut arg_values_tys = vec![]; let mut arg_values_tys = vec![];
for (arg_ty, arg_expr) in termdata.arg_tys.iter().cloned().zip(arg_exprs.iter()) { for (arg_ty, arg_expr) in termdata.arg_tys.iter().cloned().zip(arg_exprs.iter()) {
arg_values_tys arg_values_tys.push((
.push((self.gen_expr(typeenv, termenv, &*arg_expr, &vars).1, arg_ty)); self.gen_expr(typeenv, termenv, &*arg_expr, &vars, false).1[0],
arg_ty,
));
} }
match &termdata.kind { match &termdata.kind {
&TermKind::EnumVariant { variant } => ( &TermKind::EnumVariant { variant } => (
None, None,
self.add_create_variant(&arg_values_tys[..], ty, variant), vec![self.add_create_variant(&arg_values_tys[..], ty, variant)],
),
&TermKind::Regular { .. } if !gen_final_construct => (
Some(termdata.id),
arg_values_tys.into_iter().map(|(val, _ty)| val).collect(),
), ),
&TermKind::Regular { .. } => ( &TermKind::Regular { .. } => (
Some(termdata.id), Some(termdata.id),
self.add_construct(&arg_values_tys[..], ty, term), vec![self.add_construct(&arg_values_tys[..], ty, term)],
), ),
} }
} }
@@ -350,12 +373,8 @@ pub fn lower_rule(
tyenv: &TypeEnv, tyenv: &TypeEnv,
termenv: &TermEnv, termenv: &TermEnv,
rule: RuleId, rule: RuleId,
) -> ( is_forward_dir: bool,
Option<TermId>, ) -> Option<(PatternSequence, ExprSequence, TermId)> {
PatternSequence,
Option<TermId>,
ExprSequence,
) {
let mut pattern_seq: PatternSequence = Default::default(); let mut pattern_seq: PatternSequence = Default::default();
let mut expr_seq: ExprSequence = Default::default(); let mut expr_seq: ExprSequence = Default::default();
expr_seq.pos = termenv.rules[rule.index()].pos; expr_seq.pos = termenv.rules[rule.index()].pos;
@@ -363,113 +382,83 @@ pub fn lower_rule(
// Lower the pattern, starting from the root input value. // Lower the pattern, starting from the root input value.
let ruledata = &termenv.rules[rule.index()]; let ruledata = &termenv.rules[rule.index()];
let mut vars = HashMap::new(); let mut vars = HashMap::new();
let lhs_root_term = pattern_seq.gen_pattern(None, tyenv, termenv, &ruledata.lhs, &mut vars);
// Lower the expression, making use of the bound variables if is_forward_dir {
// from the pattern. let lhs_root_term = pattern_seq.gen_pattern(None, tyenv, termenv, &ruledata.lhs, &mut vars);
let (rhs_root_term, rhs_root) = expr_seq.gen_expr(tyenv, termenv, &ruledata.rhs, &vars); let root_term = match lhs_root_term {
// Return the root RHS value. Some(t) => t,
let output_ty = ruledata.rhs.ty(); None => {
expr_seq.add_return(output_ty, rhs_root); return None;
}
};
(lhs_root_term, pattern_seq, rhs_root_term, expr_seq) // Lower the expression, making use of the bound variables
// from the pattern.
let (_, rhs_root_vals) = expr_seq.gen_expr(
tyenv,
termenv,
&ruledata.rhs,
&vars,
/* final_construct = */ true,
);
// Return the root RHS value.
let output_ty = ruledata.rhs.ty();
assert_eq!(rhs_root_vals.len(), 1);
expr_seq.add_return(output_ty, rhs_root_vals[0]);
Some((pattern_seq, expr_seq, root_term))
} else {
let arg = pattern_seq.add_arg(0, ruledata.lhs.ty());
let _ = pattern_seq.gen_pattern(Some(arg), tyenv, termenv, &ruledata.lhs, &mut vars);
let (rhs_root_term, rhs_root_vals) = expr_seq.gen_expr(
tyenv,
termenv,
&ruledata.rhs,
&vars,
/* final_construct = */ false,
);
let root_term = match rhs_root_term {
Some(t) => t,
None => {
return None;
}
};
let termdata = &termenv.terms[root_term.index()];
for (i, (val, ty)) in rhs_root_vals
.into_iter()
.zip(termdata.arg_tys.iter())
.enumerate()
{
expr_seq.add_multi_return(i, *ty, val);
}
Some((pattern_seq, expr_seq, root_term))
}
} }
/// Reverse a sequence to form an extractor from a constructor. /// Trim the final Construct and Return ops in an ExprSequence in
pub fn reverse_rule( /// order to allow the extractor to be codegen'd.
orig_pat: &PatternSequence, pub fn trim_expr_for_extractor(mut expr: ExprSequence) -> ExprSequence {
orig_expr: &ExprSequence, let ret_inst = expr.insts.pop().unwrap();
) -> Option<( let retval = match ret_inst {
PatternSequence, ExprInst::Return { value, .. } => value,
ExprSequence, _ => panic!("Last instruction is not a return"),
)> };
{ assert_eq!(
let mut pattern_seq = PatternSequence::default(); retval,
let mut expr_seq = ExprSequence::default(); Value::Expr {
expr_seq.pos = orig_expr.pos; inst: InstId(expr.insts.len() - 1),
output: 0
let mut value_map = HashMap::new();
for (id, inst) in orig_expr.insts.iter().enumerate().rev() {
let id = InstId(id);
match inst {
&ExprInst::Return { index, ty, value } => {
let new_value = pattern_seq.add_arg(index, ty);
value_map.insert(value, new_value);
}
&ExprInst::Construct { ref inputs, ty, term } => {
let arg_tys = inputs.iter().map(|(_, ty)| *ty).collect::<Vec<_>>();
let input_ty = ty;
// Input to the Extract is the output of the Construct.
let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone();
let outputs = pattern_seq.add_extract(input, input_ty, &arg_tys[..], term);
for (input, output) in inputs.iter().map(|(val, _)| val).zip(outputs.into_iter()) {
value_map.insert(*input, output);
}
}
&ExprInst::CreateVariant { ref inputs, ty, variant } => {
let arg_tys = inputs.iter().map(|(_, ty)| *ty).collect::<Vec<_>>();
let input_ty = ty;
// Input to the MatchVariant is the output of the CreateVariant.
let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone();
let outputs = pattern_seq.add_match_variant(input, input_ty, &arg_tys[..], variant);
for (input, output) in inputs.iter().map(|(val, _)| val).zip(outputs.into_iter()) {
value_map.insert(*input, output);
}
}
&ExprInst::ConstInt { ty, val } => {
let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone();
pattern_seq.add_match_int(input, ty, val);
}
} }
);
let construct_inst = expr.insts.pop().unwrap();
let inputs = match construct_inst {
ExprInst::Construct { inputs, .. } => inputs,
_ => panic!("Returned value is not a construct call"),
};
for (i, (value, ty)) in inputs.into_iter().enumerate() {
expr.add_multi_return(i, ty, value);
} }
for (id, inst) in orig_pat.insts.iter().enumerate().rev() { expr
let id = InstId(id);
match inst {
&PatternInst::Extract { input, input_ty, ref arg_tys, term } => {
let mut inputs = vec![];
for i in 0..arg_tys.len() {
let value = Value::Pattern { inst: id, output: i };
let new_value = value_map.get(&value)?.clone();
inputs.push((new_value, arg_tys[i]));
}
let output = expr_seq.add_construct(&inputs[..], input_ty, term);
value_map.insert(input, output);
}
&PatternInst::MatchVariant { input, input_ty, ref arg_tys, variant } => {
let mut inputs = vec![];
for i in 0..arg_tys.len() {
let value = Value::Pattern { inst: id, output: i };
let new_value = value_map.get(&value)?.clone();
inputs.push((new_value, arg_tys[i]));
}
let output = expr_seq.add_create_variant(&inputs[..], input_ty, variant);
value_map.insert(input, output);
}
&PatternInst::MatchEqual { a, b, .. } => {
if let Some(new_a) = value_map.get(&a).cloned() {
if !value_map.contains_key(&b) {
value_map.insert(b, new_a);
}
} else if let Some(new_b) = value_map.get(&b).cloned() {
if !value_map.contains_key(&a) {
value_map.insert(a, new_b);
}
}
}
&PatternInst::MatchInt { input, ty, int_val } => {
let output = expr_seq.add_const_int(ty, int_val);
value_map.insert(input, output);
}
&PatternInst::Arg { index, ty } => {
let value = Value::Pattern { inst: id, output: 0 };
let new_value = value_map.get(&value)?.clone();
expr_seq.add_multi_return(index, ty, new_value);
}
}
}
Some((pattern_seq, expr_seq))
} }