diff --git a/cranelift/isle/isle/src/codegen.rs b/cranelift/isle/isle/src/codegen.rs index 85896d999d..5bc60e341a 100644 --- a/cranelift/isle/isle/src/codegen.rs +++ b/cranelift/isle/isle/src/codegen.rs @@ -1,12 +1,9 @@ //! Generate Rust code from a series of Sequences. -use crate::ir::{ExprInst, InstId, PatternInst, Value}; -use crate::log; -use crate::sema::{ExternalSig, ReturnKind, TermEnv, TermId, Type, TypeEnv, TypeId, Variant}; -use crate::trie::{TrieEdge, TrieNode, TrieSymbol}; -use crate::{StableMap, StableSet}; -use std::borrow::Cow; -use std::collections::BTreeMap; +use crate::sema::{ExternalSig, ReturnKind, Sym, Term, TermEnv, TermId, Type, TypeEnv, TypeId}; +use crate::serialize::{Block, ControlFlow, EvalStep, MatchArm}; +use crate::trie_again::{Binding, BindingId, Constraint, RuleSet}; +use crate::StableSet; use std::fmt::Write; /// Options for code generation. @@ -21,35 +18,78 @@ pub struct CodegenOptions { pub fn codegen( typeenv: &TypeEnv, termenv: &TermEnv, - tries: &BTreeMap, + terms: &[(TermId, RuleSet)], options: &CodegenOptions, ) -> String { - Codegen::compile(typeenv, termenv, tries).generate_rust(options) + Codegen::compile(typeenv, termenv, terms).generate_rust(options) } #[derive(Clone, Debug)] struct Codegen<'a> { typeenv: &'a TypeEnv, termenv: &'a TermEnv, - functions_by_term: &'a BTreeMap, + terms: &'a [(TermId, RuleSet)], } -#[derive(Clone, Debug, Default)] -struct BodyContext { - /// For each value: (is_ref, ty). - values: StableMap, +struct BodyContext<'a, W> { + out: &'a mut W, + ruleset: &'a RuleSet, + indent: String, + is_ref: StableSet, + is_bound: StableSet, +} + +impl<'a, W: Write> BodyContext<'a, W> { + fn new(out: &'a mut W, ruleset: &'a RuleSet) -> Self { + Self { + out, + ruleset, + indent: Default::default(), + is_ref: Default::default(), + is_bound: Default::default(), + } + } + + fn enter_scope(&mut self) -> StableSet { + let new = self.is_bound.clone(); + std::mem::replace(&mut self.is_bound, new) + } + + fn begin_block(&mut self) -> std::fmt::Result { + self.indent.push_str(" "); + writeln!(self.out, " {{") + } + + fn end_block(&mut self, scope: StableSet) -> std::fmt::Result { + self.is_bound = scope; + self.end_block_without_newline()?; + writeln!(self.out) + } + + fn end_block_without_newline(&mut self) -> std::fmt::Result { + self.indent.truncate(self.indent.len() - 4); + write!(self.out, "{}}}", &self.indent) + } + + fn set_ref(&mut self, binding: BindingId, is_ref: bool) { + if is_ref { + self.is_ref.insert(binding); + } else { + debug_assert!(!self.is_ref.contains(&binding)); + } + } } impl<'a> Codegen<'a> { fn compile( typeenv: &'a TypeEnv, termenv: &'a TermEnv, - tries: &'a BTreeMap, + terms: &'a [(TermId, RuleSet)], ) -> Codegen<'a> { Codegen { typeenv, termenv, - functions_by_term: tries, + terms, } } @@ -59,7 +99,7 @@ impl<'a> Codegen<'a> { self.generate_header(&mut code, options); self.generate_ctx_trait(&mut code); self.generate_internal_types(&mut code); - self.generate_internal_term_constructors(&mut code); + self.generate_internal_term_constructors(&mut code).unwrap(); code } @@ -270,777 +310,454 @@ impl<'a> Codegen<'a> { } } - fn value_name(&self, value: &Value) -> String { - match value { - &Value::Pattern { inst, output } => format!("pattern{}_{}", inst.index(), output), - &Value::Expr { inst, output } => format!("expr{}_{}", inst.index(), output), - } - } + fn generate_internal_term_constructors(&self, code: &mut String) -> std::fmt::Result { + for &(termid, ref ruleset) in self.terms.iter() { + let root = crate::serialize::serialize(ruleset); + let mut ctx = BodyContext::new(code, ruleset); - fn ty_prim(&self, ty: TypeId) -> bool { - self.typeenv.types[ty.index()].is_prim() - } - - fn value_binder(&self, value: &Value, is_ref: bool, ty: TypeId) -> String { - let prim = self.ty_prim(ty); - if prim || !is_ref { - format!("{}", self.value_name(value)) - } else { - format!("ref {}", self.value_name(value)) - } - } - - fn value_by_ref(&self, value: &Value, ctx: &BodyContext) -> String { - let raw_name = self.value_name(value); - let &(is_ref, ty) = ctx.values.get(value).unwrap(); - let prim = self.ty_prim(ty); - if is_ref || prim { - raw_name - } else { - format!("&{}", raw_name) - } - } - - fn value_by_val(&self, value: &Value, ctx: &BodyContext) -> String { - let raw_name = self.value_name(value); - let &(is_ref, _) = ctx.values.get(value).unwrap(); - if is_ref { - format!("{}.clone()", raw_name) - } else { - raw_name - } - } - - fn define_val(&self, value: &Value, ctx: &mut BodyContext, is_ref: bool, ty: TypeId) { - let is_ref = !self.ty_prim(ty) && is_ref; - ctx.values.insert(value.clone(), (is_ref, ty)); - } - - fn const_int(&self, val: i128, ty: TypeId) -> String { - let is_bool = match &self.typeenv.types[ty.index()] { - &Type::Primitive(_, name, _) => &self.typeenv.syms[name.index()] == "bool", - _ => unreachable!(), - }; - if is_bool { - format!("{}", val != 0) - } else { - let ty_name = self.type_name(ty, /* by_ref = */ false); - if ty_name == "i128" { - format!("{}i128", val) - } else { - format!("{}i128 as {}", val, ty_name) - } - } - } - - fn generate_internal_term_constructors(&self, code: &mut String) { - for (&termid, trie) in self.functions_by_term { let termdata = &self.termenv.terms[termid.index()]; - - // Skip terms that are enum variants or that have external - // constructors/extractors. - if !termdata.has_constructor() || termdata.has_external_constructor() { - continue; - } + let term_name = &self.typeenv.syms[termdata.name.index()]; + writeln!(ctx.out)?; + writeln!( + ctx.out, + "{}// Generated as internal constructor for term {}.", + &ctx.indent, term_name, + )?; let sig = termdata.constructor_sig(self.typeenv).unwrap(); + writeln!( + ctx.out, + "{}pub fn {}(", + &ctx.indent, sig.func_name + )?; - let args = sig - .param_tys - .iter() - .enumerate() - .map(|(i, &ty)| format!("arg{}: {}", i, self.type_name(ty, true))) - .collect::>() - .join(", "); - assert_eq!(sig.ret_tys.len(), 1); + writeln!(ctx.out, "{} ctx: &mut C,", &ctx.indent)?; + for (i, &ty) in sig.param_tys.iter().enumerate() { + let (is_ref, sym) = self.ty(ty); + write!(ctx.out, "{} arg{}: ", &ctx.indent, i)?; + write!( + ctx.out, + "{}{}", + if is_ref { "&" } else { "" }, + &self.typeenv.syms[sym.index()] + )?; + if let Some(binding) = ctx.ruleset.find_binding(&Binding::Argument { + index: i.try_into().unwrap(), + }) { + ctx.set_ref(binding, is_ref); + } + writeln!(ctx.out, ",")?; + } - let ret = self.type_name(sig.ret_tys[0], false); - let ret = match sig.ret_kind { - ReturnKind::Iterator => format!("impl ContextIter", ret), - ReturnKind::Option => format!("Option<{}>", ret), - ReturnKind::Plain => ret, + write!(ctx.out, "{}) -> ", &ctx.indent)?; + let (_, ret) = self.ty(sig.ret_tys[0]); + let ret = &self.typeenv.syms[ret.index()]; + match sig.ret_kind { + ReturnKind::Iterator => { + write!(ctx.out, "impl ContextIter", ret)? + } + ReturnKind::Option => write!(ctx.out, "Option<{}>", ret)?, + ReturnKind::Plain => write!(ctx.out, "{}", ret)?, }; - let term_name = &self.typeenv.syms[termdata.name.index()]; - writeln!( - code, - "\n// Generated as internal constructor for term {}.", - term_name, - ) - .unwrap(); - writeln!( - code, - "pub fn {}(ctx: &mut C, {}) -> {} {{", - sig.func_name, args, ret, - ) - .unwrap(); + let scope = ctx.enter_scope(); + ctx.begin_block()?; if sig.ret_kind == ReturnKind::Iterator { - writeln!(code, "let mut returns = ConstructorVec::new();").unwrap(); + writeln!( + ctx.out, + "{}let mut returns = ConstructorVec::new();", + &ctx.indent + )?; } - let mut body_ctx: BodyContext = Default::default(); - let returned = self.generate_body( - code, - /* depth = */ 0, - trie, - " ", - &mut body_ctx, - sig.ret_kind, - ); - if !returned { - let ret_expr = match sig.ret_kind { - ReturnKind::Plain => Cow::from(format!( - "unreachable!(\"no rule matched for term {{}} at {{}}; should it be partial?\", {:?}, {:?})", - term_name, - termdata - .decl_pos - .pretty_print_line(&self.typeenv.filenames[..]) - )), - ReturnKind::Option => Cow::from("None"), - ReturnKind::Iterator => { - Cow::from("ContextIterWrapper::from(returns.into_iter())") + self.emit_block(&mut ctx, &root, sig.ret_kind)?; + + match (sig.ret_kind, root.steps.last()) { + (ReturnKind::Iterator, _) => { + writeln!( + ctx.out, + "{}return ContextIterWrapper::from(returns.into_iter());", + &ctx.indent + )?; } - }; - write!(code, " return {};", ret_expr).unwrap(); - } + (_, Some(EvalStep { check: ControlFlow::Return { .. }, .. })) => { + // If there's an outermost fallback, no need for another `return` statement. + } + (ReturnKind::Option, _) => { + writeln!(ctx.out, "{}None", &ctx.indent)? + } + (ReturnKind::Plain, _) => { + writeln!(ctx.out, + "unreachable!(\"no rule matched for term {{}} at {{}}; should it be partial?\", {:?}, {:?})", + term_name, + termdata + .decl_pos + .pretty_print_line(&self.typeenv.filenames[..]) + )? + } + } - writeln!(code, "}}").unwrap(); + ctx.end_block(scope)?; + } + Ok(()) + } + + fn ty(&self, typeid: TypeId) -> (bool, Sym) { + match &self.typeenv.types[typeid.index()] { + &Type::Primitive(_, sym, _) => (false, sym), + &Type::Enum { name, .. } => (true, name), } } - fn generate_expr_inst( + fn emit_block( &self, - code: &mut String, - id: InstId, - inst: &ExprInst, - indent: &str, - ctx: &mut BodyContext, - returns: &mut Vec<(usize, String)>, - ) -> bool { - log!("generate_expr_inst: {:?}", inst); - let mut new_scope = false; - match inst { - &ExprInst::ConstInt { ty, val } => { - let value = Value::Expr { - inst: id, - output: 0, - }; - self.define_val(&value, ctx, /* is_ref = */ false, ty); - let name = self.value_name(&value); - let ty_name = self.type_name(ty, /* by_ref = */ false); - writeln!( - code, - "{}let {}: {} = {};", - indent, - name, - ty_name, - self.const_int(val, ty) - ) - .unwrap(); + ctx: &mut BodyContext, + block: &Block, + ret_kind: ReturnKind, + ) -> std::fmt::Result { + if !matches!(ret_kind, ReturnKind::Iterator) { + // Loops are only allowed if we're returning an iterator. + assert!(!block + .steps + .iter() + .any(|c| matches!(c.check, ControlFlow::Loop { .. }))); + + // Unless we're returning an iterator, a case which returns a result must be the last + // case in a block. + if let Some(result_pos) = block + .steps + .iter() + .position(|c| matches!(c.check, ControlFlow::Return { .. })) + { + assert_eq!(block.steps.len() - 1, result_pos); } - &ExprInst::ConstPrim { ty, val } => { - let value = Value::Expr { - inst: id, - output: 0, - }; - self.define_val(&value, ctx, /* is_ref = */ false, ty); - let name = self.value_name(&value); - let ty_name = self.type_name(ty, /* by_ref = */ false); - writeln!( - code, - "{}let {}: {} = {};", - indent, - name, - ty_name, - self.typeenv.syms[val.index()], - ) - .unwrap(); + } + + for case in block.steps.iter() { + for &expr in case.bind_order.iter() { + write!(ctx.out, "{}let v{} = ", &ctx.indent, expr.index())?; + self.emit_expr(ctx, expr)?; + writeln!(ctx.out, ";")?; + ctx.is_bound.insert(expr); } - &ExprInst::CreateVariant { - ref inputs, - ty, - variant, - } => { - let variantinfo = match &self.typeenv.types[ty.index()] { - &Type::Primitive(..) => panic!("CreateVariant with primitive type"), - &Type::Enum { ref variants, .. } => &variants[variant.index()], - }; - let mut input_fields = vec![]; - for ((input_value, _), field) in inputs.iter().zip(variantinfo.fields.iter()) { - let field_name = &self.typeenv.syms[field.name.index()]; - let value_expr = self.value_by_val(input_value, ctx); - input_fields.push(format!("{}: {}", field_name, value_expr)); + + match &case.check { + // Use a shorthand notation if there's only one match arm. + ControlFlow::Match { source, arms } if arms.len() == 1 => { + let arm = &arms[0]; + let scope = ctx.enter_scope(); + match arm.constraint { + Constraint::ConstInt { .. } | Constraint::ConstPrim { .. } => { + write!(ctx.out, "{}if ", &ctx.indent)?; + self.emit_expr(ctx, *source)?; + write!(ctx.out, " == ")?; + self.emit_constraint(ctx, *source, arm)?; + } + Constraint::Variant { .. } | Constraint::Some => { + write!(ctx.out, "{}if let ", &ctx.indent)?; + self.emit_constraint(ctx, *source, arm)?; + write!(ctx.out, " = ")?; + self.emit_source(ctx, *source, arm.constraint)?; + } + } + ctx.begin_block()?; + self.emit_block(ctx, &arm.body, ret_kind)?; + ctx.end_block(scope)?; } - let output = Value::Expr { - inst: id, - output: 0, - }; - let outputname = self.value_name(&output); - let full_variant_name = format!( - "{}::{}", - self.type_name(ty, false), - self.typeenv.syms[variantinfo.name.index()] - ); - if input_fields.is_empty() { - writeln!( - code, - "{}let {} = {};", - indent, outputname, full_variant_name - ) - .unwrap(); - } else { - writeln!( - code, - "{}let {} = {} {{", - indent, outputname, full_variant_name - ) - .unwrap(); - for input_field in input_fields { - writeln!(code, "{} {},", indent, input_field).unwrap(); + ControlFlow::Match { source, arms } => { + let scope = ctx.enter_scope(); + write!(ctx.out, "{}match ", &ctx.indent)?; + self.emit_source(ctx, *source, arms[0].constraint)?; + ctx.begin_block()?; + for arm in arms.iter() { + let scope = ctx.enter_scope(); + write!(ctx.out, "{}", &ctx.indent)?; + self.emit_constraint(ctx, *source, arm)?; + write!(ctx.out, " =>")?; + ctx.begin_block()?; + self.emit_block(ctx, &arm.body, ret_kind)?; + ctx.end_block(scope)?; } - writeln!(code, "{}}};", indent).unwrap(); + // Always add a catchall, because we don't do exhaustiveness checking on the + // match arms. + writeln!(ctx.out, "{}_ => {{}}", &ctx.indent)?; + ctx.end_block(scope)?; } - self.define_val(&output, ctx, /* is_ref = */ false, ty); - } - &ExprInst::Construct { - ref inputs, - term, - infallible, - multi, - .. - } => { - let mut input_exprs = vec![]; - for (input_value, input_ty) in inputs { - let value_expr = if self.typeenv.types[input_ty.index()].is_prim() { - self.value_by_val(input_value, ctx) - } else { - self.value_by_ref(input_value, ctx) + + ControlFlow::Equal { a, b, body } => { + let scope = ctx.enter_scope(); + write!(ctx.out, "{}if ", &ctx.indent)?; + self.emit_expr(ctx, *a)?; + write!(ctx.out, " == ")?; + self.emit_expr(ctx, *b)?; + ctx.begin_block()?; + self.emit_block(ctx, body, ret_kind)?; + ctx.end_block(scope)?; + } + + ControlFlow::Loop { result, body } => { + let source = match &ctx.ruleset.bindings[result.index()] { + Binding::Iterator { source } => source, + _ => unreachable!("Loop from a non-Iterator"), }; - input_exprs.push(value_expr); + let scope = ctx.enter_scope(); + write!(ctx.out, "{}let mut v{} = ", &ctx.indent, source.index())?; + self.emit_expr(ctx, *source)?; + writeln!(ctx.out, ";")?; + write!( + ctx.out, + "{}while let Some(v{}) = v{}.next(ctx)", + &ctx.indent, + result.index(), + source.index() + )?; + ctx.is_bound.insert(*result); + ctx.begin_block()?; + self.emit_block(ctx, body, ret_kind)?; + ctx.end_block(scope)?; } - let output = Value::Expr { - inst: id, - output: 0, - }; - let outputname = self.value_name(&output); - let termdata = &self.termenv.terms[term.index()]; - let sig = termdata.constructor_sig(self.typeenv).unwrap(); - assert_eq!(input_exprs.len(), sig.param_tys.len()); - - if !multi { - let fallible_try = if infallible { "" } else { "?" }; + &ControlFlow::Return { pos, result } => { writeln!( - code, - "{}let {} = {}(ctx, {}){};", - indent, - outputname, - sig.full_name, - input_exprs.join(", "), - fallible_try, - ) - .unwrap(); - } else { - writeln!( - code, - "{}let mut iter = {}(ctx, {});", - indent, - sig.full_name, - input_exprs.join(", "), - ) - .unwrap(); - writeln!( - code, - "{}while let Some({}) = iter.next(ctx) {{", - indent, outputname, - ) - .unwrap(); - new_scope = true; + ctx.out, + "{}// Rule at {}.", + &ctx.indent, + pos.pretty_print_line(&self.typeenv.filenames) + )?; + write!(ctx.out, "{}", &ctx.indent)?; + match ret_kind { + ReturnKind::Plain => write!(ctx.out, "return ")?, + ReturnKind::Option => write!(ctx.out, "return Some(")?, + ReturnKind::Iterator => write!(ctx.out, "returns.push(")?, + } + self.emit_expr(ctx, result)?; + if ctx.is_ref.contains(&result) { + write!(ctx.out, ".clone()")?; + } + match ret_kind { + ReturnKind::Plain => writeln!(ctx.out, ";")?, + ReturnKind::Option | ReturnKind::Iterator => writeln!(ctx.out, ");")?, + } } - self.define_val(&output, ctx, /* is_ref = */ false, termdata.ret_ty); - } - &ExprInst::Return { - index, ref value, .. - } => { - let value_expr = self.value_by_val(value, ctx); - returns.push((index, value_expr)); } } - - new_scope + Ok(()) } - fn match_variant_binders( - &self, - variant: &Variant, - arg_tys: &[TypeId], - id: InstId, - ctx: &mut BodyContext, - ) -> Vec { - arg_tys - .iter() - .zip(variant.fields.iter()) - .enumerate() - .map(|(i, (&ty, field))| { - let value = Value::Pattern { - inst: id, - output: i, - }; - let valuename = self.value_binder(&value, /* is_ref = */ true, ty); - let fieldname = &self.typeenv.syms[field.name.index()]; - self.define_val(&value, ctx, /* is_ref = */ true, field.ty); - format!("{}: {}", fieldname, valuename) - }) - .collect::>() - } + fn emit_expr(&self, ctx: &mut BodyContext, result: BindingId) -> std::fmt::Result { + if ctx.is_bound.contains(&result) { + return write!(ctx.out, "v{}", result.index()); + } - /// Returns a `bool` indicating whether this pattern inst is - /// infallible, and the number of scopes opened. - fn generate_pattern_inst( - &self, - code: &mut String, - id: InstId, - inst: &PatternInst, - indent: &str, - ctx: &mut BodyContext, - ) -> (bool, usize) { - match inst { - &PatternInst::Arg { index, ty } => { - let output = Value::Pattern { - inst: id, - output: 0, - }; - let outputname = self.value_name(&output); - let is_ref = match &self.typeenv.types[ty.index()] { - &Type::Primitive(..) => false, - _ => true, - }; - writeln!(code, "{}let {} = arg{};", indent, outputname, index).unwrap(); - self.define_val( - &Value::Pattern { - inst: id, - output: 0, - }, - ctx, - is_ref, - ty, - ); - (true, 0) + let binding = &ctx.ruleset.bindings[result.index()]; + + let mut call = + |term: TermId, + parameters: &[BindingId], + get_sig: fn(&Term, &TypeEnv) -> Option| { + let termdata = &self.termenv.terms[term.index()]; + let sig = get_sig(termdata, self.typeenv).unwrap(); + if let &[ret_ty] = &sig.ret_tys[..] { + let (is_ref, _) = self.ty(ret_ty); + if is_ref { + ctx.set_ref(result, true); + write!(ctx.out, "&")?; + } + } + write!(ctx.out, "{}(ctx", sig.full_name)?; + debug_assert_eq!(parameters.len(), sig.param_tys.len()); + for (¶meter, &arg_ty) in parameters.iter().zip(sig.param_tys.iter()) { + let (is_ref, _) = self.ty(arg_ty); + write!(ctx.out, ", ")?; + let (before, after) = match (is_ref, ctx.is_ref.contains(¶meter)) { + (false, true) => ("", ".clone()"), + (true, false) => ("&", ""), + _ => ("", ""), + }; + write!(ctx.out, "{}", before)?; + self.emit_expr(ctx, parameter)?; + write!(ctx.out, "{}", after)?; + } + write!(ctx.out, ")") + }; + + match binding { + &Binding::ConstInt { val, ty } => self.emit_int(ctx, val, ty), + Binding::ConstPrim { val } => write!(ctx.out, "{}", &self.typeenv.syms[val.index()]), + Binding::Argument { index } => write!(ctx.out, "arg{}", index.index()), + Binding::Extractor { term, parameter } => { + call(*term, std::slice::from_ref(parameter), Term::extractor_sig) } - &PatternInst::MatchEqual { ref a, ref b, .. } => { - let a = self.value_by_ref(a, ctx); - let b = self.value_by_ref(b, ctx); - writeln!(code, "{}if {} == {} {{", indent, a, b).unwrap(); - (false, 1) - } - &PatternInst::MatchInt { - ref input, - int_val, + Binding::Constructor { + term, parameters, .. + } => call(*term, ¶meters[..], Term::constructor_sig), + + Binding::MakeVariant { ty, - .. - } => { - let int_val = self.const_int(int_val, ty); - let input = self.value_by_val(input, ctx); - writeln!(code, "{}if {} == {} {{", indent, input, int_val).unwrap(); - (false, 1) - } - &PatternInst::MatchPrim { ref input, val, .. } => { - let input = self.value_by_val(input, ctx); - let sym = &self.typeenv.syms[val.index()]; - writeln!(code, "{}if {} == {} {{", indent, input, sym).unwrap(); - (false, 1) - } - &PatternInst::MatchVariant { - ref input, - input_ty, variant, - ref arg_tys, + fields, } => { - let input = self.value_by_ref(input, ctx); - let variants = match &self.typeenv.types[input_ty.index()] { - &Type::Primitive(..) => panic!("primitive type input to MatchVariant"), - &Type::Enum { ref variants, .. } => variants, + let (name, variants) = match &self.typeenv.types[ty.index()] { + Type::Enum { name, variants, .. } => (name, variants), + _ => unreachable!("MakeVariant with primitive type"), }; - let ty_name = self.type_name(input_ty, /* is_ref = */ true); let variant = &variants[variant.index()]; - let variantname = &self.typeenv.syms[variant.name.index()]; - let args = self.match_variant_binders(variant, &arg_tys[..], id, ctx); - let args = if args.is_empty() { - "".to_string() - } else { - format!("{{ {} }}", args.join(", ")) - }; - writeln!( - code, - "{}if let {}::{} {} = {} {{", - indent, ty_name, variantname, args, input - ) - .unwrap(); - (false, 1) - } - &PatternInst::Extract { - ref inputs, - ref output_tys, - term, - infallible, - multi, - .. - } => { - let termdata = &self.termenv.terms[term.index()]; - let sig = termdata.extractor_sig(self.typeenv).unwrap(); - - let input_values = inputs - .iter() - .map(|input| self.value_by_ref(input, ctx)) - .collect::>(); - let output_binders = output_tys - .iter() - .enumerate() - .map(|(i, &ty)| { - let output_val = Value::Pattern { - inst: id, - output: i, - }; - self.define_val(&output_val, ctx, /* is_ref = */ false, ty); - self.value_binder(&output_val, /* is_ref = */ false, ty) - }) - .collect::>(); - - let bind_pattern = format!( - "{open_paren}{vars}{close_paren}", - open_paren = if output_binders.len() == 1 { "" } else { "(" }, - vars = output_binders.join(", "), - close_paren = if output_binders.len() == 1 { "" } else { ")" } - ); - let etor_call = format!( - "{name}(ctx, {args})", - name = sig.full_name, - args = input_values.join(", ") - ); - - if multi { - writeln!(code, "{indent}let mut iter = {etor_call};").unwrap(); - writeln!( - code, - "{indent}while let Some({bind_pattern}) = iter.next(ctx) {{", - ) - .unwrap(); - (false, 1) - } else if infallible { - writeln!(code, "{indent}let {bind_pattern} = {etor_call};").unwrap(); - (true, 0) - } else { - writeln!(code, "{indent}if let Some({bind_pattern}) = {etor_call} {{").unwrap(); - (false, 1) + write!( + ctx.out, + "{}::{}", + &self.typeenv.syms[name.index()], + &self.typeenv.syms[variant.name.index()] + )?; + if !fields.is_empty() { + ctx.begin_block()?; + for (field, value) in variant.fields.iter().zip(fields.iter()) { + write!( + ctx.out, + "{}{}: ", + &ctx.indent, + &self.typeenv.syms[field.name.index()], + )?; + self.emit_expr(ctx, *value)?; + if ctx.is_ref.contains(&value) { + write!(ctx.out, ".clone()")?; + } + writeln!(ctx.out, ",")?; + } + ctx.end_block_without_newline()?; } + Ok(()) } - &PatternInst::Expr { - ref seq, output_ty, .. - } if seq.is_const_int().is_some() => { - let (ty, val) = seq.is_const_int().unwrap(); - assert_eq!(ty, output_ty); - let output = Value::Pattern { - inst: id, - output: 0, - }; - writeln!( - code, - "{}let {} = {};", - indent, - self.value_name(&output), - self.const_int(val, ty), - ) - .unwrap(); - self.define_val(&output, ctx, /* is_ref = */ false, ty); - (true, 0) + &Binding::MatchSome { source } => { + self.emit_expr(ctx, source)?; + write!(ctx.out, "?") + } + &Binding::MatchTuple { source, field } => { + self.emit_expr(ctx, source)?; + write!(ctx.out, ".{}", field.index()) } - &PatternInst::Expr { - ref seq, output_ty, .. - } => { - let closure_name = format!("closure{}", id.index()); - writeln!(code, "{}let mut {} = || {{", indent, closure_name).unwrap(); - let subindent = format!("{} ", indent); - let mut subctx = ctx.clone(); - let mut returns = vec![]; - for (id, inst) in seq.insts.iter().enumerate() { - let id = InstId(id); - let new_scope = self.generate_expr_inst( - code, - id, - inst, - &subindent, - &mut subctx, - &mut returns, - ); - assert!(!new_scope); - } - assert_eq!(returns.len(), 1); - writeln!(code, "{}return Some({});", subindent, returns[0].1).unwrap(); - writeln!(code, "{}}};", indent).unwrap(); - let output = Value::Pattern { - inst: id, - output: 0, - }; - writeln!( - code, - "{}if let Some({}) = {}() {{", - indent, - self.value_binder(&output, /* is_ref = */ false, output_ty), - closure_name - ) - .unwrap(); - self.define_val(&output, ctx, /* is_ref = */ false, output_ty); - - (false, 1) + // These are not supposed to happen. If they do, make the generated code fail to compile + // so this is easier to debug than if we panic during codegen. + &Binding::MatchVariant { source, field, .. } => { + self.emit_expr(ctx, source)?; + write!(ctx.out, ".{} /*FIXME*/", field.index()) + } + &Binding::Iterator { source } => { + self.emit_expr(ctx, source)?; + write!(ctx.out, ".next() /*FIXME*/") } } } - fn generate_body( + fn emit_source( &self, - code: &mut String, - depth: usize, - trie: &TrieNode, - indent: &str, - ctx: &mut BodyContext, - ret_kind: ReturnKind, - ) -> bool { - log!("generate_body:\n{}", trie.pretty()); - let mut returned = false; - match trie { - &TrieNode::Empty => {} - - &TrieNode::Leaf { ref output, .. } => { - writeln!( - code, - "{}// Rule at {}.", - indent, - output.pos.pretty_print_line(&self.typeenv.filenames[..]) - ) - .unwrap(); - - // If this is a leaf node, generate the ExprSequence and return. - let mut returns = vec![]; - let mut scopes = 0; - let mut indent = indent.to_string(); - let orig_indent = indent.clone(); - for (id, inst) in output.insts.iter().enumerate() { - let id = InstId(id); - let new_scope = - self.generate_expr_inst(code, id, inst, &indent[..], ctx, &mut returns); - if new_scope { - scopes += 1; - indent.push_str(" "); - } - } - - assert_eq!(returns.len(), 1); - let (before, after) = match ret_kind { - ReturnKind::Plain => ("return ", ""), - ReturnKind::Option => ("return Some(", ")"), - ReturnKind::Iterator => ("returns.push(", ")"), - }; - writeln!(code, "{}{}{}{};", indent, before, returns[0].1, after).unwrap(); - - for _ in 0..scopes { - writeln!(code, "{}}}", orig_indent).unwrap(); - } - - returned = ret_kind != ReturnKind::Iterator; - } - - &TrieNode::Decision { ref edges } => { - // If this is a decision node, generate each match op - // in turn (in priority order). Gather together - // adjacent MatchVariant ops with the same input and - // disjoint variants in order to create a `match` - // rather than a chain of if-lets. - - let mut i = 0; - while i < edges.len() { - // Gather adjacent match variants so that we can turn these - // into a `match` rather than a sequence of `if let`s. - let mut last = i; - let mut adjacent_variants = StableSet::new(); - let mut adjacent_variant_input = None; - log!( - "edge: prio = {:?}, symbol = {:?}", - edges[i].prio, - edges[i].symbol - ); - while last < edges.len() { - match &edges[last].symbol { - &TrieSymbol::Match { - op: PatternInst::MatchVariant { input, variant, .. }, - } => { - if adjacent_variant_input.is_none() { - adjacent_variant_input = Some(input); - } - if adjacent_variant_input == Some(input) - && !adjacent_variants.contains(&variant) - { - adjacent_variants.insert(variant); - last += 1; - } else { - break; - } - } - _ => { - break; - } - } - } - - // Now `edges[i..last]` is a run of adjacent `MatchVariants` - // (possibly an empty one). Only use a `match` form if there - // are at least two adjacent options. - if last - i > 1 { - self.generate_body_matches( - code, - depth, - &edges[i..last], - indent, - ctx, - ret_kind, - ); - i = last; - continue; - } else { - let &TrieEdge { - ref symbol, - ref node, - .. - } = &edges[i]; - i += 1; - - match symbol { - &TrieSymbol::EndOfMatch => { - returned = self.generate_body( - code, - depth + 1, - node, - indent, - ctx, - ret_kind, - ); - } - &TrieSymbol::Match { ref op } => { - let id = InstId(depth); - let (infallible, new_scopes) = - self.generate_pattern_inst(code, id, op, indent, ctx); - let mut subindent = indent.to_string(); - for _ in 0..new_scopes { - subindent.push_str(" "); - } - let sub_returned = self.generate_body( - code, - depth + 1, - node, - &subindent[..], - ctx, - ret_kind, - ); - for _ in 0..new_scopes { - writeln!(code, "{}}}", indent).unwrap(); - } - if infallible && sub_returned { - returned = true; - break; - } - } - } - } - } + ctx: &mut BodyContext, + source: BindingId, + constraint: Constraint, + ) -> std::fmt::Result { + if let Constraint::Variant { .. } = constraint { + if !ctx.is_ref.contains(&source) { + write!(ctx.out, "&")?; } } - - returned + self.emit_expr(ctx, source) } - fn generate_body_matches( + fn emit_constraint( &self, - code: &mut String, - depth: usize, - edges: &[TrieEdge], - indent: &str, - ctx: &mut BodyContext, - ret_kind: ReturnKind, - ) { - let (input, input_ty) = match &edges[0].symbol { - &TrieSymbol::Match { - op: - PatternInst::MatchVariant { - input, input_ty, .. - }, - } => (input, input_ty), - _ => unreachable!(), - }; - let (input_ty_sym, variants) = match &self.typeenv.types[input_ty.index()] { - &Type::Enum { - ref name, - ref variants, - .. - } => (name, variants), - _ => unreachable!(), - }; - let input_ty_name = &self.typeenv.syms[input_ty_sym.index()]; - - // Emit the `match`. - writeln!( - code, - "{}match {} {{", - indent, - self.value_by_ref(&input, ctx) - ) - .unwrap(); - - // Emit each case. - for &TrieEdge { - ref symbol, - ref node, + ctx: &mut BodyContext, + source: BindingId, + arm: &MatchArm, + ) -> std::fmt::Result { + let MatchArm { + constraint, + bindings, .. - } in edges - { - let id = InstId(depth); - let (variant, arg_tys) = match symbol { - &TrieSymbol::Match { - op: - PatternInst::MatchVariant { - variant, - ref arg_tys, - .. - }, - } => (variant, arg_tys), - _ => unreachable!(), - }; - - let variantinfo = &variants[variant.index()]; - let variantname = &self.typeenv.syms[variantinfo.name.index()]; - let fields = self.match_variant_binders(variantinfo, arg_tys, id, ctx); - let fields = if fields.is_empty() { - "".to_string() - } else { - format!("{{ {} }}", fields.join(", ")) - }; - writeln!( - code, - "{} &{}::{} {} => {{", - indent, input_ty_name, variantname, fields, - ) - .unwrap(); - let subindent = format!("{} ", indent); - self.generate_body(code, depth + 1, node, &subindent, ctx, ret_kind); - writeln!(code, "{} }}", indent).unwrap(); + } = arm; + for binding in bindings.iter() { + if let &Some(binding) = binding { + ctx.is_bound.insert(binding); + } } + match *constraint { + Constraint::ConstInt { val, ty } => self.emit_int(ctx, val, ty), + Constraint::ConstPrim { val } => { + write!(ctx.out, "{}", &self.typeenv.syms[val.index()]) + } + Constraint::Variant { ty, variant, .. } => { + let (name, variants) = match &self.typeenv.types[ty.index()] { + Type::Enum { name, variants, .. } => (name, variants), + _ => unreachable!("Variant constraint on primitive type"), + }; + let variant = &variants[variant.index()]; + write!( + ctx.out, + "&{}::{}", + &self.typeenv.syms[name.index()], + &self.typeenv.syms[variant.name.index()] + )?; + if !bindings.is_empty() { + ctx.begin_block()?; + let mut skipped_some = false; + for (&binding, field) in bindings.iter().zip(variant.fields.iter()) { + if let Some(binding) = binding { + write!( + ctx.out, + "{}{}: ", + &ctx.indent, + &self.typeenv.syms[field.name.index()] + )?; + let (is_ref, _) = self.ty(field.ty); + if is_ref { + ctx.set_ref(binding, true); + write!(ctx.out, "ref ")?; + } + writeln!(ctx.out, "v{},", binding.index())?; + } else { + skipped_some = true; + } + } + if skipped_some { + writeln!(ctx.out, "{}..", &ctx.indent)?; + } + ctx.end_block_without_newline()?; + } + Ok(()) + } + Constraint::Some => { + write!(ctx.out, "Some(")?; + if let Some(binding) = bindings[0] { + ctx.set_ref(binding, ctx.is_ref.contains(&source)); + write!(ctx.out, "v{}", binding.index())?; + } else { + write!(ctx.out, "_")?; + } + write!(ctx.out, ")") + } + } + } - // Always add a catchall, because we don't do exhaustiveness - // checking on the MatchVariants. - writeln!(code, "{} _ => {{}}", indent).unwrap(); - - writeln!(code, "{}}}", indent).unwrap(); + fn emit_int( + &self, + ctx: &mut BodyContext, + val: i128, + ty: TypeId, + ) -> Result<(), std::fmt::Error> { + // For the kinds of situations where we use ISLE, magic numbers are + // much more likely to be understandable if they're in hex rather than + // decimal. + // TODO: use better type info (https://github.com/bytecodealliance/wasmtime/issues/5431) + if val < 0 + && self.typeenv.types[ty.index()] + .name(self.typeenv) + .starts_with('i') + { + write!(ctx.out, "-{:#X}", -val) + } else { + write!(ctx.out, "{:#X}", val) + } } } diff --git a/cranelift/isle/isle/src/compile.rs b/cranelift/isle/isle/src/compile.rs index 3b6df764fe..ab340e302d 100644 --- a/cranelift/isle/isle/src/compile.rs +++ b/cranelift/isle/isle/src/compile.rs @@ -3,15 +3,14 @@ use std::path::Path; use crate::error::Errors; -use crate::{ast, codegen, sema, trie}; +use crate::{ast, codegen, sema}; /// Compile the given AST definitions into Rust source code. pub fn compile(defs: &ast::Defs, options: &codegen::CodegenOptions) -> Result { let mut typeenv = sema::TypeEnv::from_ast(defs)?; let termenv = sema::TermEnv::from_ast(&mut typeenv, defs)?; - crate::overlap::check(&mut typeenv, &termenv)?; - let tries = trie::build_tries(&termenv); - Ok(codegen::codegen(&typeenv, &termenv, &tries, options)) + let terms = crate::overlap::check(&typeenv, &termenv)?; + Ok(codegen::codegen(&typeenv, &termenv, &terms, options)) } /// Compile the given files into Rust source code. diff --git a/cranelift/isle/isle/src/ir.rs b/cranelift/isle/isle/src/ir.rs deleted file mode 100644 index 32f3f4c4e7..0000000000 --- a/cranelift/isle/isle/src/ir.rs +++ /dev/null @@ -1,425 +0,0 @@ -//! Lowered matching IR. - -use crate::lexer::Pos; -use crate::log; -use crate::sema::*; - -declare_id!( - /// The id of an instruction in a `PatternSequence`. - InstId -); - -/// A value produced by a LHS or RHS instruction. -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Value { - /// A value produced by an instruction in the Pattern (LHS). - Pattern { - /// The instruction that produces this value. - inst: InstId, - /// This value is the `output`th value produced by this pattern. - output: usize, - }, - /// A value produced by an instruction in the Expr (RHS). - Expr { - /// The instruction that produces this value. - inst: InstId, - /// This value is the `output`th value produced by this expression. - output: usize, - }, -} - -/// A single Pattern instruction. -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub enum PatternInst { - /// Match a value as equal to another value. Produces no values. - MatchEqual { - /// The first value. - a: Value, - /// The second value. - b: Value, - /// The type of the values. - ty: TypeId, - }, - - /// Try matching the given value as the given integer. Produces no values. - MatchInt { - /// The value to match on. - input: Value, - /// The value's type. - ty: TypeId, - /// The integer to match against the value. - int_val: i128, - }, - - /// Try matching the given value as the given constant. Produces no values. - MatchPrim { - /// The value to match on. - input: Value, - /// The type of the value. - ty: TypeId, - /// The primitive to match against the value. - val: Sym, - }, - - /// Try matching the given value as the given variant, producing `|arg_tys|` - /// values as output. - MatchVariant { - /// The value to match on. - input: Value, - /// The type of the value. - input_ty: TypeId, - /// The types of values produced upon a successful match. - arg_tys: Vec, - /// The value type's variant that we are matching against. - variant: VariantId, - }, - - /// Evaluate an expression and provide the given value as the result of this - /// match instruction. The expression has access to the pattern-values up to - /// this point in the sequence. - Expr { - /// The expression to evaluate. - seq: ExprSequence, - /// The value produced by the expression. - output: Value, - /// The type of the output value. - output_ty: TypeId, - }, - - // NB: this has to come second-to-last, because it might be infallible, for - // the same reasons that `Arg` has to be last. - // - /// Invoke an extractor, taking the given values as input (the first is the - /// value to extract, the other are the `Input`-polarity extractor args) and - /// producing an output value for each `Output`-polarity extractor arg. - Extract { - /// Whether this extraction is infallible or not. `false` - /// comes before `true`, so fallible nodes come first. - infallible: bool, - /// The value to extract, followed by polarity extractor args. - inputs: Vec, - /// The types of the inputs. - input_tys: Vec, - /// The types of the output values produced upon a successful match. - output_tys: Vec, - /// This extractor's term. - term: TermId, - /// Is this a multi-extractor? - multi: bool, - }, - - // NB: This has to go last, since it is infallible, so that when we sort - // edges in the trie, we visit infallible edges after first having tried the - // more-specific fallible options. - // - /// Get the Nth input argument, which corresponds to the Nth field - /// of the root term. - Arg { - /// The index of the argument to get. - index: usize, - /// The type of the argument. - ty: TypeId, - }, -} - -/// A single Expr instruction. -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub enum ExprInst { - /// Produce a constant integer. - ConstInt { - /// This integer type. - ty: TypeId, - /// The integer value. Must fit within the type. - val: i128, - }, - - /// Produce a constant extern value. - ConstPrim { - /// The primitive type. - ty: TypeId, - /// The primitive value. - val: Sym, - }, - - /// Create a variant. - CreateVariant { - /// The input arguments that will make up this variant's fields. - /// - /// These must be in the same order as the variant's fields. - inputs: Vec<(Value, TypeId)>, - /// The enum type. - ty: TypeId, - /// The variant within the enum that we are contructing. - variant: VariantId, - }, - - /// Invoke a constructor. - Construct { - /// The arguments to the constructor. - inputs: Vec<(Value, TypeId)>, - /// The type of the constructor. - ty: TypeId, - /// The constructor term. - term: TermId, - /// Whether this constructor is infallible or not. - infallible: bool, - /// Is this a multi-constructor? - multi: bool, - }, - - /// Set the Nth return value. Produces no values. - Return { - /// The index of the return value to set. - index: usize, - /// The type of the return value. - ty: TypeId, - /// The value to set as the `index`th return value. - value: Value, - }, -} - -impl ExprInst { - /// Invoke `f` for each value in this expression. - pub fn visit_values(&self, mut f: F) { - match self { - &ExprInst::ConstInt { .. } => {} - &ExprInst::ConstPrim { .. } => {} - &ExprInst::Construct { ref inputs, .. } - | &ExprInst::CreateVariant { ref inputs, .. } => { - for (input, _ty) in inputs { - f(*input); - } - } - &ExprInst::Return { value, .. } => { - f(value); - } - } - } -} - -/// A linear sequence of instructions that match on and destructure an -/// argument. A pattern is fallible (may not match). If it does not fail, its -/// result consists of the values produced by the `PatternInst`s, which may be -/// used by a subsequent `Expr`. -#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] -pub struct PatternSequence { - /// Instruction sequence for pattern. - /// - /// `InstId` indexes into this sequence for `Value::Pattern` values. - pub insts: Vec, -} - -/// A linear sequence of instructions that produce a new value from the -/// right-hand side of a rule, given bindings that come from a `Pattern` derived -/// from the left-hand side. -#[derive(Clone, Debug, PartialEq, Eq, Hash, Default, PartialOrd, Ord)] -pub struct ExprSequence { - /// Instruction sequence for expression. - /// - /// `InstId` indexes into this sequence for `Value::Expr` values. - pub insts: Vec, - /// Position at which the rule producing this sequence was located. - pub pos: Pos, -} - -impl ExprSequence { - /// Is this expression sequence producing a constant integer? - /// - /// If so, return the integer type and the constant. - pub fn is_const_int(&self) -> Option<(TypeId, i128)> { - if self.insts.len() == 2 && matches!(&self.insts[1], &ExprInst::Return { .. }) { - match &self.insts[0] { - &ExprInst::ConstInt { ty, val } => Some((ty, val)), - _ => None, - } - } else { - None - } - } -} - -impl PatternSequence { - fn add_inst(&mut self, inst: PatternInst) -> InstId { - let id = InstId(self.insts.len()); - self.insts.push(inst); - id - } -} - -/// Used as an intermediate representation of expressions in the [RuleVisitor] implementation for -/// [PatternSequence]. -pub struct ReturnExpr { - seq: ExprSequence, - output: Value, - output_ty: TypeId, -} - -impl RuleVisitor for PatternSequence { - type PatternVisitor = Self; - type ExprVisitor = ExprSequence; - type Expr = ReturnExpr; - - fn add_arg(&mut self, index: usize, ty: TypeId) -> Value { - let inst = self.add_inst(PatternInst::Arg { index, ty }); - Value::Pattern { inst, output: 0 } - } - - fn add_pattern(&mut self, visitor: F) { - visitor(self) - } - - fn add_expr(&mut self, visitor: F) -> ReturnExpr - where - F: FnOnce(&mut ExprSequence) -> VisitedExpr, - { - let mut expr = ExprSequence::default(); - let VisitedExpr { ty, value } = visitor(&mut expr); - let index = 0; - expr.add_inst(ExprInst::Return { index, ty, value }); - ReturnExpr { - seq: expr, - output: value, - output_ty: ty, - } - } - - fn expr_as_pattern(&mut self, expr: ReturnExpr) -> Value { - let inst = self.add_inst(PatternInst::Expr { - seq: expr.seq, - output: expr.output, - output_ty: expr.output_ty, - }); - - // Create values for all outputs. - Value::Pattern { inst, output: 0 } - } - - fn pattern_as_expr(&mut self, pattern: Value) -> Value { - pattern - } -} - -impl PatternVisitor for PatternSequence { - type PatternId = Value; - - fn add_match_equal(&mut self, a: Value, b: Value, ty: TypeId) { - self.add_inst(PatternInst::MatchEqual { a, b, ty }); - } - - fn add_match_int(&mut self, input: Value, ty: TypeId, int_val: i128) { - self.add_inst(PatternInst::MatchInt { input, ty, int_val }); - } - - fn add_match_prim(&mut self, input: Value, ty: TypeId, val: Sym) { - self.add_inst(PatternInst::MatchPrim { input, ty, val }); - } - - fn add_match_variant( - &mut self, - input: Value, - input_ty: TypeId, - arg_tys: &[TypeId], - variant: VariantId, - ) -> Vec { - let outputs = arg_tys.len(); - let arg_tys = arg_tys.into(); - let inst = self.add_inst(PatternInst::MatchVariant { - input, - input_ty, - arg_tys, - variant, - }); - (0..outputs) - .map(|output| Value::Pattern { inst, output }) - .collect() - } - - fn add_extract( - &mut self, - input: Value, - input_ty: TypeId, - output_tys: Vec, - term: TermId, - infallible: bool, - multi: bool, - ) -> Vec { - let outputs = output_tys.len(); - let inst = self.add_inst(PatternInst::Extract { - inputs: vec![input], - input_tys: vec![input_ty], - output_tys, - term, - infallible, - multi, - }); - (0..outputs) - .map(|output| Value::Pattern { inst, output }) - .collect() - } -} - -impl ExprSequence { - fn add_inst(&mut self, inst: ExprInst) -> InstId { - let id = InstId(self.insts.len()); - self.insts.push(inst); - id - } -} - -impl ExprVisitor for ExprSequence { - type ExprId = Value; - - fn add_const_int(&mut self, ty: TypeId, val: i128) -> Value { - let inst = self.add_inst(ExprInst::ConstInt { ty, val }); - Value::Expr { inst, output: 0 } - } - - fn add_const_prim(&mut self, ty: TypeId, val: Sym) -> Value { - let inst = self.add_inst(ExprInst::ConstPrim { ty, val }); - Value::Expr { inst, output: 0 } - } - - fn add_create_variant( - &mut self, - inputs: Vec<(Value, TypeId)>, - ty: TypeId, - variant: VariantId, - ) -> Value { - let inst = self.add_inst(ExprInst::CreateVariant { - inputs, - ty, - variant, - }); - Value::Expr { inst, output: 0 } - } - - fn add_construct( - &mut self, - inputs: Vec<(Value, TypeId)>, - ty: TypeId, - term: TermId, - _pure: bool, - infallible: bool, - multi: bool, - ) -> Value { - let inst = self.add_inst(ExprInst::Construct { - inputs, - ty, - term, - infallible, - multi, - }); - Value::Expr { inst, output: 0 } - } -} - -/// Build a sequence from a rule. -pub fn lower_rule(termenv: &TermEnv, rule: RuleId) -> (PatternSequence, ExprSequence) { - let ruledata = &termenv.rules[rule.index()]; - log!("lower_rule: ruledata {:?}", ruledata); - - let mut pattern_seq = PatternSequence::default(); - let mut expr_seq = ruledata.visit(&mut pattern_seq, termenv).seq; - expr_seq.pos = ruledata.pos; - (pattern_seq, expr_seq) -} diff --git a/cranelift/isle/isle/src/lib.rs b/cranelift/isle/isle/src/lib.rs index 5a4614c73a..a01d5a8da0 100644 --- a/cranelift/isle/isle/src/lib.rs +++ b/cranelift/isle/isle/src/lib.rs @@ -97,7 +97,7 @@ impl Index<&K> for StableMap { /// Stores disjoint sets and provides efficient operations to merge two sets, and to find a /// representative member of a set given any member of that set. In this implementation, sets always /// have at least two members, and can only be formed by the `merge` operation. -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default)] pub struct DisjointSets { parent: HashMap, } @@ -182,6 +182,26 @@ impl DisjointSets { } } + /// Returns whether the given items have both been merged into the same set. If either is not + /// part of any set, returns `false`. + /// + /// ``` + /// let mut sets = cranelift_isle::DisjointSets::default(); + /// sets.merge(1, 2); + /// sets.merge(1, 3); + /// sets.merge(2, 4); + /// sets.merge(5, 6); + /// assert!(sets.in_same_set(2, 3)); + /// assert!(sets.in_same_set(1, 4)); + /// assert!(sets.in_same_set(3, 4)); + /// assert!(!sets.in_same_set(4, 5)); + /// ``` + pub fn in_same_set(&self, x: T, y: T) -> bool { + let x = self.find(x); + let y = self.find(y); + x.zip(y).filter(|(x, y)| x == y).is_some() + } + /// Remove the set containing the given item, and return all members of that set. The set is /// returned in sorted order. This method takes time linear in the total size of all sets. /// @@ -242,11 +262,10 @@ pub mod ast; pub mod codegen; pub mod compile; pub mod error; -pub mod ir; pub mod lexer; mod log; pub mod overlap; pub mod parser; pub mod sema; -pub mod trie; +pub mod serialize; pub mod trie_again; diff --git a/cranelift/isle/isle/src/serialize.rs b/cranelift/isle/isle/src/serialize.rs new file mode 100644 index 0000000000..34728759bf --- /dev/null +++ b/cranelift/isle/isle/src/serialize.rs @@ -0,0 +1,846 @@ +//! Put "sea of nodes" representation of a `RuleSet` into a sequential order. +//! +//! We're trying to satisfy two key constraints on generated code: +//! +//! First, we must produce the same result as if we tested the left-hand side +//! of every rule in descending priority order and picked the first match. +//! But that would mean a lot of duplicated work since many rules have similar +//! patterns. We want to evaluate in an order that gets the same answer but +//! does as little work as possible. +//! +//! Second, some ISLE patterns can only be implemented in Rust using a `match` +//! expression (or various choices of syntactic sugar). Others can only +//! be implemented as expressions, which can't be evaluated while matching +//! patterns in Rust. So we need to alternate between pattern matching and +//! expression evaluation. +//! +//! To meet both requirements, we repeatedly partition the set of rules for a +//! term and build a tree of Rust control-flow constructs corresponding to each +//! partition. The root of such a tree is a [Block], and [serialize] constructs +//! it. +use std::cmp::Reverse; + +use crate::lexer::Pos; +use crate::trie_again::{Binding, BindingId, Constraint, Rule, RuleSet}; +use crate::DisjointSets; + +/// Decomposes the rule-set into a tree of [Block]s. +pub fn serialize(rules: &RuleSet) -> Block { + // While building the tree, we need temporary storage to keep track of + // different subsets of the rules as we partition them into ever smaller + // sets. As long as we're allowed to re-order the rules, we can ensure + // that every partition is contiguous; but since we plan to re-order them, + // we actually just store indexes into the `RuleSet` to minimize data + // movement. The algorithm in this module never duplicates or discards + // rules, so the total size of all partitions is exactly the number of + // rules. For all the above reasons, we can pre-allocate all the space + // we'll need to hold those partitions up front and share it throughout the + // tree. + // + // As an interesting side effect, when the algorithm finishes, this vector + // records the order in which rule bodies will be emitted in the generated + // Rust. We don't care because we could get the same information from the + // built tree, but it may be helpful to think about the intermediate steps + // as recursively sorting the rules. It may not be possible to produce the + // same order using a comparison sort, and the asymptotic complexity is + // probably worse than the O(n log n) of a comparison sort, but it's still + // doing sorting of some kind. + let mut order = Vec::from_iter(0..rules.rules.len()); + Decomposition::new(rules).sort(&mut order) +} + +/// A sequence of steps to evaluate in order. Any step may return early, so +/// steps ordered later can assume the negation of the conditions evaluated in +/// earlier steps. +#[derive(Default)] +pub struct Block { + /// Steps to evaluate. + pub steps: Vec, +} + +/// A step to evaluate involves possibly let-binding some expressions, then +/// executing some control flow construct. +pub struct EvalStep { + /// Before evaluating this case, emit let-bindings in this order. + pub bind_order: Vec, + /// The control-flow construct to execute at this point. + pub check: ControlFlow, +} + +/// What kind of control-flow structure do we need to emit here? +pub enum ControlFlow { + /// Test a binding site against one or more mutually-exclusive patterns and + /// branch to the appropriate block if a pattern matches. + Match { + /// Which binding site are we examining at this point? + source: BindingId, + /// What patterns do we care about? + arms: Vec, + }, + /// Test whether two binding sites have values which are equal when + /// evaluated on the current input. + Equal { + /// One binding site. + a: BindingId, + /// The other binding site. To ensure we always generate the same code + /// given the same set of ISLE rules, `b` should be strictly greater + /// than `a`. + b: BindingId, + /// If the test succeeds, evaluate this block. + body: Block, + }, + /// Evaluate a block once with each value of the given binding site. + Loop { + /// A binding site of type [Binding::Iterator]. Its source binding site + /// must be a multi-extractor or multi-constructor call. + result: BindingId, + /// What to evaluate with each binding. + body: Block, + }, + /// Return a result from the right-hand side of a rule. If we're building a + /// multi-constructor then this doesn't actually return, but adds to a list + /// of results instead. Otherwise this return stops evaluation before any + /// later steps. + Return { + /// Where was the rule defined that had this right-hand side? + pos: Pos, + /// What is the result expression which should be returned if this + /// rule matched? + result: BindingId, + }, +} + +/// One concrete pattern and the block to evaluate if the pattern matches. +pub struct MatchArm { + /// The pattern to match. + pub constraint: Constraint, + /// If this pattern matches, it brings these bindings into scope. If a + /// binding is unused in this block, then the corresponding position in the + /// pattern's bindings may be `None`. + pub bindings: Vec>, + /// Steps to evaluate if the pattern matched. + pub body: Block, +} + +/// Given a set of rules that's been partitioned into two groups, move rules +/// from the first partition to the second if there are higher-priority rules +/// in the second group. In the final generated code, we'll check the rules +/// in the first ("selected") group before any in the second ("deferred") +/// group. But we need the result to be _as if_ we checked the rules in strict +/// descending priority order. +/// +/// When evaluating the relationship between one rule in the selected set and +/// one rule in the deferred set, there are two cases where we can keep a rule +/// in the selected set: +/// 1. The deferred rule is lower priority than the selected rule; or +/// 2. The two rules don't overlap, so they can't match on the same inputs. +/// +/// In either case, if the selected rule matches then we know the deferred rule +/// would not have been the one we wanted anyway; and if it doesn't match then +/// the fall-through semantics of the code we generate will let us go on to +/// check the deferred rule. +/// +/// So a rule can stay in the selected set as long as it's in one of the above +/// relationships with every rule in the deferred set. +/// +/// Due to the overlap checking pass which occurs before codegen, we know that +/// if two rules have the same priority, they do not overlap. So case 1 above +/// can be expanded to when the deferred rule is lower _or equal_ priority +/// to the selected rule. This much overlap checking is absolutely necessary: +/// There are terms where codegen is impossible if we use only the unmodified +/// case 1 and don't also check case 2. +/// +/// Aside from the equal-priority case, though, case 2 does not seem to matter +/// in practice. On the current backends, doing a full overlap check here does +/// not change the generated code at all. So we don't bother. +/// +/// Since this function never moves rules from the deferred set to the selected +/// set, the returned partition-point is always less than or equal to the +/// initial partition-point. +fn respect_priority(rules: &RuleSet, order: &mut [usize], partition_point: usize) -> usize { + let (selected, deferred) = order.split_at_mut(partition_point); + + if let Some(max_deferred_prio) = deferred.iter().map(|&idx| rules.rules[idx].prio).max() { + partition_in_place(selected, |&idx| rules.rules[idx].prio >= max_deferred_prio) + } else { + // If the deferred set is empty, all selected rules are fine where + // they are. + partition_point + } +} + +/// A query which can be tested against a [Rule] to see if that rule requires +/// the given kind of control flow around the given binding sites. These +/// choices correspond to the identically-named variants of [ControlFlow]. +/// +/// The order of these variants is significant, because it's used as a tie- +/// breaker in the heuristic that picks which control flow to generate next. +/// +/// - Loops should always be chosen last. If a rule needs to run once for each +/// value from an iterator, but only if some other condition is true, we +/// should check the other condition first. +/// +/// - Sorting concrete [HasControlFlow::Match] constraints first has the effect +/// of clustering such constraints together, which is not important but means +/// codegen could theoretically merge the cluster of matches into a single +/// Rust `match` statement. +#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] +enum HasControlFlow { + /// Find rules which have a concrete pattern constraint on the given + /// binding site. + Match(BindingId), + + /// Find rules which require both given binding sites to be in the same + /// equivalence class. + Equal(BindingId, BindingId), + + /// Find rules which must loop over the multiple values of the given + /// binding site. + Loop(BindingId), +} + +struct PartitionResults { + any_matched: bool, + valid: usize, +} + +impl HasControlFlow { + /// Identify which rules both satisfy this query, and are safe to evaluate + /// before all rules that don't satisfy the query, considering rules' + /// relative priorities like [respect_priority]. Partition matching rules + /// first in `order`. Return the number of rules which are valid with + /// respect to priority, as well as whether any rules matched the query at + /// all. No ordering is guaranteed within either partition, which allows + /// this function to run in linear time. That's fine because later we'll + /// recursively sort both partitions. + fn partition(self, rules: &RuleSet, order: &mut [usize]) -> PartitionResults { + let matching = partition_in_place(order, |&idx| { + let rule = &rules.rules[idx]; + match self { + HasControlFlow::Match(binding_id) => rule.get_constraint(binding_id).is_some(), + HasControlFlow::Equal(x, y) => rule.equals.in_same_set(x, y), + HasControlFlow::Loop(binding_id) => rule.iterators.contains(&binding_id), + } + }); + PartitionResults { + any_matched: matching > 0, + valid: respect_priority(rules, order, matching), + } + } +} + +/// As we proceed through sorting a term's rules, the term's binding sites move +/// through this sequence of states. This state machine helps us avoid doing +/// the same thing with a binding site more than once in any subtree. +#[derive(Clone, Copy, Debug, Default, Eq, Ord, PartialEq, PartialOrd)] +enum BindingState { + /// Initially, all binding sites are unavailable for evaluation except for + /// top-level arguments, constants, and similar. + #[default] + Unavailable, + /// As more binding sites become available, it becomes possible to evaluate + /// bindings which depend on those sites. + Available, + /// Once we've decided a binding is needed in order to make progress in + /// matching, we emit a let-binding for it. We shouldn't evaluate it a + /// second time, if possible. + Emitted, + /// We can only match a constraint against a binding site if we can emit it + /// first. Afterward, we should not try to match a constraint against that + /// site again in the same subtree. + Matched, +} + +/// A sort key used to order control-flow candidates in `best_control_flow`. +#[derive(Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd)] +struct Score { + // We prefer to match as many rules at once as possible. + count: usize, + // Break ties by preferring bindings we've already emitted. + state: BindingState, +} + +impl Score { + /// Recompute this score. Returns whether this is a valid candidate; if + /// not, the score may not have been updated and the candidate should + /// be removed from further consideration. The `partition` callback is + /// evaluated lazily. + fn update( + &mut self, + state: BindingState, + partition: impl FnOnce() -> PartitionResults, + ) -> bool { + // Candidates which have already been matched in this partition must + // not be matched again. There's never anything to be gained from + // matching a binding site when you're in an evaluation path where you + // already know exactly what pattern that binding site matches. And + // without this check, we could go into an infinite loop: all rules in + // the current partition match the same pattern for this binding site, + // so matching on it doesn't reduce the number of rules to check and it + // doesn't make more binding sites available. + // + // Note that equality constraints never make a binding site `Matched` + // and are de-duplicated using more complicated equivalence-class + // checks instead. + if state == BindingState::Matched { + return false; + } + self.state = state; + + // The score is not based solely on how many rules have this + // constraint, but on how many such rules can go into the same block + // without violating rule priority. This number can grow as higher- + // priority rules are removed from the partition, so we can't drop + // candidates just because this is zero. If some rule has this + // constraint, it will become viable in some later partition. + let partition = partition(); + self.count = partition.valid; + + // Only consider constraints that are present in some rule in the + // current partition. Note that as we partition the rule set into + // smaller groups, the number of rules which have a particular kind of + // constraint can never grow, so a candidate removed here doesn't need + // to be examined again in this partition. + partition.any_matched + } +} + +/// A rule filter ([HasControlFlow]), plus temporary storage for the sort +/// key used in `best_control_flow` to order these candidates. Keeping the +/// temporary storage here lets us avoid repeated heap allocations. +#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] +struct Candidate { + score: Score, + // Last resort tie-breaker: defer to HasControlFlow order, but prefer + // control-flow that sorts earlier. + kind: Reverse, +} + +impl Candidate { + /// Construct a candidate where the score is not set. The score will need + /// to be reset by [Score::update] before use. + fn new(kind: HasControlFlow) -> Self { + Candidate { + score: Score::default(), + kind: Reverse(kind), + } + } +} + +/// A single binding site to check for participation in equality constraints, +/// plus temporary storage for the score used in `best_control_flow` to order +/// these candidates. Keeping the temporary storage here lets us avoid repeated +/// heap allocations. +#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] +struct EqualCandidate { + score: Score, + // Last resort tie-breaker: prefer earlier binding sites. + source: Reverse, +} + +impl EqualCandidate { + /// Construct a candidate where the score is not set. The score will need + /// to be reset by [Score::update] before use. + fn new(source: BindingId) -> Self { + EqualCandidate { + score: Score::default(), + source: Reverse(source), + } + } +} + +/// State for a [Decomposition] that needs to be cloned when entering a nested +/// scope, so that changes in that scope don't affect this one. +#[derive(Clone, Default)] +struct ScopedState { + /// The state of all binding sites at this point in the tree, indexed by + /// [BindingId]. Bindings which become available in nested scopes don't + /// magically become available in outer scopes too. + ready: Vec, + /// The current set of candidates for control flow to add at this point in + /// the tree. We can't rely on any match results that might be computed in + /// a nested scope, so if we still care about a candidate in the fallback + /// case then we need to emit the correct control flow for it again. + candidates: Vec, + /// The current set of binding sites which participate in equality + /// constraints at this point in the tree. We can't rely on any match + /// results that might be computed in a nested scope, so if we still care + /// about a candidate in the fallback case then we need to emit the correct + /// control flow for it again. + equal_candidates: Vec, + /// Equivalence classes that we've established on the current path from + /// the root. + equal: DisjointSets, +} + +/// Builder for one [Block] in the tree. +struct Decomposition<'a> { + /// The complete RuleSet, shared across the whole tree. + rules: &'a RuleSet, + /// Decomposition state that is scoped to the current subtree. + scope: ScopedState, + /// Accumulator for bindings that should be emitted before the next + /// control-flow construct. + bind_order: Vec, + /// Accumulator for the final Block that we'll return as this subtree. + block: Block, +} + +impl<'a> Decomposition<'a> { + /// Create a builder for the root [Block]. + fn new(rules: &'a RuleSet) -> Decomposition<'a> { + let mut scope = ScopedState::default(); + scope.ready.resize(rules.bindings.len(), Default::default()); + let mut result = Decomposition { + rules, + scope, + bind_order: Default::default(), + block: Default::default(), + }; + result.add_bindings(); + result + } + + /// Create a builder for a nested [Block]. + fn new_block(&mut self) -> Decomposition { + Decomposition { + rules: self.rules, + scope: self.scope.clone(), + bind_order: Default::default(), + block: Default::default(), + } + } + + /// Ensure that every binding site's state reflects its dependencies' + /// states. This takes time linear in the number of bindings. Because + /// `trie_again` only hash-conses a binding after all its dependencies have + /// already been hash-consed, a single in-order pass visits a binding's + /// dependencies before visiting the binding itself. + fn add_bindings(&mut self) { + for (idx, binding) in self.rules.bindings.iter().enumerate() { + // We only add these bindings when matching a corresponding + // type of control flow, in `make_control_flow`. + if matches!( + binding, + Binding::Iterator { .. } | Binding::MatchVariant { .. } | Binding::MatchSome { .. } + ) { + continue; + } + + // TODO: proactively put some bindings in `Emitted` state + // That makes them visible to the best-binding heuristic, which + // prefers to match on already-emitted bindings first. This helps + // to sort cheap computations before expensive ones. + + let idx: BindingId = idx.try_into().unwrap(); + if self.scope.ready[idx.index()] < BindingState::Available { + if binding + .sources() + .iter() + .all(|&source| self.scope.ready[source.index()] >= BindingState::Available) + { + self.set_ready(idx, BindingState::Available); + } + } + } + } + + /// Determines the final evaluation order for the given subset of rules, and + /// builds a [Block] representing that order. + fn sort(mut self, mut order: &mut [usize]) -> Block { + while let Some(best) = self.best_control_flow(order) { + // Peel off all rules that have this particular control flow, and + // save the rest for the next iteration of the loop. + let partition_point = best.partition(self.rules, order).valid; + debug_assert!(partition_point > 0); + let (this, rest) = order.split_at_mut(partition_point); + order = rest; + + // Recursively build the control-flow tree for these rules. + let check = self.make_control_flow(best, this); + // Note that `make_control_flow` may have added more let-bindings. + let bind_order = std::mem::take(&mut self.bind_order); + self.block.steps.push(EvalStep { bind_order, check }); + } + + // At this point, `best_control_flow` says the remaining rules don't + // have any control flow left to emit. That could be because there are + // no unhandled rules left, or because every candidate for control flow + // for the remaining rules has already been matched by some ancestor in + // the tree. + debug_assert_eq!(self.scope.candidates.len(), 0); + // TODO: assert something about self.equal_candidates? + + // If we're building a multi-constructor, then there could be multiple + // rules with the same left-hand side. We'll evaluate them all, but + // to keep the output consistent, first sort by descending priority + // and break ties with the order the rules were declared. In non-multi + // constructors, there should be at most one rule remaining here. + order.sort_unstable_by_key(|&idx| (Reverse(self.rules.rules[idx].prio), idx)); + for &idx in order.iter() { + let &Rule { + pos, + result, + ref impure, + .. + } = &self.rules.rules[idx]; + + // Ensure that any impure constructors are called, even if their + // results aren't used. + for &impure in impure.iter() { + self.use_expr(impure); + } + self.use_expr(result); + + let check = ControlFlow::Return { pos, result }; + let bind_order = std::mem::take(&mut self.bind_order); + self.block.steps.push(EvalStep { bind_order, check }); + } + + self.block + } + + /// Let-bind this binding site and all its dependencies, skipping any + /// which are already let-bound. Also skip let-bindings for certain trivial + /// expressions which are safe and cheap to evaluate multiple times, + /// because that reduces clutter in the generated code. + fn use_expr(&mut self, name: BindingId) { + if self.scope.ready[name.index()] < BindingState::Emitted { + self.set_ready(name, BindingState::Emitted); + let binding = &self.rules.bindings[name.index()]; + for &source in binding.sources() { + self.use_expr(source); + } + + let should_let_bind = match binding { + Binding::ConstInt { .. } => false, + Binding::ConstPrim { .. } => false, + Binding::Argument { .. } => false, + Binding::MatchTuple { .. } => false, + + // Only let-bind variant constructors if they have some fields. + // Building a variant with no fields is cheap, but don't + // duplicate more complex expressions. + Binding::MakeVariant { fields, .. } => !fields.is_empty(), + + // By default, do let-bind: that's always safe. + _ => true, + }; + if should_let_bind { + self.bind_order.push(name); + } + } + } + + /// Build one control-flow construct and its subtree for the specified rules. + /// The rules in `order` must all have the kind of control-flow named in `best`. + fn make_control_flow(&mut self, best: HasControlFlow, order: &mut [usize]) -> ControlFlow { + match best { + HasControlFlow::Match(source) => { + self.use_expr(source); + self.add_bindings(); + let mut arms = Vec::new(); + + let get_constraint = + |idx: usize| self.rules.rules[idx].get_constraint(source).unwrap(); + + // Ensure that identical constraints are grouped together, then + // loop over each group. + order.sort_unstable_by_key(|&idx| get_constraint(idx)); + for g in group_by_mut(order, |&a, &b| get_constraint(a) == get_constraint(b)) { + // Applying a constraint moves the discriminant from + // Emitted to Matched, but only within the constraint's + // match arm; later fallthrough cases may need to match + // this discriminant again. Since `source` is in the + // `Emitted` state in the parent due to the above call + // to `use_expr`, calling `add_bindings` again after this + // wouldn't change anything. + let mut child = self.new_block(); + child.set_ready(source, BindingState::Matched); + + // Get the constraint for this group, and all of the + // binding sites that it introduces. + let constraint = get_constraint(g[0]); + let bindings = Vec::from_iter( + constraint + .bindings_for(source) + .into_iter() + .map(|b| child.rules.find_binding(&b)), + ); + + let mut changed = false; + for &binding in bindings.iter() { + if let Some(binding) = binding { + // Matching a pattern makes its bindings + // available, and also emits code to bind + // them. + child.set_ready(binding, BindingState::Emitted); + changed = true; + } + } + + // As an optimization, only propagate availability + // if we changed any binding's readiness. + if changed { + child.add_bindings(); + } + + // Recursively construct a Block for this group of rules. + let body = child.sort(g); + arms.push(MatchArm { + constraint, + bindings, + body, + }); + } + + ControlFlow::Match { source, arms } + } + + HasControlFlow::Equal(a, b) => { + // Both sides of the equality test must be evaluated before + // the condition can be tested. Go ahead and let-bind them + // so they're available without re-evaluation in fall-through + // cases. + self.use_expr(a); + self.use_expr(b); + self.add_bindings(); + + let mut child = self.new_block(); + // Never mark binding sites used in equality constraints as + // "matched", because either might need to be used again in + // a later equality check. Instead record that they're in the + // same equivalence class on this path. + child.scope.equal.merge(a, b); + let body = child.sort(order); + ControlFlow::Equal { a, b, body } + } + + HasControlFlow::Loop(source) => { + // Consuming a multi-term involves two binding sites: + // calling the multi-term to get an iterator (the `source`), + // and looping over the iterator to get a binding for each + // `result`. + let result = self + .rules + .find_binding(&Binding::Iterator { source }) + .unwrap(); + + // We must not let-bind the iterator until we're ready to + // consume it, because it can only be consumed once. This also + // means that the let-binding for `source` is not actually + // reusable after this point, so even though we need to emit + // its let-binding here, we pretend we haven't. + let base_state = self.scope.ready[source.index()]; + debug_assert_eq!(base_state, BindingState::Available); + self.use_expr(source); + self.scope.ready[source.index()] = base_state; + self.add_bindings(); + + let mut child = self.new_block(); + child.set_ready(source, BindingState::Matched); + child.set_ready(result, BindingState::Emitted); + child.add_bindings(); + let body = child.sort(order); + ControlFlow::Loop { result, body } + } + } + } + + /// Advance the given binding to a new state. The new state usually should + /// be greater than the existing state; but at the least it must never + /// go backward. + fn set_ready(&mut self, source: BindingId, state: BindingState) { + let old = &mut self.scope.ready[source.index()]; + debug_assert!(*old <= state); + + // Add candidates for this binding, but only when it first becomes + // available. + if let BindingState::Unavailable = old { + // A binding site can't have all of these kinds of constraint, + // and many have none. But `best_control_flow` has to check all + // candidates anyway, so let it figure out which (if any) of these + // are applicable. It will only check false candidates once on any + // partition, removing them from this list immediately. + self.scope.candidates.extend([ + Candidate::new(HasControlFlow::Match(source)), + Candidate::new(HasControlFlow::Loop(source)), + ]); + self.scope + .equal_candidates + .push(EqualCandidate::new(source)); + } + + *old = state; + } + + /// For the specified set of rules, heuristically choose which control-flow + /// will minimize redundant work when the generated code is running. + fn best_control_flow(&mut self, order: &mut [usize]) -> Option { + // If there are no rules left, none of the candidates will match + // anything in the `retain_mut` call below, so short-circuit it. + if order.is_empty() { + // This is only read in a debug-assert but it's fast so just do it + self.scope.candidates.clear(); + return None; + } + + // Remove false candidates, and recompute the candidate score for the + // current set of rules in `order`. + self.scope.candidates.retain_mut(|candidate| { + let kind = candidate.kind.0; + let source = match kind { + HasControlFlow::Match(source) => source, + HasControlFlow::Loop(source) => source, + HasControlFlow::Equal(..) => unreachable!(), + }; + let state = self.scope.ready[source.index()]; + candidate + .score + .update(state, || kind.partition(self.rules, order)) + }); + + // Find the best normal candidate. + let mut best = self.scope.candidates.iter().max().cloned(); + + // Equality constraints are more complicated. We need to identify + // some pair of binding sites which are constrained to be equal in at + // least one rule in the current partition. We do this in two steps. + // First, find each single binding site which participates in any + // equality constraint in some rule. We compute the best-case `Score` + // we could get, if there were another binding site where all the rules + // constraining this binding site require it to be equal to that one. + self.scope.equal_candidates.retain_mut(|candidate| { + let source = candidate.source.0; + let state = self.scope.ready[source.index()]; + candidate.score.update(state, || { + let matching = partition_in_place(order, |&idx| { + self.rules.rules[idx].equals.find(source).is_some() + }); + PartitionResults { + any_matched: matching > 0, + valid: respect_priority(self.rules, order, matching), + } + }) + }); + + // Now that we know which single binding sites participate in any + // equality constraints, we need to find the best pair of binding + // sites. Rules that require binding sites `x` and `y` to be equal are + // a subset of the intersection of rules constraining `x` and those + // constraining `y`. So the upper bound on the number of matching rules + // is whichever candidate is smaller. + // + // Do an O(n log n) sort to put the best single binding sites first. + // Then the O(n^2) all-pairs loop can do branch-and-bound style + // pruning, breaking out of a loop as soon as the remaining candidates + // must all produce worse results than our current best candidate. + // + // Note that `x` and `y` are reversed, to sort in descending order. + self.scope + .equal_candidates + .sort_unstable_by(|x, y| y.cmp(x)); + + let mut equals = self.scope.equal_candidates.iter(); + while let Some(x) = equals.next() { + if Some(&x.score) < best.as_ref().map(|best| &best.score) { + break; + } + let x_id = x.source.0; + for y in equals.as_slice().iter() { + if Some(&y.score) < best.as_ref().map(|best| &best.score) { + break; + } + let y_id = y.source.0; + // If x and y are already in the same path-scoped equivalence + // class, then skip this pair because we already emitted this + // check or a combination of equivalent checks on this path. + if !self.scope.equal.in_same_set(x_id, y_id) { + // Sort arguments for consistency. + let kind = if x_id < y_id { + HasControlFlow::Equal(x_id, y_id) + } else { + HasControlFlow::Equal(y_id, x_id) + }; + let pair = Candidate { + kind: Reverse(kind), + score: Score { + count: kind.partition(self.rules, order).valid, + // Only treat this as already-emitted if + // both bindings are. + state: x.score.state.min(y.score.state), + }, + }; + if best.as_ref() < Some(&pair) { + best = Some(pair); + } + } + } + } + + best.filter(|candidate| candidate.score.count > 0) + .map(|candidate| candidate.kind.0) + } +} + +/// Places all elements which satisfy the predicate at the beginning of the +/// slice, and all elements which don't at the end. Returns the number of +/// elements in the first partition. +/// +/// This function runs in time linear in the number of elements, and calls +/// the predicate exactly once per element. If either partition is empty, no +/// writes will occur in the slice, so it's okay to call this frequently with +/// predicates that we expect won't match anything. +fn partition_in_place(xs: &mut [T], mut pred: impl FnMut(&T) -> bool) -> usize { + let mut iter = xs.iter_mut(); + let mut partition_point = 0; + while let Some(a) = iter.next() { + if pred(a) { + partition_point += 1; + } else { + // `a` belongs in the partition at the end. If there's some later + // element `b` that belongs in the partition at the beginning, + // swap them. Working backwards from the end establishes the loop + // invariant that both ends of the array are partitioned correctly, + // and only the middle needs to be checked. + while let Some(b) = iter.next_back() { + if pred(b) { + std::mem::swap(a, b); + partition_point += 1; + break; + } + } + } + } + partition_point +} + +fn group_by_mut( + mut xs: &mut [T], + mut pred: impl FnMut(&T, &T) -> bool, +) -> impl Iterator { + std::iter::from_fn(move || { + if xs.is_empty() { + None + } else { + let mid = xs + .windows(2) + .position(|w| !pred(&w[0], &w[1])) + .map_or(xs.len(), |x| x + 1); + let slice = std::mem::take(&mut xs); + let (group, rest) = slice.split_at_mut(mid); + xs = rest; + Some(group) + } + }) +} + +#[test] +fn test_group_mut() { + let slice = &mut [1, 1, 1, 3, 3, 2, 2, 2]; + let mut iter = group_by_mut(slice, |a, b| a == b); + assert_eq!(iter.next(), Some(&mut [1, 1, 1][..])); + assert_eq!(iter.next(), Some(&mut [3, 3][..])); + assert_eq!(iter.next(), Some(&mut [2, 2, 2][..])); + assert_eq!(iter.next(), None); +} diff --git a/cranelift/isle/isle/src/trie.rs b/cranelift/isle/isle/src/trie.rs deleted file mode 100644 index 9027f96fe3..0000000000 --- a/cranelift/isle/isle/src/trie.rs +++ /dev/null @@ -1,321 +0,0 @@ -//! Trie construction. - -use crate::ir::{lower_rule, ExprSequence, PatternInst}; -use crate::log; -use crate::sema::{TermEnv, TermId}; -use std::collections::BTreeMap; - -/// Construct the tries for each term. -pub fn build_tries(termenv: &TermEnv) -> BTreeMap { - let mut builder = TermFunctionsBuilder::default(); - builder.build(termenv); - log!("builder: {:?}", builder); - builder.finalize() -} - -/// One "input symbol" for the decision tree that handles matching on -/// a term. Each symbol represents one step: we either run a match op, -/// or we finish the match. -/// -/// Note that in the original Peepmatic scheme, the input-symbol to -/// the FSM was specified slightly differently. The automaton -/// responded to alphabet symbols that corresponded only to match -/// results, and the "extra state" was used at each automaton node to -/// represent the op to run next. This extra state differentiated -/// nodes that would otherwise be merged together by -/// deduplication. That scheme works well enough, but the "extra -/// state" is slightly confusing and diverges slightly from a pure -/// automaton. -/// -/// Instead, here, we imagine that the user of the automaton/trie can -/// query the possible transition edges out of the current state. Each -/// of these edges corresponds to one possible match op to run. After -/// running a match op, we reach a new state corresponding to -/// successful matches up to that point. -/// -/// However, it's a bit more subtle than this. Consider the -/// prioritization problem. We want to give the DSL user the ability -/// to change the order in which rules apply, for example to have a -/// tier of "fallback rules" that apply only if more custom rules do -/// not match. -/// -/// A somewhat simplistic answer to this problem is "more specific -/// rule wins". However, this implies the existence of a total -/// ordering of linearized match sequences that may not fully capture -/// the intuitive meaning of "more specific". Consider three left-hand -/// sides: -/// -/// - (A _ _) -/// - (A (B _) _) -/// - (A _ (B _)) -/// -/// Intuitively, the first is the least specific. Given the input `(A -/// (B 1) (B 2))`, we can say for sure that the first should not be -/// chosen, because either the second or third would match "more" of -/// the input tree. But which of the second and third should be -/// chosen? A "lexicographic ordering" rule would say that we sort -/// left-hand sides such that the `(B _)` sub-pattern comes before the -/// wildcard `_`, so the second rule wins. But that is arbitrarily -/// privileging one over the other based on the order of the -/// arguments. -/// -/// Instead, we can accept explicit priorities from the user to allow -/// either choice. So we need a data structure that can associate -/// matching inputs *with priorities* to outputs. -/// -/// Next, we build a decision tree rather than an FSM. Why? Because -/// we're compiling to a structured language, Rust, and states become -/// *program points* rather than *data*, we cannot easily support a -/// DAG structure. In other words, we are not producing a FSM that we -/// can interpret at runtime; rather we are compiling code in which -/// each state corresponds to a sequence of statements and -/// control-flow that branches to a next state, we naturally need -/// nesting; we cannot codegen arbitrary state transitions in an -/// efficient manner. We could support a limited form of DAG that -/// reifies "diamonds" (two alternate paths that reconverge), but -/// supporting this in a way that lets the output refer to values from -/// either side is very complex (we need to invent phi-nodes), and the -/// cases where we want to do this rather than invoke a sub-term (that -/// is compiled to a separate function) are rare. Finally, note that -/// one reason to deduplicate nodes and turn a tree back into a DAG -- -/// "output-suffix sharing" as some other instruction-rewriter -/// engines, such as Peepmatic, do -- is not done, because all -/// "output" occurs at leaf nodes; this is necessary because we do not -/// want to start invoking external constructors until we are sure of -/// the match. Some of the code-sharing advantages of the "suffix -/// sharing" scheme can be obtained in a more flexible and -/// user-controllable way (with less understanding of internal -/// compiler logic needed) by factoring logic into different internal -/// terms, which become different compiled functions. This is likely -/// to happen anyway as part of good software engineering practice. -/// -/// We prepare for codegen by building a "prioritized trie", where the -/// trie associates input strings with priorities to output values. -/// Each input string is a sequence of match operators followed by an -/// "end of match" token, and each output is a sequence of ops that -/// build the output expression. Each input-output mapping is -/// associated with a priority. The goal of the trie is to generate a -/// decision-tree procedure that lets us execute match ops in a -/// deterministic way, eventually landing at a state that corresponds -/// to the highest-priority matching rule and can produce the output. -/// -/// To build this trie, we construct nodes with edges to child nodes; -/// each edge consists of (i) one input token (a `PatternInst` or -/// EOM), and (ii) the priority of rules along this edge. We do not -/// merge rules of different priorities, because the logic to do so is -/// complex and error-prone, necessitating "splits" when we merge -/// together a set of rules over a priority range but later introduce -/// a new possible match op in the "middle" of the range. (E.g., match -/// op A at prio 10, B at prio 5, A at prio 0.) In fact, a previous -/// version of the ISLE compiler worked this way, but in practice the -/// complexity was unneeded. -/// -/// To add a rule to this trie, we perform the usual trie-insertion -/// logic, creating edges and subnodes where necessary. A new edge is -/// necessary whenever an edge does not exist for the (priority, -/// symbol) tuple. -/// -/// Note that this means that multiple edges with a single match-op -/// may exist, with different priorities. -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub enum TrieSymbol { - /// Run a match operation to continue matching a LHS. - Match { - /// The match operation to run. - op: PatternInst, - }, - /// We successfully matched a LHS. - EndOfMatch, -} - -impl TrieSymbol { - fn is_eom(&self) -> bool { - match self { - TrieSymbol::EndOfMatch => true, - _ => false, - } - } -} - -/// A priority. -pub type Prio = i64; - -/// An edge in our term trie. -#[derive(Clone, Debug)] -pub struct TrieEdge { - /// The priority for this edge's sub-trie. - pub prio: Prio, - /// The match operation to perform for this edge. - pub symbol: TrieSymbol, - /// This edge's sub-trie. - pub node: TrieNode, -} - -/// A node in the term trie. -#[derive(Clone, Debug)] -pub enum TrieNode { - /// One or more patterns could match. - /// - /// Maybe one pattern already has matched, but there are more (higher - /// priority and/or same priority but more specific) patterns that could - /// still match. - Decision { - /// The child sub-tries that we can match from this point on. - edges: Vec, - }, - - /// The successful match of an LHS pattern, and here is its RHS expression. - Leaf { - /// The priority of this rule. - prio: Prio, - /// The RHS expression to evaluate upon a successful LHS pattern match. - output: ExprSequence, - }, - - /// No LHS pattern matches. - Empty, -} - -impl TrieNode { - fn is_empty(&self) -> bool { - matches!(self, &TrieNode::Empty) - } - - fn insert( - &mut self, - prio: Prio, - mut input: impl Iterator, - output: ExprSequence, - ) -> bool { - // Take one input symbol. There must be *at least* one, EOM if - // nothing else. - let op = input - .next() - .expect("Cannot insert into trie with empty input sequence"); - let is_last = op.is_eom(); - - // If we are empty, turn into a decision node. - if self.is_empty() { - *self = TrieNode::Decision { edges: vec![] }; - } - - // We must be a decision node. - let edges = match self { - &mut TrieNode::Decision { ref mut edges } => edges, - _ => panic!("insert on leaf node!"), - }; - - // Now find or insert the appropriate edge. - let edge = edges - .iter() - .position(|edge| edge.symbol == op && edge.prio == prio) - .unwrap_or_else(|| { - edges.push(TrieEdge { - prio, - symbol: op, - node: TrieNode::Empty, - }); - edges.len() - 1 - }); - - let edge = &mut edges[edge]; - - if is_last { - if !edge.node.is_empty() { - // If a leaf node already exists at an overlapping - // prio for this op, there are two competing rules, so - // we can't insert this one. - return false; - } - edge.node = TrieNode::Leaf { prio, output }; - true - } else { - edge.node.insert(prio, input, output) - } - } - - /// Sort edges by priority. - pub fn sort(&mut self) { - match self { - TrieNode::Decision { edges } => { - // Sort by priority, highest integer value first; then - // by trie symbol. - edges.sort_by_cached_key(|edge| (-edge.prio, edge.symbol.clone())); - for child in edges { - child.node.sort(); - } - } - _ => {} - } - } - - /// Get a pretty-printed version of this trie, for debugging. - pub fn pretty(&self) -> String { - let mut s = String::new(); - pretty_rec(&mut s, self, ""); - return s; - - fn pretty_rec(s: &mut String, node: &TrieNode, indent: &str) { - match node { - TrieNode::Decision { edges } => { - s.push_str(indent); - s.push_str("TrieNode::Decision:\n"); - - let new_indent = indent.to_owned() + " "; - for edge in edges { - s.push_str(indent); - s.push_str(&format!( - " edge: prio = {:?}, symbol: {:?}\n", - edge.prio, edge.symbol - )); - pretty_rec(s, &edge.node, &new_indent); - } - } - TrieNode::Empty | TrieNode::Leaf { .. } => { - s.push_str(indent); - s.push_str(&format!("{:?}\n", node)); - } - } - } - } -} - -#[derive(Debug, Default)] -struct TermFunctionsBuilder { - builders_by_term: BTreeMap, -} - -impl TermFunctionsBuilder { - fn build(&mut self, termenv: &TermEnv) { - log!("termenv: {:?}", termenv); - for rule in termenv.rules.iter() { - let (pattern, expr) = lower_rule(termenv, rule.id); - - log!( - "build:\n- rule {:?}\n- pattern {:?}\n- expr {:?}", - rule, - pattern, - expr - ); - - let symbols = pattern - .insts - .into_iter() - .map(|op| TrieSymbol::Match { op }) - .chain(std::iter::once(TrieSymbol::EndOfMatch)); - - self.builders_by_term - .entry(rule.root_term) - .or_insert(TrieNode::Empty) - .insert(rule.prio, symbols, expr); - } - - for builder in self.builders_by_term.values_mut() { - builder.sort(); - } - } - - fn finalize(self) -> BTreeMap { - self.builders_by_term - } -}