diff --git a/cranelift/isle/TODO b/cranelift/isle/TODO index 8fd0e60ad2..d588cd0037 100644 --- a/cranelift/isle/TODO +++ b/cranelift/isle/TODO @@ -1,7 +1,3 @@ -- Convert series of if-lets from MatchVariants into a match { ... }. - - Should be pretty simple; just need to recognize adjacent edges in - priority-order that are all MatchVariants on the same input. - - Document semantics carefully, especially wrt extractors. - -- Verify that priorities work as expected. +- Build out an initial set of bindings for Cranelift LowerCtx with extractors + for instruction info. diff --git a/cranelift/isle/isle_examples/test2.isle b/cranelift/isle/isle_examples/test2.isle index 62e905a4ba..5c0977a702 100644 --- a/cranelift/isle/isle_examples/test2.isle +++ b/cranelift/isle/isle_examples/test2.isle @@ -2,7 +2,8 @@ (type A (enum (A1 (x B) (y B)))) (type B (enum - (B1 (x u32)))) + (B1 (x u32)) + (B2 (x u32)))) (decl A2B (A) B) @@ -14,6 +15,10 @@ (A2B (A.A1 (B.B1 x) _)) (B.B1 x)) +(rule 0 + (A2B (A.A1 (B.B2 x) _)) + (B.B1 x)) + (rule -1 (A2B (A.A1 _ _)) (B.B1 42)) diff --git a/cranelift/isle/src/codegen.rs b/cranelift/isle/src/codegen.rs index a4b56d5d0c..12839d819d 100644 --- a/cranelift/isle/src/codegen.rs +++ b/cranelift/isle/src/codegen.rs @@ -2,7 +2,7 @@ use crate::error::Error; use crate::ir::{lower_rule, ExprInst, ExprSequence, InstId, PatternInst, PatternSequence, Value}; -use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId}; +use crate::sema::{RuleId, TermEnv, TermId, TermKind, Type, TypeEnv, TypeId, Variant}; use std::collections::{HashMap, HashSet}; use std::fmt::Write; @@ -309,7 +309,7 @@ impl TrieNode { // ranges. Maybe the last edge we saw with the op // we're inserting can have its range expanded, // however. - if last_edge_with_op.is_some() && edges[last_edge_with_op.unwrap()].symbol == op { + if last_edge_with_op.is_some() { // Move it to the end of the run of equal-unit-range ops. edges.swap(last_edge_with_op.unwrap(), i - 1); edge = Some(i - 1); @@ -901,6 +901,38 @@ impl<'a> Codegen<'a> { Ok(()) } + fn match_variant_binders( + &self, + variant: &Variant, + arg_tys: &[TypeId], + id: InstId, + ctx: &mut BodyContext, + ) -> Vec { + arg_tys + .iter() + .zip(variant.fields.iter()) + .enumerate() + .map(|(i, (ty, field))| { + let value = Value::Pattern { + inst: id, + output: i, + }; + let valuename = self.value_name(&value); + let fieldname = &self.typeenv.syms[field.name.index()]; + match &self.typeenv.types[ty.index()] { + &Type::Primitive(..) => { + self.define_val(&value, ctx, /* is_ref = */ false); + format!("{}: {}", fieldname, valuename) + } + &Type::Enum { .. } => { + self.define_val(&value, ctx, /* is_ref = */ true); + format!("{}: ref {}", fieldname, valuename) + } + } + }) + .collect::>() + } + /// Returns a `bool` indicating whether this pattern inst is /// infallible. fn generate_pattern_inst( @@ -961,29 +993,7 @@ impl<'a> Codegen<'a> { let ty_name = self.type_name(input_ty, /* is_ref = */ Some("&")); let variant = &variants[variant.index()]; let variantname = &self.typeenv.syms[variant.name.index()]; - let args = arg_tys - .iter() - .zip(variant.fields.iter()) - .enumerate() - .map(|(i, (ty, field))| { - let value = Value::Pattern { - inst: id, - output: i, - }; - let valuename = self.value_name(&value); - let fieldname = &self.typeenv.syms[field.name.index()]; - match &self.typeenv.types[ty.index()] { - &Type::Primitive(..) => { - self.define_val(&value, ctx, /* is_ref = */ false); - format!("{}: {}", fieldname, valuename) - } - &Type::Enum { .. } => { - self.define_val(&value, ctx, /* is_ref = */ true); - format!("{}: ref {}", fieldname, valuename) - } - } - }) - .collect::>(); + let args = self.match_variant_binders(variant, &arg_tys[..], id, ctx); writeln!( code, "{}if let {}::{} {{ {} }} = {} {{", @@ -1097,34 +1107,162 @@ impl<'a> Codegen<'a> { &TrieNode::Decision { ref edges } => { let subindent = format!("{} ", indent); // if this is a decision node, generate each match op - // in turn (in priority order). - for &TrieEdge { - ref symbol, - ref node, - .. - } in edges - { - match symbol { - &TrieSymbol::EndOfMatch => { - returned = self.generate_body(code, depth + 1, node, indent, ctx)?; - } - &TrieSymbol::Match { ref op } => { - let id = InstId(depth); - let infallible = - self.generate_pattern_inst(code, id, op, indent, ctx)?; - let sub_returned = - self.generate_body(code, depth + 1, node, &subindent, ctx)?; - writeln!(code, "{}}}", indent)?; - if infallible && sub_returned { - returned = true; + // in turn (in priority order). Sort the ops within + // each priority, and gather together adjacent + // MatchVariant ops with the same input and disjoint + // variants in order to create a `match` rather than a + // chain of if-lets. + let mut edges = edges.clone(); + edges.sort_by(|e1, e2| (-e1.range.0, &e1.symbol).cmp(&(-e2.range.0, &e2.symbol))); + + let mut i = 0; + while i < edges.len() { + let mut last = i; + let mut adjacent_variants = HashSet::new(); + let mut adjacent_variant_input = None; + log::trace!("edge: {:?}", edges[i]); + while last < edges.len() { + match &edges[last].symbol { + &TrieSymbol::Match { + op: PatternInst::MatchVariant { input, variant, .. }, + } => { + if adjacent_variant_input.is_none() { + adjacent_variant_input = Some(input); + } + if adjacent_variant_input == Some(input) + && !adjacent_variants.contains(&variant) + { + adjacent_variants.insert(variant); + last += 1; + } else { + break; + } + } + _ => { break; } } } + + // edges[i..last] is a run of adjacent + // MatchVariants (possibly an empty one). Only use + // a `match` form if there are at least two + // adjacent options. + if last - i > 1 { + self.generate_body_matches(code, depth, &edges[i..last], indent, ctx)?; + i = last; + continue; + } else { + let &TrieEdge { + ref symbol, + ref node, + .. + } = &edges[i]; + i += 1; + + match symbol { + &TrieSymbol::EndOfMatch => { + returned = + self.generate_body(code, depth + 1, node, indent, ctx)?; + } + &TrieSymbol::Match { ref op } => { + let id = InstId(depth); + let infallible = + self.generate_pattern_inst(code, id, op, indent, ctx)?; + let sub_returned = + self.generate_body(code, depth + 1, node, &subindent, ctx)?; + writeln!(code, "{}}}", indent)?; + if infallible && sub_returned { + returned = true; + break; + } + } + } + } } } } Ok(returned) } + + fn generate_body_matches( + &self, + code: &mut dyn Write, + depth: usize, + edges: &[TrieEdge], + indent: &str, + ctx: &mut BodyContext, + ) -> Result<(), Error> { + let (input, input_ty) = match &edges[0].symbol { + &TrieSymbol::Match { + op: + PatternInst::MatchVariant { + input, input_ty, .. + }, + } => (input, input_ty), + _ => unreachable!(), + }; + let (input_ty_sym, variants) = match &self.typeenv.types[input_ty.index()] { + &Type::Enum { + ref name, + ref variants, + .. + } => (name, variants), + _ => unreachable!(), + }; + let input_ty_name = &self.typeenv.syms[input_ty_sym.index()]; + + // Emit the `match`. + writeln!( + code, + "{}match {} {{", + indent, + self.value_by_ref(&input, ctx) + )?; + + // Emit each case. + for &TrieEdge { + ref symbol, + ref node, + .. + } in edges + { + let id = InstId(depth); + let (variant, arg_tys) = match symbol { + &TrieSymbol::Match { + op: + PatternInst::MatchVariant { + variant, + ref arg_tys, + .. + }, + } => (variant, arg_tys), + _ => unreachable!(), + }; + + let variantinfo = &variants[variant.index()]; + let variantname = &self.typeenv.syms[variantinfo.name.index()]; + let fields = self.match_variant_binders(variantinfo, arg_tys, id, ctx); + writeln!( + code, + "{} &{}::{} {{ {} }} => {{", + indent, + input_ty_name, + variantname, + fields.join(", ") + )?; + let subindent = format!("{} ", indent); + self.generate_body(code, depth + 1, node, &subindent, ctx)?; + writeln!(code, "{} }}", indent)?; + } + + // Always add a catchall, because we don't do exhaustiveness + // checking on the MatcHVariants. + writeln!(code, "{} _ => {{}}", indent)?; + + writeln!(code, "{}}}", indent)?; + + Ok(()) + } }