diff --git a/cranelift/codegen/meta/src/cdsl/ast.rs b/cranelift/codegen/meta/src/cdsl/ast.rs new file mode 100644 index 0000000000..1df8ea3050 --- /dev/null +++ b/cranelift/codegen/meta/src/cdsl/ast.rs @@ -0,0 +1,653 @@ +use crate::cdsl::formats::FormatRegistry; +use crate::cdsl::inst::{BoundInstruction, Instruction, InstructionPredicate}; +use crate::cdsl::operands::{OperandKind, OperandKindFields}; +use crate::cdsl::types::{LaneType, ValueType}; +use crate::cdsl::typevar::{TypeSetBuilder, TypeVar}; + +use cranelift_entity::{entity_impl, PrimaryMap}; + +use std::fmt; + +pub enum Expr { + Var(VarIndex), + Literal(Literal), + Apply(Apply), +} + +impl Expr { + pub fn maybe_literal(&self) -> Option<&Literal> { + match &self { + Expr::Literal(lit) => Some(lit), + _ => None, + } + } + + pub fn maybe_var(&self) -> Option { + if let Expr::Var(var) = &self { + Some(*var) + } else { + None + } + } + + pub fn unwrap_var(&self) -> VarIndex { + self.maybe_var() + .expect("tried to unwrap a non-Var content in Expr::unwrap_var") + } + + pub fn to_rust_code(&self, var_pool: &VarPool) -> String { + match self { + Expr::Var(var_index) => var_pool.get(*var_index).to_rust_code(), + Expr::Literal(literal) => literal.to_rust_code(), + Expr::Apply(a) => a.to_rust_code(var_pool), + } + } +} + +/// An AST definition associates a set of variables with the values produced by an expression. +pub struct Def { + pub apply: Apply, + pub defined_vars: Vec, +} + +impl Def { + pub fn to_comment_string(&self, var_pool: &VarPool) -> String { + let results = self + .defined_vars + .iter() + .map(|&x| var_pool.get(x).name) + .collect::>(); + + let results = if results.len() == 1 { + results[0].to_string() + } else { + format!("({})", results.join(", ")) + }; + + format!("{} << {}", results, self.apply.to_comment_string(var_pool)) + } +} + +pub struct DefPool { + pool: PrimaryMap, +} + +impl DefPool { + pub fn new() -> Self { + Self { + pool: PrimaryMap::new(), + } + } + pub fn get(&self, index: DefIndex) -> &Def { + self.pool.get(index).unwrap() + } + pub fn get_mut(&mut self, index: DefIndex) -> &mut Def { + self.pool.get_mut(index).unwrap() + } + pub fn next_index(&self) -> DefIndex { + self.pool.next_key() + } + pub fn create(&mut self, apply: Apply, defined_vars: Vec) -> DefIndex { + self.pool.push(Def { + apply, + defined_vars, + }) + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct DefIndex(u32); +entity_impl!(DefIndex); + +#[derive(Debug, Clone)] +enum LiteralValue { + /// A value of an enumerated immediate operand. + /// + /// Some immediate operand kinds like `intcc` and `floatcc` have an enumerated range of values + /// corresponding to a Rust enum type. An `Enumerator` object is an AST leaf node representing one + /// of the values. + Enumerator(&'static str), + + /// A bitwise value of an immediate operand, used for bitwise exact floating point constants. + Bits(u64), + + /// A value of an integer immediate operand. + Int(i64), +} + +#[derive(Clone)] +pub struct Literal { + kind: OperandKind, + value: LiteralValue, +} + +impl fmt::Debug for Literal { + fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { + write!( + fmt, + "Literal(kind={}, value={:?})", + self.kind.name, self.value + ) + } +} + +impl Literal { + pub fn enumerator_for(kind: &OperandKind, value: &'static str) -> Self { + if let OperandKindFields::ImmEnum(values) = &kind.fields { + assert!( + values.get(value).is_some(), + format!( + "nonexistent value '{}' in enumeration '{}'", + value, kind.name + ) + ); + } else { + panic!("enumerator is for enum values"); + } + Self { + kind: kind.clone(), + value: LiteralValue::Enumerator(value), + } + } + + pub fn bits(kind: &OperandKind, bits: u64) -> Self { + match kind.fields { + OperandKindFields::ImmValue => {} + _ => panic!("bits_of is for immediate scalar types"), + } + Self { + kind: kind.clone(), + value: LiteralValue::Bits(bits), + } + } + + pub fn constant(kind: &OperandKind, value: i64) -> Self { + match kind.fields { + OperandKindFields::ImmValue => {} + _ => panic!("bits_of is for immediate scalar types"), + } + Self { + kind: kind.clone(), + value: LiteralValue::Int(value), + } + } + + pub fn to_rust_code(&self) -> String { + let maybe_values = match &self.kind.fields { + OperandKindFields::ImmEnum(values) => Some(values), + OperandKindFields::ImmValue => None, + _ => panic!("impossible per construction"), + }; + + match self.value { + LiteralValue::Enumerator(value) => { + format!("{}::{}", self.kind.rust_type, maybe_values.unwrap()[value]) + } + LiteralValue::Bits(bits) => format!("{}::with_bits({:#x})", self.kind.rust_type, bits), + LiteralValue::Int(val) => val.to_string(), + } + } +} + +#[derive(Clone, Copy, Debug)] +pub enum PatternPosition { + Source, + Destination, +} + +/// A free variable. +/// +/// When variables are used in `XForms` with source and destination patterns, they are classified +/// as follows: +/// +/// Input values: Uses in the source pattern with no preceding def. These may appear as inputs in +/// the destination pattern too, but no new inputs can be introduced. +/// +/// Output values: Variables that are defined in both the source and destination pattern. These +/// values may have uses outside the source pattern, and the destination pattern must compute the +/// same value. +/// +/// Intermediate values: Values that are defined in the source pattern, but not in the destination +/// pattern. These may have uses outside the source pattern, so the defining instruction can't be +/// deleted immediately. +/// +/// Temporary values are defined only in the destination pattern. +pub struct Var { + pub name: &'static str, + + /// The `Def` defining this variable in a source pattern. + pub src_def: Option, + + /// The `Def` defining this variable in a destination pattern. + pub dst_def: Option, + + /// TypeVar representing the type of this variable. + type_var: Option, + + /// Is this the original type variable, or has it be redefined with set_typevar? + is_original_type_var: bool, +} + +impl Var { + fn new(name: &'static str) -> Self { + Self { + name, + src_def: None, + dst_def: None, + type_var: None, + is_original_type_var: false, + } + } + + /// Is this an input value to the src pattern? + pub fn is_input(&self) -> bool { + self.src_def.is_none() && self.dst_def.is_none() + } + + /// Is this an output value, defined in both src and dst patterns? + pub fn is_output(&self) -> bool { + self.src_def.is_some() && self.dst_def.is_some() + } + + /// Is this an intermediate value, defined only in the src pattern? + pub fn is_intermediate(&self) -> bool { + self.src_def.is_some() && self.dst_def.is_none() + } + + /// Is this a temp value, defined only in the dst pattern? + pub fn is_temp(&self) -> bool { + self.src_def.is_none() && self.dst_def.is_some() + } + + /// Get the def of this variable according to the position. + pub fn get_def(&self, position: PatternPosition) -> Option { + match position { + PatternPosition::Source => self.src_def, + PatternPosition::Destination => self.dst_def, + } + } + + pub fn set_def(&mut self, position: PatternPosition, def: DefIndex) { + assert!( + self.get_def(position).is_none(), + format!("redefinition of variable {}", self.name) + ); + match position { + PatternPosition::Source => { + self.src_def = Some(def); + } + PatternPosition::Destination => { + self.dst_def = Some(def); + } + } + } + + /// Get the type variable representing the type of this variable. + pub fn get_or_create_typevar(&mut self) -> TypeVar { + match &self.type_var { + Some(tv) => tv.clone(), + None => { + // Create a new type var in which we allow all types. + let tv = TypeVar::new( + format!("typeof_{}", self.name), + format!("Type of the pattern variable {:?}", self), + TypeSetBuilder::all(), + ); + self.type_var = Some(tv.clone()); + self.is_original_type_var = true; + tv + } + } + } + pub fn get_typevar(&self) -> Option { + self.type_var.clone() + } + pub fn set_typevar(&mut self, tv: TypeVar) { + self.is_original_type_var = if let Some(previous_tv) = &self.type_var { + *previous_tv == tv + } else { + false + }; + self.type_var = Some(tv); + } + + /// Check if this variable has a free type variable. If not, the type of this variable is + /// computed from the type of another variable. + pub fn has_free_typevar(&self) -> bool { + match &self.type_var { + Some(tv) => tv.base.is_none() && self.is_original_type_var, + None => false, + } + } + + pub fn to_rust_code(&self) -> String { + self.name.into() + } + fn rust_type(&self) -> String { + self.type_var.as_ref().unwrap().to_rust_code() + } +} + +impl fmt::Debug for Var { + fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { + fmt.write_fmt(format_args!( + "Var({}{}{})", + self.name, + if self.src_def.is_some() { ", src" } else { "" }, + if self.dst_def.is_some() { ", dst" } else { "" } + )) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct VarIndex(u32); +entity_impl!(VarIndex); + +pub struct VarPool { + pool: PrimaryMap, +} + +impl VarPool { + pub fn new() -> Self { + Self { + pool: PrimaryMap::new(), + } + } + pub fn get(&self, index: VarIndex) -> &Var { + self.pool.get(index).unwrap() + } + pub fn get_mut(&mut self, index: VarIndex) -> &mut Var { + self.pool.get_mut(index).unwrap() + } + pub fn create(&mut self, name: &'static str) -> VarIndex { + self.pool.push(Var::new(name)) + } +} + +pub enum ApplyTarget { + Inst(Instruction), + Bound(BoundInstruction), +} + +impl ApplyTarget { + pub fn inst(&self) -> &Instruction { + match &self { + ApplyTarget::Inst(inst) => inst, + ApplyTarget::Bound(bound_inst) => &bound_inst.inst, + } + } +} + +impl Into for &Instruction { + fn into(self) -> ApplyTarget { + ApplyTarget::Inst(self.clone()) + } +} + +impl Into for BoundInstruction { + fn into(self) -> ApplyTarget { + ApplyTarget::Bound(self) + } +} + +pub fn bind(target: impl Into, lane_type: impl Into) -> BoundInstruction { + let value_type = ValueType::from(lane_type.into()); + + let (inst, value_types) = match target.into() { + ApplyTarget::Inst(inst) => (inst, vec![value_type]), + ApplyTarget::Bound(bound_inst) => { + let mut new_value_types = bound_inst.value_types; + new_value_types.push(value_type); + (bound_inst.inst, new_value_types) + } + }; + + match &inst.polymorphic_info { + Some(poly) => { + assert!( + value_types.len() <= 1 + poly.other_typevars.len(), + format!("trying to bind too many types for {}", inst.name) + ); + } + None => { + panic!(format!( + "trying to bind a type for {} which is not a polymorphic instruction", + inst.name + )); + } + } + + BoundInstruction { inst, value_types } +} + +/// Apply an instruction to arguments. +/// +/// An `Apply` AST expression is created by using function call syntax on instructions. This +/// applies to both bound and unbound polymorphic instructions. +pub struct Apply { + pub inst: Instruction, + pub args: Vec, + pub value_types: Vec, +} + +impl Apply { + pub fn new(target: ApplyTarget, args: Vec) -> Self { + let (inst, value_types) = match target.into() { + ApplyTarget::Inst(inst) => (inst, Vec::new()), + ApplyTarget::Bound(bound_inst) => (bound_inst.inst, bound_inst.value_types), + }; + + // Basic check on number of arguments. + assert!( + inst.operands_in.len() == args.len(), + format!("incorrect number of arguments in instruction {}", inst.name) + ); + + // Check that the kinds of Literals arguments match the expected operand. + for &imm_index in &inst.imm_opnums { + let arg = &args[imm_index]; + if let Some(literal) = arg.maybe_literal() { + let op = &inst.operands_in[imm_index]; + assert!( + op.kind.name == literal.kind.name, + format!( + "Passing literal of kind {} to field of wrong kind {}", + literal.kind.name, op.kind.name + ) + ); + } + } + + Self { + inst, + args, + value_types, + } + } + + fn to_comment_string(&self, var_pool: &VarPool) -> String { + let args = self + .args + .iter() + .map(|arg| arg.to_rust_code(var_pool)) + .collect::>() + .join(", "); + + let mut inst_and_bound_types = vec![self.inst.name.to_string()]; + inst_and_bound_types.extend(self.value_types.iter().map(|vt| vt.to_string())); + let inst_name = inst_and_bound_types.join("."); + + format!("{}({})", inst_name, args) + } + + fn to_rust_code(&self, var_pool: &VarPool) -> String { + let args = self + .args + .iter() + .map(|arg| arg.to_rust_code(var_pool)) + .collect::>() + .join(", "); + format!("{}({})", self.inst.name, args) + } + + fn inst_predicate( + &self, + format_registry: &FormatRegistry, + var_pool: &VarPool, + ) -> InstructionPredicate { + let iform = format_registry.get(self.inst.format); + + let mut pred = InstructionPredicate::new(); + for (format_field, &op_num) in iform.imm_fields.iter().zip(self.inst.imm_opnums.iter()) { + let arg = &self.args[op_num]; + if arg.maybe_var().is_some() { + // Ignore free variables for now. + continue; + } + pred = pred.and(InstructionPredicate::new_is_field_equal( + &format_field, + arg.to_rust_code(var_pool), + )); + } + + // Add checks for any bound secondary type variables. We can't check the controlling type + // variable this way since it may not appear as the type of an operand. + if self.value_types.len() > 1 { + let poly = self + .inst + .polymorphic_info + .as_ref() + .expect("must have polymorphic info if it has bounded types"); + for (bound_type, type_var) in + self.value_types[1..].iter().zip(poly.other_typevars.iter()) + { + pred = pred.and(InstructionPredicate::new_typevar_check( + &self.inst, type_var, bound_type, + )); + } + } + + pred + } + + /// Same as `inst_predicate()`, but also check the controlling type variable. + pub fn inst_predicate_with_ctrl_typevar( + &self, + format_registry: &FormatRegistry, + var_pool: &VarPool, + ) -> InstructionPredicate { + let mut pred = self.inst_predicate(format_registry, var_pool); + + if !self.value_types.is_empty() { + let bound_type = &self.value_types[0]; + let poly = self.inst.polymorphic_info.as_ref().unwrap(); + let type_check = if poly.use_typevar_operand { + InstructionPredicate::new_typevar_check(&self.inst, &poly.ctrl_typevar, bound_type) + } else { + InstructionPredicate::new_ctrl_typevar_check(&bound_type) + }; + pred = pred.and(type_check); + } + + pred + } + + pub fn rust_builder(&self, defined_vars: &Vec, var_pool: &VarPool) -> String { + let mut args = self + .args + .iter() + .map(|expr| expr.to_rust_code(var_pool)) + .collect::>() + .join(", "); + + // Do we need to pass an explicit type argument? + if let Some(poly) = &self.inst.polymorphic_info { + if !poly.use_typevar_operand { + args = format!("{}, {}", var_pool.get(defined_vars[0]).rust_type(), args); + } + } + + format!("{}({})", self.inst.snake_name(), args) + } +} + +// Simple helpers for legalize actions construction. + +pub enum DummyExpr { + Var(DummyVar), + Literal(Literal), + Apply(ApplyTarget, Vec), +} + +#[derive(Clone)] +pub struct DummyVar { + pub name: &'static str, +} + +impl Into for DummyVar { + fn into(self) -> DummyExpr { + DummyExpr::Var(self) + } +} +impl Into for Literal { + fn into(self) -> DummyExpr { + DummyExpr::Literal(self) + } +} + +pub fn var(name: &'static str) -> DummyVar { + DummyVar { name } +} + +pub struct DummyDef { + pub expr: DummyExpr, + pub defined_vars: Vec, +} + +pub struct ExprBuilder { + expr: DummyExpr, +} + +impl ExprBuilder { + pub fn apply(inst: ApplyTarget, args: Vec) -> Self { + let expr = DummyExpr::Apply(inst, args); + Self { expr } + } + + pub fn assign_to(self, defined_vars: Vec) -> DummyDef { + DummyDef { + expr: self.expr, + defined_vars, + } + } +} + +macro_rules! def_rhs { + // inst(a, b, c) + ($inst:ident($($src:expr),*)) => { + ExprBuilder::apply($inst.into(), vec![$($src.clone().into()),*]) + }; + + // inst.type(a, b, c) + ($inst:ident.$type:ident($($src:expr),*)) => { + ExprBuilder::apply(bind($inst, $type).into(), vec![$($src.clone().into()),*]) + }; +} + +// Helper macro to define legalization recipes. +macro_rules! def { + // x = ... + ($dest:ident = $($tt:tt)*) => { + def_rhs!($($tt)*).assign_to(vec![$dest.clone()]) + }; + + // (x, y, ...) = ... + (($($dest:ident),*) = $($tt:tt)*) => { + def_rhs!($($tt)*).assign_to(vec![$($dest.clone()),*]) + }; + + // An instruction with no results. + ($($tt:tt)*) => { + def_rhs!($($tt)*).assign_to(Vec::new()) + } +} diff --git a/cranelift/codegen/meta/src/cdsl/mod.rs b/cranelift/codegen/meta/src/cdsl/mod.rs index 540370624a..a6c8be95fe 100644 --- a/cranelift/codegen/meta/src/cdsl/mod.rs +++ b/cranelift/codegen/meta/src/cdsl/mod.rs @@ -3,6 +3,8 @@ //! This module defines the classes that are used to define Cranelift //! instructions and other entities. +#[macro_use] +pub mod ast; pub mod formats; pub mod inst; pub mod isa; @@ -12,6 +14,7 @@ pub mod settings; pub mod type_inference; pub mod types; pub mod typevar; +pub mod xform; /// A macro that converts boolean settings into predicates to look more natural. #[macro_export] diff --git a/cranelift/codegen/meta/src/cdsl/type_inference.rs b/cranelift/codegen/meta/src/cdsl/type_inference.rs index de104f4b47..e3a1e9bb06 100644 --- a/cranelift/codegen/meta/src/cdsl/type_inference.rs +++ b/cranelift/codegen/meta/src/cdsl/type_inference.rs @@ -1,5 +1,658 @@ -use crate::cdsl::typevar::TypeVar; +use crate::cdsl::ast::{Def, DefIndex, DefPool, Var, VarIndex, VarPool}; +use crate::cdsl::typevar::{DerivedFunc, TypeSet, TypeVar}; +use std::collections::{HashMap, HashSet}; +use std::iter::FromIterator; + +#[derive(Hash, PartialEq, Eq)] pub enum Constraint { + /// Constraint specifying that a type var tv1 must be wider than or equal to type var tv2 at + /// runtime. This requires that: + /// 1) They have the same number of lanes + /// 2) In a lane tv1 has at least as many bits as tv2. WiderOrEq(TypeVar, TypeVar), + + /// Constraint specifying that two derived type vars must have the same runtime type. + Eq(TypeVar, TypeVar), + + /// Constraint specifying that a type var must belong to some typeset. + InTypeset(TypeVar, TypeSet), +} + +impl Constraint { + fn translate_with TypeVar>(&self, func: F) -> Constraint { + match self { + Constraint::WiderOrEq(lhs, rhs) => { + let lhs = func(&lhs); + let rhs = func(&rhs); + Constraint::WiderOrEq(lhs, rhs) + } + Constraint::Eq(lhs, rhs) => { + let lhs = func(&lhs); + let rhs = func(&rhs); + Constraint::Eq(lhs, rhs) + } + Constraint::InTypeset(tv, ts) => { + let tv = func(&tv); + Constraint::InTypeset(tv, ts.clone()) + } + } + } + + /// Creates a new constraint by replacing type vars by their hashmap equivalent. + fn translate_with_map( + &self, + original_to_own_typevar: &HashMap<&TypeVar, TypeVar>, + ) -> Constraint { + self.translate_with(|tv| substitute(original_to_own_typevar, tv)) + } + + /// Creates a new constraint by replacing type vars by their canonical equivalent. + fn translate_with_env(&self, type_env: &TypeEnvironment) -> Constraint { + self.translate_with(|tv| type_env.get_equivalent(tv)) + } + + fn is_trivial(&self) -> bool { + match self { + Constraint::WiderOrEq(lhs, rhs) => { + // Trivially true. + if lhs == rhs { + return true; + } + + let ts1 = lhs.get_typeset(); + let ts2 = rhs.get_typeset(); + + // Trivially true. + if ts1.is_wider_or_equal(&ts2) { + return true; + } + + // Trivially false. + if ts1.is_narrower(&ts2) { + return true; + } + + // Trivially false. + if (&ts1.lanes & &ts2.lanes).len() == 0 { + return true; + } + + self.is_concrete() + } + Constraint::Eq(lhs, rhs) => lhs == rhs || self.is_concrete(), + Constraint::InTypeset(_, _) => { + // The way InTypeset are made, they would always be trivial if we were applying the + // same logic as the Python code did, so ignore this. + self.is_concrete() + } + } + } + + /// Returns true iff all the referenced type vars are singletons. + fn is_concrete(&self) -> bool { + match self { + Constraint::WiderOrEq(lhs, rhs) => { + lhs.singleton_type().is_some() && rhs.singleton_type().is_some() + } + Constraint::Eq(lhs, rhs) => { + lhs.singleton_type().is_some() && rhs.singleton_type().is_some() + } + Constraint::InTypeset(tv, _) => tv.singleton_type().is_some(), + } + } + + fn typevar_args(&self) -> Vec<&TypeVar> { + match self { + Constraint::WiderOrEq(lhs, rhs) => vec![lhs, rhs], + Constraint::Eq(lhs, rhs) => vec![lhs, rhs], + Constraint::InTypeset(tv, _) => vec![tv], + } + } +} + +#[derive(Clone, Copy)] +enum TypeEnvRank { + Singleton = 5, + Input = 4, + Intermediate = 3, + Output = 2, + Temp = 1, + Internal = 0, +} + +/// Class encapsulating the necessary bookkeeping for type inference. +pub struct TypeEnvironment { + vars: HashSet, + ranks: HashMap, + equivalency_map: HashMap, + pub constraints: Vec, +} + +impl TypeEnvironment { + fn new() -> Self { + TypeEnvironment { + vars: HashSet::new(), + ranks: HashMap::new(), + equivalency_map: HashMap::new(), + constraints: Vec::new(), + } + } + + fn register(&mut self, var_index: VarIndex, var: &mut Var) { + self.vars.insert(var_index); + let rank = if var.is_input() { + TypeEnvRank::Input + } else if var.is_intermediate() { + TypeEnvRank::Intermediate + } else if var.is_output() { + TypeEnvRank::Output + } else { + assert!(var.is_temp()); + TypeEnvRank::Temp + }; + self.ranks.insert(var.get_or_create_typevar(), rank); + } + + fn add_constraint(&mut self, constraint: Constraint) { + if self + .constraints + .iter() + .find(|&item| item == &constraint) + .is_some() + { + return; + } + + // Check extra conditions for InTypeset constraints. + if let Constraint::InTypeset(tv, _) = &constraint { + assert!(tv.base.is_none()); + assert!(tv.name.starts_with("typeof_")); + } + + self.constraints.push(constraint); + } + + /// Returns the canonical representative of the equivalency class of the given argument, or + /// duplicates it if it's not there yet. + pub fn get_equivalent(&self, tv: &TypeVar) -> TypeVar { + let mut tv = tv; + while let Some(found) = self.equivalency_map.get(tv) { + tv = found; + } + match &tv.base { + Some(parent) => self + .get_equivalent(&parent.type_var) + .derived(parent.derived_func), + None => tv.clone(), + } + } + + /// Get the rank of tv in the partial order: + /// - TVs directly associated with a Var get their rank from the Var (see register()). + /// - Internally generated non-derived TVs implicitly get the lowest rank (0). + /// - Derived variables get their rank from their free typevar. + /// - Singletons have the highest rank. + /// - TVs associated with vars in a source pattern have a higher rank than TVs associated with + /// temporary vars. + fn rank(&self, tv: &TypeVar) -> u8 { + let actual_tv = match tv.base { + Some(_) => tv.free_typevar(), + None => Some(tv.clone()), + }; + + let rank = match actual_tv { + Some(actual_tv) => match self.ranks.get(&actual_tv) { + Some(rank) => Some(*rank), + None => { + assert!( + !actual_tv.name.starts_with("typeof_"), + format!("variable {} should be explicitly ranked", actual_tv.name) + ); + None + } + }, + None => None, + }; + + let rank = match rank { + Some(rank) => rank, + None => { + if tv.singleton_type().is_some() { + TypeEnvRank::Singleton + } else { + TypeEnvRank::Internal + } + } + }; + + rank as u8 + } + + /// Record the fact that the free tv1 is part of the same equivalence class as tv2. The + /// canonical representative of the merged class is tv2's canonical representative. + fn record_equivalent(&mut self, tv1: TypeVar, tv2: TypeVar) { + assert!(tv1.base.is_none()); + assert!(self.get_equivalent(&tv1) == tv1); + if let Some(tv2_base) = &tv2.base { + // Ensure there are no cycles. + assert!(self.get_equivalent(&tv2_base.type_var) != tv1); + } + self.equivalency_map.insert(tv1, tv2); + } + + /// Get the free typevars in the current type environment. + pub fn free_typevars(&self, var_pool: &mut VarPool) -> Vec { + let mut typevars = Vec::new(); + typevars.extend(self.equivalency_map.keys().cloned()); + typevars.extend( + self.vars + .iter() + .map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()), + ); + + let set: HashSet = HashSet::from_iter( + typevars + .iter() + .map(|tv| self.get_equivalent(tv).free_typevar()) + .filter(|opt_tv| { + // Filter out singleton types. + return opt_tv.is_some(); + }) + .map(|tv| tv.unwrap()), + ); + Vec::from_iter(set) + } + + /// Normalize by collapsing any roots that don't correspond to a concrete type var AND have a + /// single type var derived from them or equivalent to them. + /// + /// e.g. if we have a root of the tree that looks like: + /// + /// typeof_a typeof_b + /// \\ / + /// typeof_x + /// | + /// half_width(1) + /// | + /// 1 + /// + /// we want to collapse the linear path between 1 and typeof_x. The resulting graph is: + /// + /// typeof_a typeof_b + /// \\ / + /// typeof_x + fn normalize(&mut self, var_pool: &mut VarPool) { + let source_tvs: HashSet = HashSet::from_iter( + self.vars + .iter() + .map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()), + ); + + let mut children: HashMap> = HashMap::new(); + + // Insert all the parents found by the derivation relationship. + for type_var in self.equivalency_map.values() { + if type_var.base.is_none() { + continue; + } + + let parent_tv = type_var.free_typevar(); + if parent_tv.is_none() { + // Ignore this type variable, it's a singleton. + continue; + } + let parent_tv = parent_tv.unwrap(); + + children + .entry(parent_tv) + .or_insert(HashSet::new()) + .insert(type_var.clone()); + } + + // Insert all the explicit equivalency links. + for (equivalent_tv, canon_tv) in self.equivalency_map.iter() { + children + .entry(canon_tv.clone()) + .or_insert(HashSet::new()) + .insert(equivalent_tv.clone()); + } + + // Remove links that are straight paths up to typevar of variables. + for free_root in self.free_typevars(var_pool) { + let mut root = &free_root; + while !source_tvs.contains(&root) + && children.contains_key(&root) + && children.get(&root).unwrap().len() == 1 + { + let child = children.get(&root).unwrap().iter().next().unwrap(); + assert_eq!(self.equivalency_map[child], root.clone()); + self.equivalency_map.remove(child); + root = child; + } + } + } + + /// Extract a clean type environment from self, that only mentions type vars associated with + /// real variables. + fn extract(self, var_pool: &mut VarPool) -> TypeEnvironment { + let vars_tv: HashSet = HashSet::from_iter( + self.vars + .iter() + .map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()), + ); + + let mut new_equivalency_map: HashMap = HashMap::new(); + for tv in &vars_tv { + let canon_tv = self.get_equivalent(tv); + if *tv != canon_tv { + new_equivalency_map.insert(tv.clone(), canon_tv.clone()); + } + + // Sanity check: the translated type map should only refer to real variables. + assert!(vars_tv.contains(tv)); + let canon_free_tv = canon_tv.free_typevar(); + assert!(canon_free_tv.is_none() || vars_tv.contains(&canon_free_tv.unwrap())); + } + + let mut new_constraints: HashSet = HashSet::new(); + for constraint in &self.constraints { + let constraint = constraint.translate_with_env(&self); + if constraint.is_trivial() || new_constraints.contains(&constraint) { + continue; + } + + // Sanity check: translated constraints should refer only to real variables. + for arg in constraint.typevar_args() { + assert!(vars_tv.contains(arg)); + let arg_free_tv = arg.free_typevar(); + assert!(arg_free_tv.is_none() || vars_tv.contains(&arg_free_tv.unwrap())); + } + + new_constraints.insert(constraint); + } + + TypeEnvironment { + vars: self.vars, + ranks: self.ranks, + equivalency_map: new_equivalency_map, + constraints: Vec::from_iter(new_constraints), + } + } +} + +/// Replaces an external type variable according to the following rules: +/// - if a local copy is present in the map, return it. +/// - or if it's derived, create a local derived one that recursively substitutes the parent. +/// - or return itself. +fn substitute(map: &HashMap<&TypeVar, TypeVar>, external_type_var: &TypeVar) -> TypeVar { + match map.get(&external_type_var) { + Some(own_type_var) => own_type_var.clone(), + None => match &external_type_var.base { + Some(parent) => { + let parent_substitute = substitute(map, &parent.type_var); + TypeVar::derived(&parent_substitute, parent.derived_func) + } + None => external_type_var.clone(), + }, + } +} + +/// Normalize a (potentially derived) typevar using the following rules: +/// +/// - vector and width derived functions commute +/// {HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) -> +/// {HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base)) +/// +/// - half/double pairs collapse +/// {HALF,DOUBLE}WIDTH({DOUBLE,HALF}WIDTH(base)) -> base +/// {HALF,DOUBLE}VECTOR({DOUBLE,HALF}VECTOR(base)) -> base +fn canonicalize_derivations(tv: TypeVar) -> TypeVar { + let base = match &tv.base { + Some(base) => base, + None => return tv, + }; + + let derived_func = base.derived_func; + + if let Some(base_base) = &base.type_var.base { + let base_base_tv = &base_base.type_var; + match (derived_func, base_base.derived_func) { + (DerivedFunc::HalfWidth, DerivedFunc::DoubleWidth) + | (DerivedFunc::DoubleWidth, DerivedFunc::HalfWidth) + | (DerivedFunc::HalfVector, DerivedFunc::DoubleVector) + | (DerivedFunc::DoubleVector, DerivedFunc::HalfVector) => { + // Cancelling bijective transformations. This doesn't hide any overflow issues + // since derived type sets are checked upon derivaion, and base typesets are only + // allowed to shrink. + return canonicalize_derivations(base_base_tv.clone()); + } + (DerivedFunc::HalfWidth, DerivedFunc::HalfVector) + | (DerivedFunc::HalfWidth, DerivedFunc::DoubleVector) + | (DerivedFunc::DoubleWidth, DerivedFunc::DoubleVector) + | (DerivedFunc::DoubleWidth, DerivedFunc::HalfVector) => { + // Arbitrarily put WIDTH derivations before VECTOR derivations, since they commute. + return canonicalize_derivations( + base_base_tv + .derived(derived_func) + .derived(base_base.derived_func), + ); + } + _ => {} + }; + } + + canonicalize_derivations(base.type_var.clone()).derived(derived_func) +} + +/// Given typevars tv1 and tv2 (which could be derived from one another), constrain their typesets +/// to be the same. When one is derived from the other, repeat the constrain process until +/// a fixed point is reached. +fn constrain_fixpoint(tv1: &TypeVar, tv2: &TypeVar) { + loop { + let old_tv1_ts = tv1.get_typeset().clone(); + tv2.constrain_types(tv1.clone()); + if tv1.get_typeset() == old_tv1_ts { + break; + } + } + + let old_tv2_ts = tv2.get_typeset().clone(); + tv1.constrain_types(tv2.clone()); + // The above loop should ensure that all reference cycles have been handled. + assert!(old_tv2_ts == tv2.get_typeset()); +} + +/// Unify tv1 and tv2 in the given type environment. tv1 must have a rank greater or equal to tv2's +/// one, modulo commutations. +fn unify(tv1: &TypeVar, tv2: &TypeVar, type_env: &mut TypeEnvironment) -> Result<(), String> { + let tv1 = canonicalize_derivations(type_env.get_equivalent(tv1)); + let tv2 = canonicalize_derivations(type_env.get_equivalent(tv2)); + + if tv1 == tv2 { + // Already unified. + return Ok(()); + } + + if type_env.rank(&tv2) < type_env.rank(&tv1) { + // Make sure tv1 always has the smallest rank, since real variables have the higher rank + // and we want them to be the canonical representatives of their equivalency classes. + return unify(&tv2, &tv1, type_env); + } + + constrain_fixpoint(&tv1, &tv2); + + if tv1.get_typeset().size() == 0 || tv2.get_typeset().size() == 0 { + return Err(format!( + "Error: empty type created when unifying {} and {}", + tv1.name, tv2.name + )); + } + + let base = match &tv1.base { + Some(base) => base, + None => { + type_env.record_equivalent(tv1, tv2); + return Ok(()); + } + }; + + if let Some(inverse) = base.derived_func.inverse() { + return unify(&base.type_var, &tv2.derived(inverse), type_env); + } + + type_env.add_constraint(Constraint::Eq(tv1, tv2)); + Ok(()) +} + +/// Perform type inference on one Def in the current type environment and return an updated type +/// environment or error. +/// +/// At a high level this works by creating fresh copies of each formal type var in the Def's +/// instruction's signature, and unifying the formal typevar with the corresponding actual typevar. +fn infer_definition( + def: &Def, + var_pool: &mut VarPool, + type_env: TypeEnvironment, + last_type_index: &mut usize, +) -> Result { + let apply = &def.apply; + let inst = &apply.inst; + + let mut type_env = type_env; + let free_formal_tvs = inst.all_typevars(); + + let mut original_to_own_typevar: HashMap<&TypeVar, TypeVar> = HashMap::new(); + for &tv in &free_formal_tvs { + assert!(original_to_own_typevar + .insert( + tv, + TypeVar::copy_from(tv, format!("own_{}", last_type_index)) + ) + .is_none()); + *last_type_index += 1; + } + + // Update the mapping with any explicity bound type vars: + for (i, value_type) in apply.value_types.iter().enumerate() { + let singleton = TypeVar::new_singleton(value_type.clone()); + assert!(original_to_own_typevar + .insert(free_formal_tvs[i], singleton) + .is_some()); + } + + // Get fresh copies for each typevar in the signature (both free and derived). + let mut formal_tvs = Vec::new(); + formal_tvs.extend(inst.value_results.iter().map(|&i| { + substitute( + &original_to_own_typevar, + inst.operands_out[i].type_var().unwrap(), + ) + })); + formal_tvs.extend(inst.value_opnums.iter().map(|&i| { + substitute( + &original_to_own_typevar, + inst.operands_in[i].type_var().unwrap(), + ) + })); + + // Get the list of actual vars. + let mut actual_vars = Vec::new(); + actual_vars.extend(inst.value_results.iter().map(|&i| def.defined_vars[i])); + actual_vars.extend( + inst.value_opnums + .iter() + .map(|&i| apply.args[i].unwrap_var()), + ); + + // Get the list of the actual TypeVars. + let mut actual_tvs = Vec::new(); + for var_index in actual_vars { + let var = var_pool.get_mut(var_index); + type_env.register(var_index, var); + actual_tvs.push(var.get_or_create_typevar()); + } + + // Make sure we start unifying with the control type variable first, by putting it at the + // front of both vectors. + if let Some(poly) = &inst.polymorphic_info { + let own_ctrl_tv = &original_to_own_typevar[&poly.ctrl_typevar]; + let ctrl_index = formal_tvs.iter().position(|tv| tv == own_ctrl_tv).unwrap(); + if ctrl_index != 0 { + formal_tvs.swap(0, ctrl_index); + actual_tvs.swap(0, ctrl_index); + } + } + + // Unify each actual type variable with the corresponding formal type variable. + for (actual_tv, formal_tv) in actual_tvs.iter().zip(&formal_tvs) { + if let Err(msg) = unify(actual_tv, formal_tv, &mut type_env) { + return Err(format!( + "fail ti on {} <: {}: {}", + actual_tv.name, formal_tv.name, msg + )); + } + } + + // Add any instruction specific constraints. + for constraint in &inst.constraints { + type_env.add_constraint(constraint.translate_with_map(&original_to_own_typevar)); + } + + Ok(type_env) +} + +/// Perform type inference on an transformation. Return an updated type environment or error. +pub fn infer_transform( + src: DefIndex, + dst: &Vec, + def_pool: &DefPool, + var_pool: &mut VarPool, +) -> Result { + let mut type_env = TypeEnvironment::new(); + let mut last_type_index = 0; + + // Execute type inference on the source pattern. + type_env = infer_definition(def_pool.get(src), var_pool, type_env, &mut last_type_index) + .map_err(|err| format!("In src pattern: {}", err))?; + + // Collect the type sets once after applying the source patterm; we'll compare the typesets + // after we've also considered the destination pattern, and will emit supplementary InTypeset + // checks if they don't match. + let src_typesets = type_env + .vars + .iter() + .map(|&var_index| { + let var = var_pool.get_mut(var_index); + let tv = type_env.get_equivalent(&var.get_or_create_typevar()); + (var_index, tv.get_typeset().clone()) + }) + .collect::>(); + + // Execute type inference on the destination pattern. + for (i, &def_index) in dst.iter().enumerate() { + let def = def_pool.get(def_index); + type_env = infer_definition(def, var_pool, type_env, &mut last_type_index) + .map_err(|err| format!("line {}: {}", i, err))?; + } + + for (var_index, src_typeset) in src_typesets { + let var = var_pool.get(var_index); + if !var.has_free_typevar() { + continue; + } + let tv = type_env.get_equivalent(&var.get_typevar().unwrap()); + let new_typeset = tv.get_typeset(); + assert!( + new_typeset.is_subset(&src_typeset), + "type sets can only get narrower" + ); + if new_typeset != src_typeset { + type_env.add_constraint(Constraint::InTypeset(tv.clone(), new_typeset.clone())); + } + } + + type_env.normalize(var_pool); + + Ok(type_env.extract(var_pool)) } diff --git a/cranelift/codegen/meta/src/cdsl/xform.rs b/cranelift/codegen/meta/src/cdsl/xform.rs new file mode 100644 index 0000000000..c748459441 --- /dev/null +++ b/cranelift/codegen/meta/src/cdsl/xform.rs @@ -0,0 +1,416 @@ +use crate::cdsl::ast::{ + Apply, DefIndex, DefPool, DummyDef, DummyExpr, Expr, PatternPosition, VarIndex, VarPool, +}; +use crate::cdsl::inst::Instruction; +use crate::cdsl::type_inference::{infer_transform, TypeEnvironment}; +use crate::cdsl::typevar::TypeVar; + +use cranelift_entity::{entity_impl, PrimaryMap}; + +use std::collections::{HashMap, HashSet}; +use std::iter::FromIterator; + +/// An instruction transformation consists of a source and destination pattern. +/// +/// Patterns are expressed in *register transfer language* as tuples of Def or Expr nodes. A +/// pattern may optionally have a sequence of TypeConstraints, that additionally limit the set of +/// cases when it applies. +/// +/// The source pattern can contain only a single instruction. +pub struct Transform { + pub src: DefIndex, + pub dst: Vec, + pub var_pool: VarPool, + pub def_pool: DefPool, + pub type_env: TypeEnvironment, +} + +type SymbolTable = HashMap<&'static str, VarIndex>; + +impl Transform { + fn new(src: DummyDef, dst: Vec) -> Self { + let mut var_pool = VarPool::new(); + let mut def_pool = DefPool::new(); + + let mut input_vars: Vec = Vec::new(); + let mut defined_vars: Vec = Vec::new(); + + // Maps variable names to our own Var copies. + let mut symbol_table: SymbolTable = SymbolTable::new(); + + // Rewrite variables in src and dst using our own copies. + let src = rewrite_def_list( + PatternPosition::Source, + vec![src], + &mut symbol_table, + &mut input_vars, + &mut defined_vars, + &mut var_pool, + &mut def_pool, + )[0]; + + let num_src_inputs = input_vars.len(); + + let dst = rewrite_def_list( + PatternPosition::Destination, + dst, + &mut symbol_table, + &mut input_vars, + &mut defined_vars, + &mut var_pool, + &mut def_pool, + ); + + // Sanity checks. + for &var_index in &input_vars { + assert!( + var_pool.get(var_index).is_input(), + format!("'{:?}' used as both input and def", var_pool.get(var_index)) + ); + } + assert!( + input_vars.len() == num_src_inputs, + format!( + "extra input vars in dst pattern: {:?}", + input_vars + .iter() + .map(|&i| var_pool.get(i)) + .skip(num_src_inputs) + .collect::>() + ) + ); + + // Perform type inference and cleanup. + let type_env = infer_transform(src, &dst, &def_pool, &mut var_pool).unwrap(); + + // Sanity check: the set of inferred free type variables should be a subset of the type + // variables corresponding to Vars appearing in the source pattern. + { + let free_typevars: HashSet = + HashSet::from_iter(type_env.free_typevars(&mut var_pool)); + let src_tvs = HashSet::from_iter( + input_vars + .clone() + .iter() + .chain( + defined_vars + .iter() + .filter(|&&var_index| !var_pool.get(var_index).is_temp()), + ) + .map(|&var_index| var_pool.get(var_index).get_typevar()) + .filter(|maybe_var| maybe_var.is_some()) + .map(|var| var.unwrap()), + ); + if !free_typevars.is_subset(&src_tvs) { + let missing_tvs = (&free_typevars - &src_tvs) + .iter() + .map(|tv| tv.name.clone()) + .collect::>() + .join(", "); + panic!("Some free vars don't appear in src: {}", missing_tvs); + } + } + + for &var_index in input_vars.iter().chain(defined_vars.iter()) { + let var = var_pool.get_mut(var_index); + let canon_tv = type_env.get_equivalent(&var.get_or_create_typevar()); + var.set_typevar(canon_tv); + } + + Self { + src, + dst, + var_pool, + def_pool, + type_env, + } + } + + fn verify_legalize(&self) { + let def = self.def_pool.get(self.src); + for &var_index in def.defined_vars.iter() { + let defined_var = self.var_pool.get(var_index); + assert!( + defined_var.is_output(), + format!("{:?} not defined in the destination pattern", defined_var) + ); + } + } +} + +/// Given a list of symbols defined in a Def, rewrite them to local symbols. Yield the new locals. +fn rewrite_defined_vars( + position: PatternPosition, + dummy_def: &DummyDef, + def_index: DefIndex, + symbol_table: &mut SymbolTable, + defined_vars: &mut Vec, + var_pool: &mut VarPool, +) -> Vec { + let mut new_defined_vars = Vec::new(); + for var in &dummy_def.defined_vars { + let own_var = match symbol_table.get(var.name) { + Some(&existing_var) => existing_var, + None => { + // Materialize the variable. + let new_var = var_pool.create(var.name); + symbol_table.insert(var.name, new_var); + defined_vars.push(new_var); + new_var + } + }; + var_pool.get_mut(own_var).set_def(position, def_index); + new_defined_vars.push(own_var); + } + new_defined_vars +} + +/// Find all uses of variables in `expr` and replace them with our own local symbols. +fn rewrite_expr( + position: PatternPosition, + dummy_expr: DummyExpr, + symbol_table: &mut SymbolTable, + input_vars: &mut Vec, + var_pool: &mut VarPool, +) -> Apply { + let (apply_target, dummy_args) = if let DummyExpr::Apply(apply_target, dummy_args) = dummy_expr + { + (apply_target, dummy_args) + } else { + panic!("we only rewrite apply expressions"); + }; + + assert_eq!( + apply_target.inst().operands_in.len(), + dummy_args.len(), + "number of arguments in instruction is incorrect" + ); + + let mut args = Vec::new(); + for (i, arg) in dummy_args.into_iter().enumerate() { + match arg { + DummyExpr::Var(var) => { + let own_var = match symbol_table.get(var.name) { + Some(&own_var) => { + let var = var_pool.get(own_var); + assert!( + var.is_input() || var.get_def(position).is_some(), + format!("{:?} used as both input and def", var) + ); + own_var + } + None => { + // First time we're using this variable. + let own_var = var_pool.create(var.name); + symbol_table.insert(var.name, own_var); + input_vars.push(own_var); + own_var + } + }; + args.push(Expr::Var(own_var)); + } + DummyExpr::Literal(literal) => { + assert!(!apply_target.inst().operands_in[i].is_value()); + args.push(Expr::Literal(literal)); + } + DummyExpr::Apply(..) => { + panic!("Recursive apply is not allowed."); + } + } + } + + Apply::new(apply_target, args) +} + +fn rewrite_def_list( + position: PatternPosition, + dummy_defs: Vec, + symbol_table: &mut SymbolTable, + input_vars: &mut Vec, + defined_vars: &mut Vec, + var_pool: &mut VarPool, + def_pool: &mut DefPool, +) -> Vec { + let mut new_defs = Vec::new(); + for dummy_def in dummy_defs { + let def_index = def_pool.next_index(); + + let new_defined_vars = rewrite_defined_vars( + position, + &dummy_def, + def_index, + symbol_table, + defined_vars, + var_pool, + ); + let new_apply = rewrite_expr(position, dummy_def.expr, symbol_table, input_vars, var_pool); + + assert!( + def_pool.next_index() == def_index, + "shouldn't have created new defs in the meanwhile" + ); + assert_eq!( + new_apply.inst.value_results.len(), + new_defined_vars.len(), + "number of Var results in instruction is incorrect" + ); + + new_defs.push(def_pool.create(new_apply, new_defined_vars)); + } + new_defs +} + +/// A group of related transformations. +pub struct TransformGroup { + pub name: &'static str, + pub doc: &'static str, + pub chain_with: Option, + pub isa_name: Option<&'static str>, + pub id: TransformGroupIndex, + + /// Maps Instruction camel_case names to custom legalization functions names. + pub custom_legalizes: HashMap, + pub transforms: Vec, +} + +impl TransformGroup { + pub fn rust_name(&self) -> String { + match self.isa_name { + Some(_) => { + // This is a function in the same module as the LEGALIZE_ACTIONS table referring to + // it. + self.name.to_string() + } + None => format!("crate::legalizer::{}", self.name), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct TransformGroupIndex(u32); +entity_impl!(TransformGroupIndex); + +pub struct TransformGroupBuilder { + name: &'static str, + doc: &'static str, + chain_with: Option, + isa_name: Option<&'static str>, + pub custom_legalizes: HashMap, + pub transforms: Vec, +} + +impl TransformGroupBuilder { + pub fn new(name: &'static str, doc: &'static str) -> Self { + Self { + name, + doc, + chain_with: None, + isa_name: None, + custom_legalizes: HashMap::new(), + transforms: Vec::new(), + } + } + + pub fn chain_with(mut self, next_id: TransformGroupIndex) -> Self { + assert!(self.chain_with.is_none()); + self.chain_with = Some(next_id); + self + } + + pub fn isa(mut self, isa_name: &'static str) -> Self { + assert!(self.isa_name.is_none()); + self.isa_name = Some(isa_name); + self + } + + /// Add a custom legalization action for `inst`. + /// + /// The `func_name` parameter is the fully qualified name of a Rust function which takes the + /// same arguments as the `isa::Legalize` actions. + /// + /// The custom function will be called to legalize `inst` and any return value is ignored. + pub fn custom_legalize(&mut self, inst: &Instruction, func_name: &'static str) { + assert!( + self.custom_legalizes + .insert(inst.camel_name.clone(), func_name) + .is_none(), + format!( + "custom legalization action for {} inserted twice", + inst.name + ) + ); + } + + /// Add a legalization pattern to this group. + pub fn legalize(&mut self, src: DummyDef, dst: Vec) { + let transform = Transform::new(src, dst); + transform.verify_legalize(); + self.transforms.push(transform); + } + + pub fn finish_and_add_to(self, owner: &mut TransformGroups) -> TransformGroupIndex { + let next_id = owner.next_key(); + owner.add(TransformGroup { + name: self.name, + doc: self.doc, + isa_name: self.isa_name, + id: next_id, + chain_with: self.chain_with, + custom_legalizes: self.custom_legalizes, + transforms: self.transforms, + }) + } +} + +pub struct TransformGroups { + groups: PrimaryMap, +} + +impl TransformGroups { + pub fn new() -> Self { + Self { + groups: PrimaryMap::new(), + } + } + pub fn add(&mut self, new_group: TransformGroup) -> TransformGroupIndex { + for group in self.groups.values() { + assert!( + group.name != new_group.name, + format!("trying to insert {} for the second time", new_group.name) + ); + } + self.groups.push(new_group) + } + pub fn get(&self, id: TransformGroupIndex) -> &TransformGroup { + &self.groups[id] + } + pub fn get_mut(&mut self, id: TransformGroupIndex) -> &mut TransformGroup { + self.groups.get_mut(id).unwrap() + } + fn next_key(&self) -> TransformGroupIndex { + self.groups.next_key() + } + pub fn by_name(&self, name: &'static str) -> &TransformGroup { + for group in self.groups.values() { + if group.name == name { + return group; + } + } + panic!(format!("transform group with name {} not found", name)); + } +} + +#[test] +#[should_panic] +fn test_double_custom_legalization() { + use crate::cdsl::formats::{FormatRegistry, InstructionFormatBuilder}; + use crate::cdsl::inst::InstructionBuilder; + + let mut format = FormatRegistry::new(); + format.insert(InstructionFormatBuilder::new("nullary")); + let dummy_inst = InstructionBuilder::new("dummy", "doc").finish(&format); + + let mut transform_group = TransformGroupBuilder::new("test", "doc"); + transform_group.custom_legalize(&dummy_inst, "custom 1"); + transform_group.custom_legalize(&dummy_inst, "custom 2"); +}