use crate::cdsl::ast::{ Apply, DefIndex, DefPool, DummyDef, DummyExpr, Expr, PatternPosition, VarIndex, VarPool, }; use crate::cdsl::instructions::Instruction; use crate::cdsl::type_inference::{infer_transform, TypeEnvironment}; use crate::cdsl::typevar::TypeVar; use cranelift_entity::{entity_impl, PrimaryMap}; use std::collections::{HashMap, HashSet}; use std::iter::FromIterator; /// An instruction transformation consists of a source and destination pattern. /// /// Patterns are expressed in *register transfer language* as tuples of Def or Expr nodes. A /// pattern may optionally have a sequence of TypeConstraints, that additionally limit the set of /// cases when it applies. /// /// The source pattern can contain only a single instruction. pub struct Transform { pub src: DefIndex, pub dst: Vec, pub var_pool: VarPool, pub def_pool: DefPool, pub type_env: TypeEnvironment, } type SymbolTable = HashMap<&'static str, VarIndex>; impl Transform { fn new(src: DummyDef, dst: Vec) -> Self { let mut var_pool = VarPool::new(); let mut def_pool = DefPool::new(); let mut input_vars: Vec = Vec::new(); let mut defined_vars: Vec = Vec::new(); // Maps variable names to our own Var copies. let mut symbol_table: SymbolTable = SymbolTable::new(); // Rewrite variables in src and dst using our own copies. let src = rewrite_def_list( PatternPosition::Source, vec![src], &mut symbol_table, &mut input_vars, &mut defined_vars, &mut var_pool, &mut def_pool, )[0]; let num_src_inputs = input_vars.len(); let dst = rewrite_def_list( PatternPosition::Destination, dst, &mut symbol_table, &mut input_vars, &mut defined_vars, &mut var_pool, &mut def_pool, ); // Sanity checks. for &var_index in &input_vars { assert!( var_pool.get(var_index).is_input(), format!("'{:?}' used as both input and def", var_pool.get(var_index)) ); } assert!( input_vars.len() == num_src_inputs, format!( "extra input vars in dst pattern: {:?}", input_vars .iter() .map(|&i| var_pool.get(i)) .skip(num_src_inputs) .collect::>() ) ); // Perform type inference and cleanup. let type_env = infer_transform(src, &dst, &def_pool, &mut var_pool).unwrap(); // Sanity check: the set of inferred free type variables should be a subset of the type // variables corresponding to Vars appearing in the source pattern. { let free_typevars: HashSet = HashSet::from_iter(type_env.free_typevars(&mut var_pool)); let src_tvs = HashSet::from_iter( input_vars .clone() .iter() .chain( defined_vars .iter() .filter(|&&var_index| !var_pool.get(var_index).is_temp()), ) .map(|&var_index| var_pool.get(var_index).get_typevar()) .filter(|maybe_var| maybe_var.is_some()) .map(|var| var.unwrap()), ); if !free_typevars.is_subset(&src_tvs) { let missing_tvs = (&free_typevars - &src_tvs) .iter() .map(|tv| tv.name.clone()) .collect::>() .join(", "); panic!("Some free vars don't appear in src: {}", missing_tvs); } } for &var_index in input_vars.iter().chain(defined_vars.iter()) { let var = var_pool.get_mut(var_index); let canon_tv = type_env.get_equivalent(&var.get_or_create_typevar()); var.set_typevar(canon_tv); } Self { src, dst, var_pool, def_pool, type_env, } } fn verify_legalize(&self) { let def = self.def_pool.get(self.src); for &var_index in def.defined_vars.iter() { let defined_var = self.var_pool.get(var_index); assert!( defined_var.is_output(), format!("{:?} not defined in the destination pattern", defined_var) ); } } } /// Given a list of symbols defined in a Def, rewrite them to local symbols. Yield the new locals. fn rewrite_defined_vars( position: PatternPosition, dummy_def: &DummyDef, def_index: DefIndex, symbol_table: &mut SymbolTable, defined_vars: &mut Vec, var_pool: &mut VarPool, ) -> Vec { let mut new_defined_vars = Vec::new(); for var in &dummy_def.defined_vars { let own_var = match symbol_table.get(var.name) { Some(&existing_var) => existing_var, None => { // Materialize the variable. let new_var = var_pool.create(var.name); symbol_table.insert(var.name, new_var); defined_vars.push(new_var); new_var } }; var_pool.get_mut(own_var).set_def(position, def_index); new_defined_vars.push(own_var); } new_defined_vars } /// Find all uses of variables in `expr` and replace them with our own local symbols. fn rewrite_expr( position: PatternPosition, dummy_expr: DummyExpr, symbol_table: &mut SymbolTable, input_vars: &mut Vec, var_pool: &mut VarPool, ) -> Apply { let (apply_target, dummy_args) = if let DummyExpr::Apply(apply_target, dummy_args) = dummy_expr { (apply_target, dummy_args) } else { panic!("we only rewrite apply expressions"); }; assert_eq!( apply_target.inst().operands_in.len(), dummy_args.len(), "number of arguments in instruction is incorrect" ); let mut args = Vec::new(); for (i, arg) in dummy_args.into_iter().enumerate() { match arg { DummyExpr::Var(var) => { let own_var = match symbol_table.get(var.name) { Some(&own_var) => { let var = var_pool.get(own_var); assert!( var.is_input() || var.get_def(position).is_some(), format!("{:?} used as both input and def", var) ); own_var } None => { // First time we're using this variable. let own_var = var_pool.create(var.name); symbol_table.insert(var.name, own_var); input_vars.push(own_var); own_var } }; args.push(Expr::Var(own_var)); } DummyExpr::Literal(literal) => { assert!(!apply_target.inst().operands_in[i].is_value()); args.push(Expr::Literal(literal)); } DummyExpr::Apply(..) => { panic!("Recursive apply is not allowed."); } } } Apply::new(apply_target, args) } fn rewrite_def_list( position: PatternPosition, dummy_defs: Vec, symbol_table: &mut SymbolTable, input_vars: &mut Vec, defined_vars: &mut Vec, var_pool: &mut VarPool, def_pool: &mut DefPool, ) -> Vec { let mut new_defs = Vec::new(); for dummy_def in dummy_defs { let def_index = def_pool.next_index(); let new_defined_vars = rewrite_defined_vars( position, &dummy_def, def_index, symbol_table, defined_vars, var_pool, ); let new_apply = rewrite_expr(position, dummy_def.expr, symbol_table, input_vars, var_pool); assert!( def_pool.next_index() == def_index, "shouldn't have created new defs in the meanwhile" ); assert_eq!( new_apply.inst.value_results.len(), new_defined_vars.len(), "number of Var results in instruction is incorrect" ); new_defs.push(def_pool.create(new_apply, new_defined_vars)); } new_defs } /// A group of related transformations. pub struct TransformGroup { pub name: &'static str, pub doc: &'static str, pub chain_with: Option, pub isa_name: Option<&'static str>, pub id: TransformGroupIndex, /// Maps Instruction camel_case names to custom legalization functions names. pub custom_legalizes: HashMap, pub transforms: Vec, } impl TransformGroup { pub fn rust_name(&self) -> String { match self.isa_name { Some(_) => { // This is a function in the same module as the LEGALIZE_ACTIONS table referring to // it. self.name.to_string() } None => format!("crate::legalizer::{}", self.name), } } } #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct TransformGroupIndex(u32); entity_impl!(TransformGroupIndex); pub struct TransformGroupBuilder { name: &'static str, doc: &'static str, chain_with: Option, isa_name: Option<&'static str>, pub custom_legalizes: HashMap, pub transforms: Vec, } impl TransformGroupBuilder { pub fn new(name: &'static str, doc: &'static str) -> Self { Self { name, doc, chain_with: None, isa_name: None, custom_legalizes: HashMap::new(), transforms: Vec::new(), } } pub fn chain_with(mut self, next_id: TransformGroupIndex) -> Self { assert!(self.chain_with.is_none()); self.chain_with = Some(next_id); self } pub fn isa(mut self, isa_name: &'static str) -> Self { assert!(self.isa_name.is_none()); self.isa_name = Some(isa_name); self } /// Add a custom legalization action for `inst`. /// /// The `func_name` parameter is the fully qualified name of a Rust function which takes the /// same arguments as the `isa::Legalize` actions. /// /// The custom function will be called to legalize `inst` and any return value is ignored. pub fn custom_legalize(&mut self, inst: &Instruction, func_name: &'static str) { assert!( self.custom_legalizes .insert(inst.camel_name.clone(), func_name) .is_none(), format!( "custom legalization action for {} inserted twice", inst.name ) ); } /// Add a legalization pattern to this group. pub fn legalize(&mut self, src: DummyDef, dst: Vec) { let transform = Transform::new(src, dst); transform.verify_legalize(); self.transforms.push(transform); } pub fn build_and_add_to(self, owner: &mut TransformGroups) -> TransformGroupIndex { let next_id = owner.next_key(); owner.add(TransformGroup { name: self.name, doc: self.doc, isa_name: self.isa_name, id: next_id, chain_with: self.chain_with, custom_legalizes: self.custom_legalizes, transforms: self.transforms, }) } } pub struct TransformGroups { groups: PrimaryMap, } impl TransformGroups { pub fn new() -> Self { Self { groups: PrimaryMap::new(), } } pub fn add(&mut self, new_group: TransformGroup) -> TransformGroupIndex { for group in self.groups.values() { assert!( group.name != new_group.name, format!("trying to insert {} for the second time", new_group.name) ); } self.groups.push(new_group) } pub fn get(&self, id: TransformGroupIndex) -> &TransformGroup { &self.groups[id] } pub fn get_mut(&mut self, id: TransformGroupIndex) -> &mut TransformGroup { self.groups.get_mut(id).unwrap() } fn next_key(&self) -> TransformGroupIndex { self.groups.next_key() } pub fn by_name(&self, name: &'static str) -> &TransformGroup { for group in self.groups.values() { if group.name == name { return group; } } panic!(format!("transform group with name {} not found", name)); } } #[test] #[should_panic] fn test_double_custom_legalization() { use crate::cdsl::formats::{FormatRegistry, InstructionFormatBuilder}; use crate::cdsl::instructions::{InstructionBuilder, InstructionGroupBuilder}; let mut format = FormatRegistry::new(); format.insert(InstructionFormatBuilder::new("nullary")); let mut inst_group = InstructionGroupBuilder::new("test", "", &format); inst_group.push(InstructionBuilder::new("dummy", "doc")); let inst_group = inst_group.build(); let dummy_inst = inst_group.by_name("dummy"); let mut transform_group = TransformGroupBuilder::new("test", "doc"); transform_group.custom_legalize(&dummy_inst, "custom 1"); transform_group.custom_legalize(&dummy_inst, "custom 2"); }