Working extractor and constructor generation from rules!

This commit is contained in:
Chris Fallin
2021-09-04 20:03:21 -07:00
parent be1140e80a
commit d7efd9f219
2 changed files with 139 additions and 139 deletions

View File

@@ -1,8 +1,8 @@
//! Generate Rust code from a series of Sequences.
use crate::error::Error;
use crate::ir::{lower_rule, ExprInst, ExprSequence, InstId, PatternInst, PatternSequence, Value};
use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId};
use crate::{error::Error, ir::reverse_rule};
use std::collections::{HashMap, HashSet};
use std::fmt::Write;
@@ -448,29 +448,40 @@ impl<'a> TermFunctionsBuilder<'a> {
let rule = RuleId(rule);
let prio = self.termenv.rules[rule.index()].prio.unwrap_or(0);
let (lhs_root, pattern, rhs_root, expr) = lower_rule(self.typeenv, self.termenv, rule);
log::trace!(
"build:\n- rule {:?}\n- lhs_root {:?} rhs_root {:?}\n- pattern {:?}\n- expr {:?}",
self.termenv.rules[rule.index()],
lhs_root,
rhs_root,
pattern,
expr
);
if let Some(input_root_term) = lhs_root {
if let Some((pattern, expr, lhs_root)) = lower_rule(
self.typeenv,
self.termenv,
rule,
/* forward_dir = */ true,
) {
log::trace!(
"build:\n- rule {:?}\n- fwd pattern {:?}\n- fwd expr {:?}",
self.termenv.rules[rule.index()],
pattern,
expr
);
self.builders_by_input
.entry(input_root_term)
.or_insert_with(|| TermFunctionBuilder::new(input_root_term))
.entry(lhs_root)
.or_insert_with(|| TermFunctionBuilder::new(lhs_root))
.add_rule(prio, pattern.clone(), expr.clone());
}
if let Some(output_root_term) = rhs_root {
if let Some((reverse_pattern, reverse_expr)) = reverse_rule(&pattern, &expr) {
self.builders_by_output
.entry(output_root_term)
.or_insert_with(|| TermFunctionBuilder::new(output_root_term))
.add_rule(prio, reverse_pattern, reverse_expr);
}
if let Some((pattern, expr, rhs_root)) = lower_rule(
self.typeenv,
self.termenv,
rule,
/* forward_dir = */ false,
) {
log::trace!(
"build:\n- rule {:?}\n- rev pattern {:?}\n- rev expr {:?}",
self.termenv.rules[rule.index()],
pattern,
expr
);
self.builders_by_output
.entry(rhs_root)
.or_insert_with(|| TermFunctionBuilder::new(rhs_root))
.add_rule(prio, pattern, expr);
}
}
}

View File

@@ -72,7 +72,11 @@ pub enum ExprInst {
},
/// Set the Nth return value. Produces no values.
Return { index: usize, ty: TypeId, value: Value },
Return {
index: usize,
ty: TypeId,
value: Value,
},
}
impl ExprInst {
@@ -294,7 +298,11 @@ impl ExprSequence {
}
fn add_return(&mut self, ty: TypeId, value: Value) {
self.add_inst(ExprInst::Return { index: 0, ty, value });
self.add_inst(ExprInst::Return {
index: 0,
ty,
value,
});
}
fn add_multi_return(&mut self, index: usize, ty: TypeId, value: Value) {
@@ -304,40 +312,55 @@ impl ExprSequence {
/// Creates a sequence of ExprInsts to generate the given
/// expression value. Returns the value ID as well as the root
/// term ID, if any.
///
/// If `gen_final_construct` is false and the value is a
/// constructor call, this returns the arguments instead. This is
/// used when codegen'ing extractors for internal terms.
fn gen_expr(
&mut self,
typeenv: &TypeEnv,
termenv: &TermEnv,
expr: &Expr,
vars: &HashMap<VarId, (Option<TermId>, Value)>,
) -> (Option<TermId>, Value) {
gen_final_construct: bool,
) -> (Option<TermId>, Vec<Value>) {
match expr {
&Expr::ConstInt(ty, val) => (None, self.add_const_int(ty, val)),
&Expr::ConstInt(ty, val) => (None, vec![self.add_const_int(ty, val)]),
&Expr::Let(_ty, ref bindings, ref subexpr) => {
let mut vars = vars.clone();
for &(var, _var_ty, ref var_expr) in bindings {
let (var_value_term, var_value) =
self.gen_expr(typeenv, termenv, &*var_expr, &vars);
self.gen_expr(typeenv, termenv, &*var_expr, &vars, false);
let var_value = var_value[0];
vars.insert(var, (var_value_term, var_value));
}
self.gen_expr(typeenv, termenv, &*subexpr, &vars)
self.gen_expr(typeenv, termenv, &*subexpr, &vars, gen_final_construct)
}
&Expr::Var(_ty, var_id) => {
let (root_term, value) = vars.get(&var_id).cloned().unwrap();
(root_term, vec![value])
}
&Expr::Var(_ty, var_id) => vars.get(&var_id).cloned().unwrap(),
&Expr::Term(ty, term, ref arg_exprs) => {
let termdata = &termenv.terms[term.index()];
let mut arg_values_tys = vec![];
for (arg_ty, arg_expr) in termdata.arg_tys.iter().cloned().zip(arg_exprs.iter()) {
arg_values_tys
.push((self.gen_expr(typeenv, termenv, &*arg_expr, &vars).1, arg_ty));
arg_values_tys.push((
self.gen_expr(typeenv, termenv, &*arg_expr, &vars, false).1[0],
arg_ty,
));
}
match &termdata.kind {
&TermKind::EnumVariant { variant } => (
None,
self.add_create_variant(&arg_values_tys[..], ty, variant),
vec![self.add_create_variant(&arg_values_tys[..], ty, variant)],
),
&TermKind::Regular { .. } if !gen_final_construct => (
Some(termdata.id),
arg_values_tys.into_iter().map(|(val, _ty)| val).collect(),
),
&TermKind::Regular { .. } => (
Some(termdata.id),
self.add_construct(&arg_values_tys[..], ty, term),
vec![self.add_construct(&arg_values_tys[..], ty, term)],
),
}
}
@@ -350,12 +373,8 @@ pub fn lower_rule(
tyenv: &TypeEnv,
termenv: &TermEnv,
rule: RuleId,
) -> (
Option<TermId>,
PatternSequence,
Option<TermId>,
ExprSequence,
) {
is_forward_dir: bool,
) -> Option<(PatternSequence, ExprSequence, TermId)> {
let mut pattern_seq: PatternSequence = Default::default();
let mut expr_seq: ExprSequence = Default::default();
expr_seq.pos = termenv.rules[rule.index()].pos;
@@ -363,113 +382,83 @@ pub fn lower_rule(
// Lower the pattern, starting from the root input value.
let ruledata = &termenv.rules[rule.index()];
let mut vars = HashMap::new();
let lhs_root_term = pattern_seq.gen_pattern(None, tyenv, termenv, &ruledata.lhs, &mut vars);
// Lower the expression, making use of the bound variables
// from the pattern.
let (rhs_root_term, rhs_root) = expr_seq.gen_expr(tyenv, termenv, &ruledata.rhs, &vars);
// Return the root RHS value.
let output_ty = ruledata.rhs.ty();
expr_seq.add_return(output_ty, rhs_root);
if is_forward_dir {
let lhs_root_term = pattern_seq.gen_pattern(None, tyenv, termenv, &ruledata.lhs, &mut vars);
let root_term = match lhs_root_term {
Some(t) => t,
None => {
return None;
}
};
(lhs_root_term, pattern_seq, rhs_root_term, expr_seq)
// Lower the expression, making use of the bound variables
// from the pattern.
let (_, rhs_root_vals) = expr_seq.gen_expr(
tyenv,
termenv,
&ruledata.rhs,
&vars,
/* final_construct = */ true,
);
// Return the root RHS value.
let output_ty = ruledata.rhs.ty();
assert_eq!(rhs_root_vals.len(), 1);
expr_seq.add_return(output_ty, rhs_root_vals[0]);
Some((pattern_seq, expr_seq, root_term))
} else {
let arg = pattern_seq.add_arg(0, ruledata.lhs.ty());
let _ = pattern_seq.gen_pattern(Some(arg), tyenv, termenv, &ruledata.lhs, &mut vars);
let (rhs_root_term, rhs_root_vals) = expr_seq.gen_expr(
tyenv,
termenv,
&ruledata.rhs,
&vars,
/* final_construct = */ false,
);
let root_term = match rhs_root_term {
Some(t) => t,
None => {
return None;
}
};
let termdata = &termenv.terms[root_term.index()];
for (i, (val, ty)) in rhs_root_vals
.into_iter()
.zip(termdata.arg_tys.iter())
.enumerate()
{
expr_seq.add_multi_return(i, *ty, val);
}
Some((pattern_seq, expr_seq, root_term))
}
}
/// Reverse a sequence to form an extractor from a constructor.
pub fn reverse_rule(
orig_pat: &PatternSequence,
orig_expr: &ExprSequence,
) -> Option<(
PatternSequence,
ExprSequence,
)>
{
let mut pattern_seq = PatternSequence::default();
let mut expr_seq = ExprSequence::default();
expr_seq.pos = orig_expr.pos;
let mut value_map = HashMap::new();
for (id, inst) in orig_expr.insts.iter().enumerate().rev() {
let id = InstId(id);
match inst {
&ExprInst::Return { index, ty, value } => {
let new_value = pattern_seq.add_arg(index, ty);
value_map.insert(value, new_value);
}
&ExprInst::Construct { ref inputs, ty, term } => {
let arg_tys = inputs.iter().map(|(_, ty)| *ty).collect::<Vec<_>>();
let input_ty = ty;
// Input to the Extract is the output of the Construct.
let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone();
let outputs = pattern_seq.add_extract(input, input_ty, &arg_tys[..], term);
for (input, output) in inputs.iter().map(|(val, _)| val).zip(outputs.into_iter()) {
value_map.insert(*input, output);
}
}
&ExprInst::CreateVariant { ref inputs, ty, variant } => {
let arg_tys = inputs.iter().map(|(_, ty)| *ty).collect::<Vec<_>>();
let input_ty = ty;
// Input to the MatchVariant is the output of the CreateVariant.
let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone();
let outputs = pattern_seq.add_match_variant(input, input_ty, &arg_tys[..], variant);
for (input, output) in inputs.iter().map(|(val, _)| val).zip(outputs.into_iter()) {
value_map.insert(*input, output);
}
}
&ExprInst::ConstInt { ty, val } => {
let input = value_map.get(&Value::Expr { inst: id, output: 0 })?.clone();
pattern_seq.add_match_int(input, ty, val);
}
/// Trim the final Construct and Return ops in an ExprSequence in
/// order to allow the extractor to be codegen'd.
pub fn trim_expr_for_extractor(mut expr: ExprSequence) -> ExprSequence {
let ret_inst = expr.insts.pop().unwrap();
let retval = match ret_inst {
ExprInst::Return { value, .. } => value,
_ => panic!("Last instruction is not a return"),
};
assert_eq!(
retval,
Value::Expr {
inst: InstId(expr.insts.len() - 1),
output: 0
}
);
let construct_inst = expr.insts.pop().unwrap();
let inputs = match construct_inst {
ExprInst::Construct { inputs, .. } => inputs,
_ => panic!("Returned value is not a construct call"),
};
for (i, (value, ty)) in inputs.into_iter().enumerate() {
expr.add_multi_return(i, ty, value);
}
for (id, inst) in orig_pat.insts.iter().enumerate().rev() {
let id = InstId(id);
match inst {
&PatternInst::Extract { input, input_ty, ref arg_tys, term } => {
let mut inputs = vec![];
for i in 0..arg_tys.len() {
let value = Value::Pattern { inst: id, output: i };
let new_value = value_map.get(&value)?.clone();
inputs.push((new_value, arg_tys[i]));
}
let output = expr_seq.add_construct(&inputs[..], input_ty, term);
value_map.insert(input, output);
}
&PatternInst::MatchVariant { input, input_ty, ref arg_tys, variant } => {
let mut inputs = vec![];
for i in 0..arg_tys.len() {
let value = Value::Pattern { inst: id, output: i };
let new_value = value_map.get(&value)?.clone();
inputs.push((new_value, arg_tys[i]));
}
let output = expr_seq.add_create_variant(&inputs[..], input_ty, variant);
value_map.insert(input, output);
}
&PatternInst::MatchEqual { a, b, .. } => {
if let Some(new_a) = value_map.get(&a).cloned() {
if !value_map.contains_key(&b) {
value_map.insert(b, new_a);
}
} else if let Some(new_b) = value_map.get(&b).cloned() {
if !value_map.contains_key(&a) {
value_map.insert(a, new_b);
}
}
}
&PatternInst::MatchInt { input, ty, int_val } => {
let output = expr_seq.add_const_int(ty, int_val);
value_map.insert(input, output);
}
&PatternInst::Arg { index, ty } => {
let value = Value::Pattern { inst: id, output: 0 };
let new_value = value_map.get(&value)?.clone();
expr_seq.add_multi_return(index, ty, new_value);
}
}
}
Some((pattern_seq, expr_seq))
expr
}