[meta] Add type inference, transforms and AST helpers for legalization;
This commit is contained in:
653
cranelift/codegen/meta/src/cdsl/ast.rs
Normal file
653
cranelift/codegen/meta/src/cdsl/ast.rs
Normal file
@@ -0,0 +1,653 @@
|
||||
use crate::cdsl::formats::FormatRegistry;
|
||||
use crate::cdsl::inst::{BoundInstruction, Instruction, InstructionPredicate};
|
||||
use crate::cdsl::operands::{OperandKind, OperandKindFields};
|
||||
use crate::cdsl::types::{LaneType, ValueType};
|
||||
use crate::cdsl::typevar::{TypeSetBuilder, TypeVar};
|
||||
|
||||
use cranelift_entity::{entity_impl, PrimaryMap};
|
||||
|
||||
use std::fmt;
|
||||
|
||||
pub enum Expr {
|
||||
Var(VarIndex),
|
||||
Literal(Literal),
|
||||
Apply(Apply),
|
||||
}
|
||||
|
||||
impl Expr {
|
||||
pub fn maybe_literal(&self) -> Option<&Literal> {
|
||||
match &self {
|
||||
Expr::Literal(lit) => Some(lit),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn maybe_var(&self) -> Option<VarIndex> {
|
||||
if let Expr::Var(var) = &self {
|
||||
Some(*var)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unwrap_var(&self) -> VarIndex {
|
||||
self.maybe_var()
|
||||
.expect("tried to unwrap a non-Var content in Expr::unwrap_var")
|
||||
}
|
||||
|
||||
pub fn to_rust_code(&self, var_pool: &VarPool) -> String {
|
||||
match self {
|
||||
Expr::Var(var_index) => var_pool.get(*var_index).to_rust_code(),
|
||||
Expr::Literal(literal) => literal.to_rust_code(),
|
||||
Expr::Apply(a) => a.to_rust_code(var_pool),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An AST definition associates a set of variables with the values produced by an expression.
|
||||
pub struct Def {
|
||||
pub apply: Apply,
|
||||
pub defined_vars: Vec<VarIndex>,
|
||||
}
|
||||
|
||||
impl Def {
|
||||
pub fn to_comment_string(&self, var_pool: &VarPool) -> String {
|
||||
let results = self
|
||||
.defined_vars
|
||||
.iter()
|
||||
.map(|&x| var_pool.get(x).name)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let results = if results.len() == 1 {
|
||||
results[0].to_string()
|
||||
} else {
|
||||
format!("({})", results.join(", "))
|
||||
};
|
||||
|
||||
format!("{} << {}", results, self.apply.to_comment_string(var_pool))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DefPool {
|
||||
pool: PrimaryMap<DefIndex, Def>,
|
||||
}
|
||||
|
||||
impl DefPool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
pool: PrimaryMap::new(),
|
||||
}
|
||||
}
|
||||
pub fn get(&self, index: DefIndex) -> &Def {
|
||||
self.pool.get(index).unwrap()
|
||||
}
|
||||
pub fn get_mut(&mut self, index: DefIndex) -> &mut Def {
|
||||
self.pool.get_mut(index).unwrap()
|
||||
}
|
||||
pub fn next_index(&self) -> DefIndex {
|
||||
self.pool.next_key()
|
||||
}
|
||||
pub fn create(&mut self, apply: Apply, defined_vars: Vec<VarIndex>) -> DefIndex {
|
||||
self.pool.push(Def {
|
||||
apply,
|
||||
defined_vars,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub struct DefIndex(u32);
|
||||
entity_impl!(DefIndex);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum LiteralValue {
|
||||
/// A value of an enumerated immediate operand.
|
||||
///
|
||||
/// Some immediate operand kinds like `intcc` and `floatcc` have an enumerated range of values
|
||||
/// corresponding to a Rust enum type. An `Enumerator` object is an AST leaf node representing one
|
||||
/// of the values.
|
||||
Enumerator(&'static str),
|
||||
|
||||
/// A bitwise value of an immediate operand, used for bitwise exact floating point constants.
|
||||
Bits(u64),
|
||||
|
||||
/// A value of an integer immediate operand.
|
||||
Int(i64),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Literal {
|
||||
kind: OperandKind,
|
||||
value: LiteralValue,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Literal {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
|
||||
write!(
|
||||
fmt,
|
||||
"Literal(kind={}, value={:?})",
|
||||
self.kind.name, self.value
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Literal {
|
||||
pub fn enumerator_for(kind: &OperandKind, value: &'static str) -> Self {
|
||||
if let OperandKindFields::ImmEnum(values) = &kind.fields {
|
||||
assert!(
|
||||
values.get(value).is_some(),
|
||||
format!(
|
||||
"nonexistent value '{}' in enumeration '{}'",
|
||||
value, kind.name
|
||||
)
|
||||
);
|
||||
} else {
|
||||
panic!("enumerator is for enum values");
|
||||
}
|
||||
Self {
|
||||
kind: kind.clone(),
|
||||
value: LiteralValue::Enumerator(value),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bits(kind: &OperandKind, bits: u64) -> Self {
|
||||
match kind.fields {
|
||||
OperandKindFields::ImmValue => {}
|
||||
_ => panic!("bits_of is for immediate scalar types"),
|
||||
}
|
||||
Self {
|
||||
kind: kind.clone(),
|
||||
value: LiteralValue::Bits(bits),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn constant(kind: &OperandKind, value: i64) -> Self {
|
||||
match kind.fields {
|
||||
OperandKindFields::ImmValue => {}
|
||||
_ => panic!("bits_of is for immediate scalar types"),
|
||||
}
|
||||
Self {
|
||||
kind: kind.clone(),
|
||||
value: LiteralValue::Int(value),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_rust_code(&self) -> String {
|
||||
let maybe_values = match &self.kind.fields {
|
||||
OperandKindFields::ImmEnum(values) => Some(values),
|
||||
OperandKindFields::ImmValue => None,
|
||||
_ => panic!("impossible per construction"),
|
||||
};
|
||||
|
||||
match self.value {
|
||||
LiteralValue::Enumerator(value) => {
|
||||
format!("{}::{}", self.kind.rust_type, maybe_values.unwrap()[value])
|
||||
}
|
||||
LiteralValue::Bits(bits) => format!("{}::with_bits({:#x})", self.kind.rust_type, bits),
|
||||
LiteralValue::Int(val) => val.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum PatternPosition {
|
||||
Source,
|
||||
Destination,
|
||||
}
|
||||
|
||||
/// A free variable.
|
||||
///
|
||||
/// When variables are used in `XForms` with source and destination patterns, they are classified
|
||||
/// as follows:
|
||||
///
|
||||
/// Input values: Uses in the source pattern with no preceding def. These may appear as inputs in
|
||||
/// the destination pattern too, but no new inputs can be introduced.
|
||||
///
|
||||
/// Output values: Variables that are defined in both the source and destination pattern. These
|
||||
/// values may have uses outside the source pattern, and the destination pattern must compute the
|
||||
/// same value.
|
||||
///
|
||||
/// Intermediate values: Values that are defined in the source pattern, but not in the destination
|
||||
/// pattern. These may have uses outside the source pattern, so the defining instruction can't be
|
||||
/// deleted immediately.
|
||||
///
|
||||
/// Temporary values are defined only in the destination pattern.
|
||||
pub struct Var {
|
||||
pub name: &'static str,
|
||||
|
||||
/// The `Def` defining this variable in a source pattern.
|
||||
pub src_def: Option<DefIndex>,
|
||||
|
||||
/// The `Def` defining this variable in a destination pattern.
|
||||
pub dst_def: Option<DefIndex>,
|
||||
|
||||
/// TypeVar representing the type of this variable.
|
||||
type_var: Option<TypeVar>,
|
||||
|
||||
/// Is this the original type variable, or has it be redefined with set_typevar?
|
||||
is_original_type_var: bool,
|
||||
}
|
||||
|
||||
impl Var {
|
||||
fn new(name: &'static str) -> Self {
|
||||
Self {
|
||||
name,
|
||||
src_def: None,
|
||||
dst_def: None,
|
||||
type_var: None,
|
||||
is_original_type_var: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Is this an input value to the src pattern?
|
||||
pub fn is_input(&self) -> bool {
|
||||
self.src_def.is_none() && self.dst_def.is_none()
|
||||
}
|
||||
|
||||
/// Is this an output value, defined in both src and dst patterns?
|
||||
pub fn is_output(&self) -> bool {
|
||||
self.src_def.is_some() && self.dst_def.is_some()
|
||||
}
|
||||
|
||||
/// Is this an intermediate value, defined only in the src pattern?
|
||||
pub fn is_intermediate(&self) -> bool {
|
||||
self.src_def.is_some() && self.dst_def.is_none()
|
||||
}
|
||||
|
||||
/// Is this a temp value, defined only in the dst pattern?
|
||||
pub fn is_temp(&self) -> bool {
|
||||
self.src_def.is_none() && self.dst_def.is_some()
|
||||
}
|
||||
|
||||
/// Get the def of this variable according to the position.
|
||||
pub fn get_def(&self, position: PatternPosition) -> Option<DefIndex> {
|
||||
match position {
|
||||
PatternPosition::Source => self.src_def,
|
||||
PatternPosition::Destination => self.dst_def,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_def(&mut self, position: PatternPosition, def: DefIndex) {
|
||||
assert!(
|
||||
self.get_def(position).is_none(),
|
||||
format!("redefinition of variable {}", self.name)
|
||||
);
|
||||
match position {
|
||||
PatternPosition::Source => {
|
||||
self.src_def = Some(def);
|
||||
}
|
||||
PatternPosition::Destination => {
|
||||
self.dst_def = Some(def);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the type variable representing the type of this variable.
|
||||
pub fn get_or_create_typevar(&mut self) -> TypeVar {
|
||||
match &self.type_var {
|
||||
Some(tv) => tv.clone(),
|
||||
None => {
|
||||
// Create a new type var in which we allow all types.
|
||||
let tv = TypeVar::new(
|
||||
format!("typeof_{}", self.name),
|
||||
format!("Type of the pattern variable {:?}", self),
|
||||
TypeSetBuilder::all(),
|
||||
);
|
||||
self.type_var = Some(tv.clone());
|
||||
self.is_original_type_var = true;
|
||||
tv
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn get_typevar(&self) -> Option<TypeVar> {
|
||||
self.type_var.clone()
|
||||
}
|
||||
pub fn set_typevar(&mut self, tv: TypeVar) {
|
||||
self.is_original_type_var = if let Some(previous_tv) = &self.type_var {
|
||||
*previous_tv == tv
|
||||
} else {
|
||||
false
|
||||
};
|
||||
self.type_var = Some(tv);
|
||||
}
|
||||
|
||||
/// Check if this variable has a free type variable. If not, the type of this variable is
|
||||
/// computed from the type of another variable.
|
||||
pub fn has_free_typevar(&self) -> bool {
|
||||
match &self.type_var {
|
||||
Some(tv) => tv.base.is_none() && self.is_original_type_var,
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_rust_code(&self) -> String {
|
||||
self.name.into()
|
||||
}
|
||||
fn rust_type(&self) -> String {
|
||||
self.type_var.as_ref().unwrap().to_rust_code()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Var {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
|
||||
fmt.write_fmt(format_args!(
|
||||
"Var({}{}{})",
|
||||
self.name,
|
||||
if self.src_def.is_some() { ", src" } else { "" },
|
||||
if self.dst_def.is_some() { ", dst" } else { "" }
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub struct VarIndex(u32);
|
||||
entity_impl!(VarIndex);
|
||||
|
||||
pub struct VarPool {
|
||||
pool: PrimaryMap<VarIndex, Var>,
|
||||
}
|
||||
|
||||
impl VarPool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
pool: PrimaryMap::new(),
|
||||
}
|
||||
}
|
||||
pub fn get(&self, index: VarIndex) -> &Var {
|
||||
self.pool.get(index).unwrap()
|
||||
}
|
||||
pub fn get_mut(&mut self, index: VarIndex) -> &mut Var {
|
||||
self.pool.get_mut(index).unwrap()
|
||||
}
|
||||
pub fn create(&mut self, name: &'static str) -> VarIndex {
|
||||
self.pool.push(Var::new(name))
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ApplyTarget {
|
||||
Inst(Instruction),
|
||||
Bound(BoundInstruction),
|
||||
}
|
||||
|
||||
impl ApplyTarget {
|
||||
pub fn inst(&self) -> &Instruction {
|
||||
match &self {
|
||||
ApplyTarget::Inst(inst) => inst,
|
||||
ApplyTarget::Bound(bound_inst) => &bound_inst.inst,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<ApplyTarget> for &Instruction {
|
||||
fn into(self) -> ApplyTarget {
|
||||
ApplyTarget::Inst(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<ApplyTarget> for BoundInstruction {
|
||||
fn into(self) -> ApplyTarget {
|
||||
ApplyTarget::Bound(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bind(target: impl Into<ApplyTarget>, lane_type: impl Into<LaneType>) -> BoundInstruction {
|
||||
let value_type = ValueType::from(lane_type.into());
|
||||
|
||||
let (inst, value_types) = match target.into() {
|
||||
ApplyTarget::Inst(inst) => (inst, vec![value_type]),
|
||||
ApplyTarget::Bound(bound_inst) => {
|
||||
let mut new_value_types = bound_inst.value_types;
|
||||
new_value_types.push(value_type);
|
||||
(bound_inst.inst, new_value_types)
|
||||
}
|
||||
};
|
||||
|
||||
match &inst.polymorphic_info {
|
||||
Some(poly) => {
|
||||
assert!(
|
||||
value_types.len() <= 1 + poly.other_typevars.len(),
|
||||
format!("trying to bind too many types for {}", inst.name)
|
||||
);
|
||||
}
|
||||
None => {
|
||||
panic!(format!(
|
||||
"trying to bind a type for {} which is not a polymorphic instruction",
|
||||
inst.name
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
BoundInstruction { inst, value_types }
|
||||
}
|
||||
|
||||
/// Apply an instruction to arguments.
|
||||
///
|
||||
/// An `Apply` AST expression is created by using function call syntax on instructions. This
|
||||
/// applies to both bound and unbound polymorphic instructions.
|
||||
pub struct Apply {
|
||||
pub inst: Instruction,
|
||||
pub args: Vec<Expr>,
|
||||
pub value_types: Vec<ValueType>,
|
||||
}
|
||||
|
||||
impl Apply {
|
||||
pub fn new(target: ApplyTarget, args: Vec<Expr>) -> Self {
|
||||
let (inst, value_types) = match target.into() {
|
||||
ApplyTarget::Inst(inst) => (inst, Vec::new()),
|
||||
ApplyTarget::Bound(bound_inst) => (bound_inst.inst, bound_inst.value_types),
|
||||
};
|
||||
|
||||
// Basic check on number of arguments.
|
||||
assert!(
|
||||
inst.operands_in.len() == args.len(),
|
||||
format!("incorrect number of arguments in instruction {}", inst.name)
|
||||
);
|
||||
|
||||
// Check that the kinds of Literals arguments match the expected operand.
|
||||
for &imm_index in &inst.imm_opnums {
|
||||
let arg = &args[imm_index];
|
||||
if let Some(literal) = arg.maybe_literal() {
|
||||
let op = &inst.operands_in[imm_index];
|
||||
assert!(
|
||||
op.kind.name == literal.kind.name,
|
||||
format!(
|
||||
"Passing literal of kind {} to field of wrong kind {}",
|
||||
literal.kind.name, op.kind.name
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
inst,
|
||||
args,
|
||||
value_types,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_comment_string(&self, var_pool: &VarPool) -> String {
|
||||
let args = self
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| arg.to_rust_code(var_pool))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
let mut inst_and_bound_types = vec![self.inst.name.to_string()];
|
||||
inst_and_bound_types.extend(self.value_types.iter().map(|vt| vt.to_string()));
|
||||
let inst_name = inst_and_bound_types.join(".");
|
||||
|
||||
format!("{}({})", inst_name, args)
|
||||
}
|
||||
|
||||
fn to_rust_code(&self, var_pool: &VarPool) -> String {
|
||||
let args = self
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| arg.to_rust_code(var_pool))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
format!("{}({})", self.inst.name, args)
|
||||
}
|
||||
|
||||
fn inst_predicate(
|
||||
&self,
|
||||
format_registry: &FormatRegistry,
|
||||
var_pool: &VarPool,
|
||||
) -> InstructionPredicate {
|
||||
let iform = format_registry.get(self.inst.format);
|
||||
|
||||
let mut pred = InstructionPredicate::new();
|
||||
for (format_field, &op_num) in iform.imm_fields.iter().zip(self.inst.imm_opnums.iter()) {
|
||||
let arg = &self.args[op_num];
|
||||
if arg.maybe_var().is_some() {
|
||||
// Ignore free variables for now.
|
||||
continue;
|
||||
}
|
||||
pred = pred.and(InstructionPredicate::new_is_field_equal(
|
||||
&format_field,
|
||||
arg.to_rust_code(var_pool),
|
||||
));
|
||||
}
|
||||
|
||||
// Add checks for any bound secondary type variables. We can't check the controlling type
|
||||
// variable this way since it may not appear as the type of an operand.
|
||||
if self.value_types.len() > 1 {
|
||||
let poly = self
|
||||
.inst
|
||||
.polymorphic_info
|
||||
.as_ref()
|
||||
.expect("must have polymorphic info if it has bounded types");
|
||||
for (bound_type, type_var) in
|
||||
self.value_types[1..].iter().zip(poly.other_typevars.iter())
|
||||
{
|
||||
pred = pred.and(InstructionPredicate::new_typevar_check(
|
||||
&self.inst, type_var, bound_type,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
pred
|
||||
}
|
||||
|
||||
/// Same as `inst_predicate()`, but also check the controlling type variable.
|
||||
pub fn inst_predicate_with_ctrl_typevar(
|
||||
&self,
|
||||
format_registry: &FormatRegistry,
|
||||
var_pool: &VarPool,
|
||||
) -> InstructionPredicate {
|
||||
let mut pred = self.inst_predicate(format_registry, var_pool);
|
||||
|
||||
if !self.value_types.is_empty() {
|
||||
let bound_type = &self.value_types[0];
|
||||
let poly = self.inst.polymorphic_info.as_ref().unwrap();
|
||||
let type_check = if poly.use_typevar_operand {
|
||||
InstructionPredicate::new_typevar_check(&self.inst, &poly.ctrl_typevar, bound_type)
|
||||
} else {
|
||||
InstructionPredicate::new_ctrl_typevar_check(&bound_type)
|
||||
};
|
||||
pred = pred.and(type_check);
|
||||
}
|
||||
|
||||
pred
|
||||
}
|
||||
|
||||
pub fn rust_builder(&self, defined_vars: &Vec<VarIndex>, var_pool: &VarPool) -> String {
|
||||
let mut args = self
|
||||
.args
|
||||
.iter()
|
||||
.map(|expr| expr.to_rust_code(var_pool))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
// Do we need to pass an explicit type argument?
|
||||
if let Some(poly) = &self.inst.polymorphic_info {
|
||||
if !poly.use_typevar_operand {
|
||||
args = format!("{}, {}", var_pool.get(defined_vars[0]).rust_type(), args);
|
||||
}
|
||||
}
|
||||
|
||||
format!("{}({})", self.inst.snake_name(), args)
|
||||
}
|
||||
}
|
||||
|
||||
// Simple helpers for legalize actions construction.
|
||||
|
||||
pub enum DummyExpr {
|
||||
Var(DummyVar),
|
||||
Literal(Literal),
|
||||
Apply(ApplyTarget, Vec<DummyExpr>),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DummyVar {
|
||||
pub name: &'static str,
|
||||
}
|
||||
|
||||
impl Into<DummyExpr> for DummyVar {
|
||||
fn into(self) -> DummyExpr {
|
||||
DummyExpr::Var(self)
|
||||
}
|
||||
}
|
||||
impl Into<DummyExpr> for Literal {
|
||||
fn into(self) -> DummyExpr {
|
||||
DummyExpr::Literal(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn var(name: &'static str) -> DummyVar {
|
||||
DummyVar { name }
|
||||
}
|
||||
|
||||
pub struct DummyDef {
|
||||
pub expr: DummyExpr,
|
||||
pub defined_vars: Vec<DummyVar>,
|
||||
}
|
||||
|
||||
pub struct ExprBuilder {
|
||||
expr: DummyExpr,
|
||||
}
|
||||
|
||||
impl ExprBuilder {
|
||||
pub fn apply(inst: ApplyTarget, args: Vec<DummyExpr>) -> Self {
|
||||
let expr = DummyExpr::Apply(inst, args);
|
||||
Self { expr }
|
||||
}
|
||||
|
||||
pub fn assign_to(self, defined_vars: Vec<DummyVar>) -> DummyDef {
|
||||
DummyDef {
|
||||
expr: self.expr,
|
||||
defined_vars,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! def_rhs {
|
||||
// inst(a, b, c)
|
||||
($inst:ident($($src:expr),*)) => {
|
||||
ExprBuilder::apply($inst.into(), vec![$($src.clone().into()),*])
|
||||
};
|
||||
|
||||
// inst.type(a, b, c)
|
||||
($inst:ident.$type:ident($($src:expr),*)) => {
|
||||
ExprBuilder::apply(bind($inst, $type).into(), vec![$($src.clone().into()),*])
|
||||
};
|
||||
}
|
||||
|
||||
// Helper macro to define legalization recipes.
|
||||
macro_rules! def {
|
||||
// x = ...
|
||||
($dest:ident = $($tt:tt)*) => {
|
||||
def_rhs!($($tt)*).assign_to(vec![$dest.clone()])
|
||||
};
|
||||
|
||||
// (x, y, ...) = ...
|
||||
(($($dest:ident),*) = $($tt:tt)*) => {
|
||||
def_rhs!($($tt)*).assign_to(vec![$($dest.clone()),*])
|
||||
};
|
||||
|
||||
// An instruction with no results.
|
||||
($($tt:tt)*) => {
|
||||
def_rhs!($($tt)*).assign_to(Vec::new())
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,8 @@
|
||||
//! This module defines the classes that are used to define Cranelift
|
||||
//! instructions and other entities.
|
||||
|
||||
#[macro_use]
|
||||
pub mod ast;
|
||||
pub mod formats;
|
||||
pub mod inst;
|
||||
pub mod isa;
|
||||
@@ -12,6 +14,7 @@ pub mod settings;
|
||||
pub mod type_inference;
|
||||
pub mod types;
|
||||
pub mod typevar;
|
||||
pub mod xform;
|
||||
|
||||
/// A macro that converts boolean settings into predicates to look more natural.
|
||||
#[macro_export]
|
||||
|
||||
@@ -1,5 +1,658 @@
|
||||
use crate::cdsl::typevar::TypeVar;
|
||||
use crate::cdsl::ast::{Def, DefIndex, DefPool, Var, VarIndex, VarPool};
|
||||
use crate::cdsl::typevar::{DerivedFunc, TypeSet, TypeVar};
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::iter::FromIterator;
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
pub enum Constraint {
|
||||
/// Constraint specifying that a type var tv1 must be wider than or equal to type var tv2 at
|
||||
/// runtime. This requires that:
|
||||
/// 1) They have the same number of lanes
|
||||
/// 2) In a lane tv1 has at least as many bits as tv2.
|
||||
WiderOrEq(TypeVar, TypeVar),
|
||||
|
||||
/// Constraint specifying that two derived type vars must have the same runtime type.
|
||||
Eq(TypeVar, TypeVar),
|
||||
|
||||
/// Constraint specifying that a type var must belong to some typeset.
|
||||
InTypeset(TypeVar, TypeSet),
|
||||
}
|
||||
|
||||
impl Constraint {
|
||||
fn translate_with<F: Fn(&TypeVar) -> TypeVar>(&self, func: F) -> Constraint {
|
||||
match self {
|
||||
Constraint::WiderOrEq(lhs, rhs) => {
|
||||
let lhs = func(&lhs);
|
||||
let rhs = func(&rhs);
|
||||
Constraint::WiderOrEq(lhs, rhs)
|
||||
}
|
||||
Constraint::Eq(lhs, rhs) => {
|
||||
let lhs = func(&lhs);
|
||||
let rhs = func(&rhs);
|
||||
Constraint::Eq(lhs, rhs)
|
||||
}
|
||||
Constraint::InTypeset(tv, ts) => {
|
||||
let tv = func(&tv);
|
||||
Constraint::InTypeset(tv, ts.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new constraint by replacing type vars by their hashmap equivalent.
|
||||
fn translate_with_map(
|
||||
&self,
|
||||
original_to_own_typevar: &HashMap<&TypeVar, TypeVar>,
|
||||
) -> Constraint {
|
||||
self.translate_with(|tv| substitute(original_to_own_typevar, tv))
|
||||
}
|
||||
|
||||
/// Creates a new constraint by replacing type vars by their canonical equivalent.
|
||||
fn translate_with_env(&self, type_env: &TypeEnvironment) -> Constraint {
|
||||
self.translate_with(|tv| type_env.get_equivalent(tv))
|
||||
}
|
||||
|
||||
fn is_trivial(&self) -> bool {
|
||||
match self {
|
||||
Constraint::WiderOrEq(lhs, rhs) => {
|
||||
// Trivially true.
|
||||
if lhs == rhs {
|
||||
return true;
|
||||
}
|
||||
|
||||
let ts1 = lhs.get_typeset();
|
||||
let ts2 = rhs.get_typeset();
|
||||
|
||||
// Trivially true.
|
||||
if ts1.is_wider_or_equal(&ts2) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Trivially false.
|
||||
if ts1.is_narrower(&ts2) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Trivially false.
|
||||
if (&ts1.lanes & &ts2.lanes).len() == 0 {
|
||||
return true;
|
||||
}
|
||||
|
||||
self.is_concrete()
|
||||
}
|
||||
Constraint::Eq(lhs, rhs) => lhs == rhs || self.is_concrete(),
|
||||
Constraint::InTypeset(_, _) => {
|
||||
// The way InTypeset are made, they would always be trivial if we were applying the
|
||||
// same logic as the Python code did, so ignore this.
|
||||
self.is_concrete()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true iff all the referenced type vars are singletons.
|
||||
fn is_concrete(&self) -> bool {
|
||||
match self {
|
||||
Constraint::WiderOrEq(lhs, rhs) => {
|
||||
lhs.singleton_type().is_some() && rhs.singleton_type().is_some()
|
||||
}
|
||||
Constraint::Eq(lhs, rhs) => {
|
||||
lhs.singleton_type().is_some() && rhs.singleton_type().is_some()
|
||||
}
|
||||
Constraint::InTypeset(tv, _) => tv.singleton_type().is_some(),
|
||||
}
|
||||
}
|
||||
|
||||
fn typevar_args(&self) -> Vec<&TypeVar> {
|
||||
match self {
|
||||
Constraint::WiderOrEq(lhs, rhs) => vec![lhs, rhs],
|
||||
Constraint::Eq(lhs, rhs) => vec![lhs, rhs],
|
||||
Constraint::InTypeset(tv, _) => vec![tv],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum TypeEnvRank {
|
||||
Singleton = 5,
|
||||
Input = 4,
|
||||
Intermediate = 3,
|
||||
Output = 2,
|
||||
Temp = 1,
|
||||
Internal = 0,
|
||||
}
|
||||
|
||||
/// Class encapsulating the necessary bookkeeping for type inference.
|
||||
pub struct TypeEnvironment {
|
||||
vars: HashSet<VarIndex>,
|
||||
ranks: HashMap<TypeVar, TypeEnvRank>,
|
||||
equivalency_map: HashMap<TypeVar, TypeVar>,
|
||||
pub constraints: Vec<Constraint>,
|
||||
}
|
||||
|
||||
impl TypeEnvironment {
|
||||
fn new() -> Self {
|
||||
TypeEnvironment {
|
||||
vars: HashSet::new(),
|
||||
ranks: HashMap::new(),
|
||||
equivalency_map: HashMap::new(),
|
||||
constraints: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn register(&mut self, var_index: VarIndex, var: &mut Var) {
|
||||
self.vars.insert(var_index);
|
||||
let rank = if var.is_input() {
|
||||
TypeEnvRank::Input
|
||||
} else if var.is_intermediate() {
|
||||
TypeEnvRank::Intermediate
|
||||
} else if var.is_output() {
|
||||
TypeEnvRank::Output
|
||||
} else {
|
||||
assert!(var.is_temp());
|
||||
TypeEnvRank::Temp
|
||||
};
|
||||
self.ranks.insert(var.get_or_create_typevar(), rank);
|
||||
}
|
||||
|
||||
fn add_constraint(&mut self, constraint: Constraint) {
|
||||
if self
|
||||
.constraints
|
||||
.iter()
|
||||
.find(|&item| item == &constraint)
|
||||
.is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Check extra conditions for InTypeset constraints.
|
||||
if let Constraint::InTypeset(tv, _) = &constraint {
|
||||
assert!(tv.base.is_none());
|
||||
assert!(tv.name.starts_with("typeof_"));
|
||||
}
|
||||
|
||||
self.constraints.push(constraint);
|
||||
}
|
||||
|
||||
/// Returns the canonical representative of the equivalency class of the given argument, or
|
||||
/// duplicates it if it's not there yet.
|
||||
pub fn get_equivalent(&self, tv: &TypeVar) -> TypeVar {
|
||||
let mut tv = tv;
|
||||
while let Some(found) = self.equivalency_map.get(tv) {
|
||||
tv = found;
|
||||
}
|
||||
match &tv.base {
|
||||
Some(parent) => self
|
||||
.get_equivalent(&parent.type_var)
|
||||
.derived(parent.derived_func),
|
||||
None => tv.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the rank of tv in the partial order:
|
||||
/// - TVs directly associated with a Var get their rank from the Var (see register()).
|
||||
/// - Internally generated non-derived TVs implicitly get the lowest rank (0).
|
||||
/// - Derived variables get their rank from their free typevar.
|
||||
/// - Singletons have the highest rank.
|
||||
/// - TVs associated with vars in a source pattern have a higher rank than TVs associated with
|
||||
/// temporary vars.
|
||||
fn rank(&self, tv: &TypeVar) -> u8 {
|
||||
let actual_tv = match tv.base {
|
||||
Some(_) => tv.free_typevar(),
|
||||
None => Some(tv.clone()),
|
||||
};
|
||||
|
||||
let rank = match actual_tv {
|
||||
Some(actual_tv) => match self.ranks.get(&actual_tv) {
|
||||
Some(rank) => Some(*rank),
|
||||
None => {
|
||||
assert!(
|
||||
!actual_tv.name.starts_with("typeof_"),
|
||||
format!("variable {} should be explicitly ranked", actual_tv.name)
|
||||
);
|
||||
None
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
let rank = match rank {
|
||||
Some(rank) => rank,
|
||||
None => {
|
||||
if tv.singleton_type().is_some() {
|
||||
TypeEnvRank::Singleton
|
||||
} else {
|
||||
TypeEnvRank::Internal
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
rank as u8
|
||||
}
|
||||
|
||||
/// Record the fact that the free tv1 is part of the same equivalence class as tv2. The
|
||||
/// canonical representative of the merged class is tv2's canonical representative.
|
||||
fn record_equivalent(&mut self, tv1: TypeVar, tv2: TypeVar) {
|
||||
assert!(tv1.base.is_none());
|
||||
assert!(self.get_equivalent(&tv1) == tv1);
|
||||
if let Some(tv2_base) = &tv2.base {
|
||||
// Ensure there are no cycles.
|
||||
assert!(self.get_equivalent(&tv2_base.type_var) != tv1);
|
||||
}
|
||||
self.equivalency_map.insert(tv1, tv2);
|
||||
}
|
||||
|
||||
/// Get the free typevars in the current type environment.
|
||||
pub fn free_typevars(&self, var_pool: &mut VarPool) -> Vec<TypeVar> {
|
||||
let mut typevars = Vec::new();
|
||||
typevars.extend(self.equivalency_map.keys().cloned());
|
||||
typevars.extend(
|
||||
self.vars
|
||||
.iter()
|
||||
.map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()),
|
||||
);
|
||||
|
||||
let set: HashSet<TypeVar> = HashSet::from_iter(
|
||||
typevars
|
||||
.iter()
|
||||
.map(|tv| self.get_equivalent(tv).free_typevar())
|
||||
.filter(|opt_tv| {
|
||||
// Filter out singleton types.
|
||||
return opt_tv.is_some();
|
||||
})
|
||||
.map(|tv| tv.unwrap()),
|
||||
);
|
||||
Vec::from_iter(set)
|
||||
}
|
||||
|
||||
/// Normalize by collapsing any roots that don't correspond to a concrete type var AND have a
|
||||
/// single type var derived from them or equivalent to them.
|
||||
///
|
||||
/// e.g. if we have a root of the tree that looks like:
|
||||
///
|
||||
/// typeof_a typeof_b
|
||||
/// \\ /
|
||||
/// typeof_x
|
||||
/// |
|
||||
/// half_width(1)
|
||||
/// |
|
||||
/// 1
|
||||
///
|
||||
/// we want to collapse the linear path between 1 and typeof_x. The resulting graph is:
|
||||
///
|
||||
/// typeof_a typeof_b
|
||||
/// \\ /
|
||||
/// typeof_x
|
||||
fn normalize(&mut self, var_pool: &mut VarPool) {
|
||||
let source_tvs: HashSet<TypeVar> = HashSet::from_iter(
|
||||
self.vars
|
||||
.iter()
|
||||
.map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()),
|
||||
);
|
||||
|
||||
let mut children: HashMap<TypeVar, HashSet<TypeVar>> = HashMap::new();
|
||||
|
||||
// Insert all the parents found by the derivation relationship.
|
||||
for type_var in self.equivalency_map.values() {
|
||||
if type_var.base.is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let parent_tv = type_var.free_typevar();
|
||||
if parent_tv.is_none() {
|
||||
// Ignore this type variable, it's a singleton.
|
||||
continue;
|
||||
}
|
||||
let parent_tv = parent_tv.unwrap();
|
||||
|
||||
children
|
||||
.entry(parent_tv)
|
||||
.or_insert(HashSet::new())
|
||||
.insert(type_var.clone());
|
||||
}
|
||||
|
||||
// Insert all the explicit equivalency links.
|
||||
for (equivalent_tv, canon_tv) in self.equivalency_map.iter() {
|
||||
children
|
||||
.entry(canon_tv.clone())
|
||||
.or_insert(HashSet::new())
|
||||
.insert(equivalent_tv.clone());
|
||||
}
|
||||
|
||||
// Remove links that are straight paths up to typevar of variables.
|
||||
for free_root in self.free_typevars(var_pool) {
|
||||
let mut root = &free_root;
|
||||
while !source_tvs.contains(&root)
|
||||
&& children.contains_key(&root)
|
||||
&& children.get(&root).unwrap().len() == 1
|
||||
{
|
||||
let child = children.get(&root).unwrap().iter().next().unwrap();
|
||||
assert_eq!(self.equivalency_map[child], root.clone());
|
||||
self.equivalency_map.remove(child);
|
||||
root = child;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract a clean type environment from self, that only mentions type vars associated with
|
||||
/// real variables.
|
||||
fn extract(self, var_pool: &mut VarPool) -> TypeEnvironment {
|
||||
let vars_tv: HashSet<TypeVar> = HashSet::from_iter(
|
||||
self.vars
|
||||
.iter()
|
||||
.map(|&var_index| var_pool.get_mut(var_index).get_or_create_typevar()),
|
||||
);
|
||||
|
||||
let mut new_equivalency_map: HashMap<TypeVar, TypeVar> = HashMap::new();
|
||||
for tv in &vars_tv {
|
||||
let canon_tv = self.get_equivalent(tv);
|
||||
if *tv != canon_tv {
|
||||
new_equivalency_map.insert(tv.clone(), canon_tv.clone());
|
||||
}
|
||||
|
||||
// Sanity check: the translated type map should only refer to real variables.
|
||||
assert!(vars_tv.contains(tv));
|
||||
let canon_free_tv = canon_tv.free_typevar();
|
||||
assert!(canon_free_tv.is_none() || vars_tv.contains(&canon_free_tv.unwrap()));
|
||||
}
|
||||
|
||||
let mut new_constraints: HashSet<Constraint> = HashSet::new();
|
||||
for constraint in &self.constraints {
|
||||
let constraint = constraint.translate_with_env(&self);
|
||||
if constraint.is_trivial() || new_constraints.contains(&constraint) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Sanity check: translated constraints should refer only to real variables.
|
||||
for arg in constraint.typevar_args() {
|
||||
assert!(vars_tv.contains(arg));
|
||||
let arg_free_tv = arg.free_typevar();
|
||||
assert!(arg_free_tv.is_none() || vars_tv.contains(&arg_free_tv.unwrap()));
|
||||
}
|
||||
|
||||
new_constraints.insert(constraint);
|
||||
}
|
||||
|
||||
TypeEnvironment {
|
||||
vars: self.vars,
|
||||
ranks: self.ranks,
|
||||
equivalency_map: new_equivalency_map,
|
||||
constraints: Vec::from_iter(new_constraints),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Replaces an external type variable according to the following rules:
|
||||
/// - if a local copy is present in the map, return it.
|
||||
/// - or if it's derived, create a local derived one that recursively substitutes the parent.
|
||||
/// - or return itself.
|
||||
fn substitute(map: &HashMap<&TypeVar, TypeVar>, external_type_var: &TypeVar) -> TypeVar {
|
||||
match map.get(&external_type_var) {
|
||||
Some(own_type_var) => own_type_var.clone(),
|
||||
None => match &external_type_var.base {
|
||||
Some(parent) => {
|
||||
let parent_substitute = substitute(map, &parent.type_var);
|
||||
TypeVar::derived(&parent_substitute, parent.derived_func)
|
||||
}
|
||||
None => external_type_var.clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalize a (potentially derived) typevar using the following rules:
|
||||
///
|
||||
/// - vector and width derived functions commute
|
||||
/// {HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) ->
|
||||
/// {HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base))
|
||||
///
|
||||
/// - half/double pairs collapse
|
||||
/// {HALF,DOUBLE}WIDTH({DOUBLE,HALF}WIDTH(base)) -> base
|
||||
/// {HALF,DOUBLE}VECTOR({DOUBLE,HALF}VECTOR(base)) -> base
|
||||
fn canonicalize_derivations(tv: TypeVar) -> TypeVar {
|
||||
let base = match &tv.base {
|
||||
Some(base) => base,
|
||||
None => return tv,
|
||||
};
|
||||
|
||||
let derived_func = base.derived_func;
|
||||
|
||||
if let Some(base_base) = &base.type_var.base {
|
||||
let base_base_tv = &base_base.type_var;
|
||||
match (derived_func, base_base.derived_func) {
|
||||
(DerivedFunc::HalfWidth, DerivedFunc::DoubleWidth)
|
||||
| (DerivedFunc::DoubleWidth, DerivedFunc::HalfWidth)
|
||||
| (DerivedFunc::HalfVector, DerivedFunc::DoubleVector)
|
||||
| (DerivedFunc::DoubleVector, DerivedFunc::HalfVector) => {
|
||||
// Cancelling bijective transformations. This doesn't hide any overflow issues
|
||||
// since derived type sets are checked upon derivaion, and base typesets are only
|
||||
// allowed to shrink.
|
||||
return canonicalize_derivations(base_base_tv.clone());
|
||||
}
|
||||
(DerivedFunc::HalfWidth, DerivedFunc::HalfVector)
|
||||
| (DerivedFunc::HalfWidth, DerivedFunc::DoubleVector)
|
||||
| (DerivedFunc::DoubleWidth, DerivedFunc::DoubleVector)
|
||||
| (DerivedFunc::DoubleWidth, DerivedFunc::HalfVector) => {
|
||||
// Arbitrarily put WIDTH derivations before VECTOR derivations, since they commute.
|
||||
return canonicalize_derivations(
|
||||
base_base_tv
|
||||
.derived(derived_func)
|
||||
.derived(base_base.derived_func),
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
|
||||
canonicalize_derivations(base.type_var.clone()).derived(derived_func)
|
||||
}
|
||||
|
||||
/// Given typevars tv1 and tv2 (which could be derived from one another), constrain their typesets
|
||||
/// to be the same. When one is derived from the other, repeat the constrain process until
|
||||
/// a fixed point is reached.
|
||||
fn constrain_fixpoint(tv1: &TypeVar, tv2: &TypeVar) {
|
||||
loop {
|
||||
let old_tv1_ts = tv1.get_typeset().clone();
|
||||
tv2.constrain_types(tv1.clone());
|
||||
if tv1.get_typeset() == old_tv1_ts {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let old_tv2_ts = tv2.get_typeset().clone();
|
||||
tv1.constrain_types(tv2.clone());
|
||||
// The above loop should ensure that all reference cycles have been handled.
|
||||
assert!(old_tv2_ts == tv2.get_typeset());
|
||||
}
|
||||
|
||||
/// Unify tv1 and tv2 in the given type environment. tv1 must have a rank greater or equal to tv2's
|
||||
/// one, modulo commutations.
|
||||
fn unify(tv1: &TypeVar, tv2: &TypeVar, type_env: &mut TypeEnvironment) -> Result<(), String> {
|
||||
let tv1 = canonicalize_derivations(type_env.get_equivalent(tv1));
|
||||
let tv2 = canonicalize_derivations(type_env.get_equivalent(tv2));
|
||||
|
||||
if tv1 == tv2 {
|
||||
// Already unified.
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if type_env.rank(&tv2) < type_env.rank(&tv1) {
|
||||
// Make sure tv1 always has the smallest rank, since real variables have the higher rank
|
||||
// and we want them to be the canonical representatives of their equivalency classes.
|
||||
return unify(&tv2, &tv1, type_env);
|
||||
}
|
||||
|
||||
constrain_fixpoint(&tv1, &tv2);
|
||||
|
||||
if tv1.get_typeset().size() == 0 || tv2.get_typeset().size() == 0 {
|
||||
return Err(format!(
|
||||
"Error: empty type created when unifying {} and {}",
|
||||
tv1.name, tv2.name
|
||||
));
|
||||
}
|
||||
|
||||
let base = match &tv1.base {
|
||||
Some(base) => base,
|
||||
None => {
|
||||
type_env.record_equivalent(tv1, tv2);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(inverse) = base.derived_func.inverse() {
|
||||
return unify(&base.type_var, &tv2.derived(inverse), type_env);
|
||||
}
|
||||
|
||||
type_env.add_constraint(Constraint::Eq(tv1, tv2));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Perform type inference on one Def in the current type environment and return an updated type
|
||||
/// environment or error.
|
||||
///
|
||||
/// At a high level this works by creating fresh copies of each formal type var in the Def's
|
||||
/// instruction's signature, and unifying the formal typevar with the corresponding actual typevar.
|
||||
fn infer_definition(
|
||||
def: &Def,
|
||||
var_pool: &mut VarPool,
|
||||
type_env: TypeEnvironment,
|
||||
last_type_index: &mut usize,
|
||||
) -> Result<TypeEnvironment, String> {
|
||||
let apply = &def.apply;
|
||||
let inst = &apply.inst;
|
||||
|
||||
let mut type_env = type_env;
|
||||
let free_formal_tvs = inst.all_typevars();
|
||||
|
||||
let mut original_to_own_typevar: HashMap<&TypeVar, TypeVar> = HashMap::new();
|
||||
for &tv in &free_formal_tvs {
|
||||
assert!(original_to_own_typevar
|
||||
.insert(
|
||||
tv,
|
||||
TypeVar::copy_from(tv, format!("own_{}", last_type_index))
|
||||
)
|
||||
.is_none());
|
||||
*last_type_index += 1;
|
||||
}
|
||||
|
||||
// Update the mapping with any explicity bound type vars:
|
||||
for (i, value_type) in apply.value_types.iter().enumerate() {
|
||||
let singleton = TypeVar::new_singleton(value_type.clone());
|
||||
assert!(original_to_own_typevar
|
||||
.insert(free_formal_tvs[i], singleton)
|
||||
.is_some());
|
||||
}
|
||||
|
||||
// Get fresh copies for each typevar in the signature (both free and derived).
|
||||
let mut formal_tvs = Vec::new();
|
||||
formal_tvs.extend(inst.value_results.iter().map(|&i| {
|
||||
substitute(
|
||||
&original_to_own_typevar,
|
||||
inst.operands_out[i].type_var().unwrap(),
|
||||
)
|
||||
}));
|
||||
formal_tvs.extend(inst.value_opnums.iter().map(|&i| {
|
||||
substitute(
|
||||
&original_to_own_typevar,
|
||||
inst.operands_in[i].type_var().unwrap(),
|
||||
)
|
||||
}));
|
||||
|
||||
// Get the list of actual vars.
|
||||
let mut actual_vars = Vec::new();
|
||||
actual_vars.extend(inst.value_results.iter().map(|&i| def.defined_vars[i]));
|
||||
actual_vars.extend(
|
||||
inst.value_opnums
|
||||
.iter()
|
||||
.map(|&i| apply.args[i].unwrap_var()),
|
||||
);
|
||||
|
||||
// Get the list of the actual TypeVars.
|
||||
let mut actual_tvs = Vec::new();
|
||||
for var_index in actual_vars {
|
||||
let var = var_pool.get_mut(var_index);
|
||||
type_env.register(var_index, var);
|
||||
actual_tvs.push(var.get_or_create_typevar());
|
||||
}
|
||||
|
||||
// Make sure we start unifying with the control type variable first, by putting it at the
|
||||
// front of both vectors.
|
||||
if let Some(poly) = &inst.polymorphic_info {
|
||||
let own_ctrl_tv = &original_to_own_typevar[&poly.ctrl_typevar];
|
||||
let ctrl_index = formal_tvs.iter().position(|tv| tv == own_ctrl_tv).unwrap();
|
||||
if ctrl_index != 0 {
|
||||
formal_tvs.swap(0, ctrl_index);
|
||||
actual_tvs.swap(0, ctrl_index);
|
||||
}
|
||||
}
|
||||
|
||||
// Unify each actual type variable with the corresponding formal type variable.
|
||||
for (actual_tv, formal_tv) in actual_tvs.iter().zip(&formal_tvs) {
|
||||
if let Err(msg) = unify(actual_tv, formal_tv, &mut type_env) {
|
||||
return Err(format!(
|
||||
"fail ti on {} <: {}: {}",
|
||||
actual_tv.name, formal_tv.name, msg
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Add any instruction specific constraints.
|
||||
for constraint in &inst.constraints {
|
||||
type_env.add_constraint(constraint.translate_with_map(&original_to_own_typevar));
|
||||
}
|
||||
|
||||
Ok(type_env)
|
||||
}
|
||||
|
||||
/// Perform type inference on an transformation. Return an updated type environment or error.
|
||||
pub fn infer_transform(
|
||||
src: DefIndex,
|
||||
dst: &Vec<DefIndex>,
|
||||
def_pool: &DefPool,
|
||||
var_pool: &mut VarPool,
|
||||
) -> Result<TypeEnvironment, String> {
|
||||
let mut type_env = TypeEnvironment::new();
|
||||
let mut last_type_index = 0;
|
||||
|
||||
// Execute type inference on the source pattern.
|
||||
type_env = infer_definition(def_pool.get(src), var_pool, type_env, &mut last_type_index)
|
||||
.map_err(|err| format!("In src pattern: {}", err))?;
|
||||
|
||||
// Collect the type sets once after applying the source patterm; we'll compare the typesets
|
||||
// after we've also considered the destination pattern, and will emit supplementary InTypeset
|
||||
// checks if they don't match.
|
||||
let src_typesets = type_env
|
||||
.vars
|
||||
.iter()
|
||||
.map(|&var_index| {
|
||||
let var = var_pool.get_mut(var_index);
|
||||
let tv = type_env.get_equivalent(&var.get_or_create_typevar());
|
||||
(var_index, tv.get_typeset().clone())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Execute type inference on the destination pattern.
|
||||
for (i, &def_index) in dst.iter().enumerate() {
|
||||
let def = def_pool.get(def_index);
|
||||
type_env = infer_definition(def, var_pool, type_env, &mut last_type_index)
|
||||
.map_err(|err| format!("line {}: {}", i, err))?;
|
||||
}
|
||||
|
||||
for (var_index, src_typeset) in src_typesets {
|
||||
let var = var_pool.get(var_index);
|
||||
if !var.has_free_typevar() {
|
||||
continue;
|
||||
}
|
||||
let tv = type_env.get_equivalent(&var.get_typevar().unwrap());
|
||||
let new_typeset = tv.get_typeset();
|
||||
assert!(
|
||||
new_typeset.is_subset(&src_typeset),
|
||||
"type sets can only get narrower"
|
||||
);
|
||||
if new_typeset != src_typeset {
|
||||
type_env.add_constraint(Constraint::InTypeset(tv.clone(), new_typeset.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
type_env.normalize(var_pool);
|
||||
|
||||
Ok(type_env.extract(var_pool))
|
||||
}
|
||||
|
||||
416
cranelift/codegen/meta/src/cdsl/xform.rs
Normal file
416
cranelift/codegen/meta/src/cdsl/xform.rs
Normal file
@@ -0,0 +1,416 @@
|
||||
use crate::cdsl::ast::{
|
||||
Apply, DefIndex, DefPool, DummyDef, DummyExpr, Expr, PatternPosition, VarIndex, VarPool,
|
||||
};
|
||||
use crate::cdsl::inst::Instruction;
|
||||
use crate::cdsl::type_inference::{infer_transform, TypeEnvironment};
|
||||
use crate::cdsl::typevar::TypeVar;
|
||||
|
||||
use cranelift_entity::{entity_impl, PrimaryMap};
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::iter::FromIterator;
|
||||
|
||||
/// An instruction transformation consists of a source and destination pattern.
|
||||
///
|
||||
/// Patterns are expressed in *register transfer language* as tuples of Def or Expr nodes. A
|
||||
/// pattern may optionally have a sequence of TypeConstraints, that additionally limit the set of
|
||||
/// cases when it applies.
|
||||
///
|
||||
/// The source pattern can contain only a single instruction.
|
||||
pub struct Transform {
|
||||
pub src: DefIndex,
|
||||
pub dst: Vec<DefIndex>,
|
||||
pub var_pool: VarPool,
|
||||
pub def_pool: DefPool,
|
||||
pub type_env: TypeEnvironment,
|
||||
}
|
||||
|
||||
type SymbolTable = HashMap<&'static str, VarIndex>;
|
||||
|
||||
impl Transform {
|
||||
fn new(src: DummyDef, dst: Vec<DummyDef>) -> Self {
|
||||
let mut var_pool = VarPool::new();
|
||||
let mut def_pool = DefPool::new();
|
||||
|
||||
let mut input_vars: Vec<VarIndex> = Vec::new();
|
||||
let mut defined_vars: Vec<VarIndex> = Vec::new();
|
||||
|
||||
// Maps variable names to our own Var copies.
|
||||
let mut symbol_table: SymbolTable = SymbolTable::new();
|
||||
|
||||
// Rewrite variables in src and dst using our own copies.
|
||||
let src = rewrite_def_list(
|
||||
PatternPosition::Source,
|
||||
vec![src],
|
||||
&mut symbol_table,
|
||||
&mut input_vars,
|
||||
&mut defined_vars,
|
||||
&mut var_pool,
|
||||
&mut def_pool,
|
||||
)[0];
|
||||
|
||||
let num_src_inputs = input_vars.len();
|
||||
|
||||
let dst = rewrite_def_list(
|
||||
PatternPosition::Destination,
|
||||
dst,
|
||||
&mut symbol_table,
|
||||
&mut input_vars,
|
||||
&mut defined_vars,
|
||||
&mut var_pool,
|
||||
&mut def_pool,
|
||||
);
|
||||
|
||||
// Sanity checks.
|
||||
for &var_index in &input_vars {
|
||||
assert!(
|
||||
var_pool.get(var_index).is_input(),
|
||||
format!("'{:?}' used as both input and def", var_pool.get(var_index))
|
||||
);
|
||||
}
|
||||
assert!(
|
||||
input_vars.len() == num_src_inputs,
|
||||
format!(
|
||||
"extra input vars in dst pattern: {:?}",
|
||||
input_vars
|
||||
.iter()
|
||||
.map(|&i| var_pool.get(i))
|
||||
.skip(num_src_inputs)
|
||||
.collect::<Vec<_>>()
|
||||
)
|
||||
);
|
||||
|
||||
// Perform type inference and cleanup.
|
||||
let type_env = infer_transform(src, &dst, &def_pool, &mut var_pool).unwrap();
|
||||
|
||||
// Sanity check: the set of inferred free type variables should be a subset of the type
|
||||
// variables corresponding to Vars appearing in the source pattern.
|
||||
{
|
||||
let free_typevars: HashSet<TypeVar> =
|
||||
HashSet::from_iter(type_env.free_typevars(&mut var_pool));
|
||||
let src_tvs = HashSet::from_iter(
|
||||
input_vars
|
||||
.clone()
|
||||
.iter()
|
||||
.chain(
|
||||
defined_vars
|
||||
.iter()
|
||||
.filter(|&&var_index| !var_pool.get(var_index).is_temp()),
|
||||
)
|
||||
.map(|&var_index| var_pool.get(var_index).get_typevar())
|
||||
.filter(|maybe_var| maybe_var.is_some())
|
||||
.map(|var| var.unwrap()),
|
||||
);
|
||||
if !free_typevars.is_subset(&src_tvs) {
|
||||
let missing_tvs = (&free_typevars - &src_tvs)
|
||||
.iter()
|
||||
.map(|tv| tv.name.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
panic!("Some free vars don't appear in src: {}", missing_tvs);
|
||||
}
|
||||
}
|
||||
|
||||
for &var_index in input_vars.iter().chain(defined_vars.iter()) {
|
||||
let var = var_pool.get_mut(var_index);
|
||||
let canon_tv = type_env.get_equivalent(&var.get_or_create_typevar());
|
||||
var.set_typevar(canon_tv);
|
||||
}
|
||||
|
||||
Self {
|
||||
src,
|
||||
dst,
|
||||
var_pool,
|
||||
def_pool,
|
||||
type_env,
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_legalize(&self) {
|
||||
let def = self.def_pool.get(self.src);
|
||||
for &var_index in def.defined_vars.iter() {
|
||||
let defined_var = self.var_pool.get(var_index);
|
||||
assert!(
|
||||
defined_var.is_output(),
|
||||
format!("{:?} not defined in the destination pattern", defined_var)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Given a list of symbols defined in a Def, rewrite them to local symbols. Yield the new locals.
|
||||
fn rewrite_defined_vars(
|
||||
position: PatternPosition,
|
||||
dummy_def: &DummyDef,
|
||||
def_index: DefIndex,
|
||||
symbol_table: &mut SymbolTable,
|
||||
defined_vars: &mut Vec<VarIndex>,
|
||||
var_pool: &mut VarPool,
|
||||
) -> Vec<VarIndex> {
|
||||
let mut new_defined_vars = Vec::new();
|
||||
for var in &dummy_def.defined_vars {
|
||||
let own_var = match symbol_table.get(var.name) {
|
||||
Some(&existing_var) => existing_var,
|
||||
None => {
|
||||
// Materialize the variable.
|
||||
let new_var = var_pool.create(var.name);
|
||||
symbol_table.insert(var.name, new_var);
|
||||
defined_vars.push(new_var);
|
||||
new_var
|
||||
}
|
||||
};
|
||||
var_pool.get_mut(own_var).set_def(position, def_index);
|
||||
new_defined_vars.push(own_var);
|
||||
}
|
||||
new_defined_vars
|
||||
}
|
||||
|
||||
/// Find all uses of variables in `expr` and replace them with our own local symbols.
|
||||
fn rewrite_expr(
|
||||
position: PatternPosition,
|
||||
dummy_expr: DummyExpr,
|
||||
symbol_table: &mut SymbolTable,
|
||||
input_vars: &mut Vec<VarIndex>,
|
||||
var_pool: &mut VarPool,
|
||||
) -> Apply {
|
||||
let (apply_target, dummy_args) = if let DummyExpr::Apply(apply_target, dummy_args) = dummy_expr
|
||||
{
|
||||
(apply_target, dummy_args)
|
||||
} else {
|
||||
panic!("we only rewrite apply expressions");
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
apply_target.inst().operands_in.len(),
|
||||
dummy_args.len(),
|
||||
"number of arguments in instruction is incorrect"
|
||||
);
|
||||
|
||||
let mut args = Vec::new();
|
||||
for (i, arg) in dummy_args.into_iter().enumerate() {
|
||||
match arg {
|
||||
DummyExpr::Var(var) => {
|
||||
let own_var = match symbol_table.get(var.name) {
|
||||
Some(&own_var) => {
|
||||
let var = var_pool.get(own_var);
|
||||
assert!(
|
||||
var.is_input() || var.get_def(position).is_some(),
|
||||
format!("{:?} used as both input and def", var)
|
||||
);
|
||||
own_var
|
||||
}
|
||||
None => {
|
||||
// First time we're using this variable.
|
||||
let own_var = var_pool.create(var.name);
|
||||
symbol_table.insert(var.name, own_var);
|
||||
input_vars.push(own_var);
|
||||
own_var
|
||||
}
|
||||
};
|
||||
args.push(Expr::Var(own_var));
|
||||
}
|
||||
DummyExpr::Literal(literal) => {
|
||||
assert!(!apply_target.inst().operands_in[i].is_value());
|
||||
args.push(Expr::Literal(literal));
|
||||
}
|
||||
DummyExpr::Apply(..) => {
|
||||
panic!("Recursive apply is not allowed.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Apply::new(apply_target, args)
|
||||
}
|
||||
|
||||
fn rewrite_def_list(
|
||||
position: PatternPosition,
|
||||
dummy_defs: Vec<DummyDef>,
|
||||
symbol_table: &mut SymbolTable,
|
||||
input_vars: &mut Vec<VarIndex>,
|
||||
defined_vars: &mut Vec<VarIndex>,
|
||||
var_pool: &mut VarPool,
|
||||
def_pool: &mut DefPool,
|
||||
) -> Vec<DefIndex> {
|
||||
let mut new_defs = Vec::new();
|
||||
for dummy_def in dummy_defs {
|
||||
let def_index = def_pool.next_index();
|
||||
|
||||
let new_defined_vars = rewrite_defined_vars(
|
||||
position,
|
||||
&dummy_def,
|
||||
def_index,
|
||||
symbol_table,
|
||||
defined_vars,
|
||||
var_pool,
|
||||
);
|
||||
let new_apply = rewrite_expr(position, dummy_def.expr, symbol_table, input_vars, var_pool);
|
||||
|
||||
assert!(
|
||||
def_pool.next_index() == def_index,
|
||||
"shouldn't have created new defs in the meanwhile"
|
||||
);
|
||||
assert_eq!(
|
||||
new_apply.inst.value_results.len(),
|
||||
new_defined_vars.len(),
|
||||
"number of Var results in instruction is incorrect"
|
||||
);
|
||||
|
||||
new_defs.push(def_pool.create(new_apply, new_defined_vars));
|
||||
}
|
||||
new_defs
|
||||
}
|
||||
|
||||
/// A group of related transformations.
|
||||
pub struct TransformGroup {
|
||||
pub name: &'static str,
|
||||
pub doc: &'static str,
|
||||
pub chain_with: Option<TransformGroupIndex>,
|
||||
pub isa_name: Option<&'static str>,
|
||||
pub id: TransformGroupIndex,
|
||||
|
||||
/// Maps Instruction camel_case names to custom legalization functions names.
|
||||
pub custom_legalizes: HashMap<String, &'static str>,
|
||||
pub transforms: Vec<Transform>,
|
||||
}
|
||||
|
||||
impl TransformGroup {
|
||||
pub fn rust_name(&self) -> String {
|
||||
match self.isa_name {
|
||||
Some(_) => {
|
||||
// This is a function in the same module as the LEGALIZE_ACTIONS table referring to
|
||||
// it.
|
||||
self.name.to_string()
|
||||
}
|
||||
None => format!("crate::legalizer::{}", self.name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub struct TransformGroupIndex(u32);
|
||||
entity_impl!(TransformGroupIndex);
|
||||
|
||||
pub struct TransformGroupBuilder {
|
||||
name: &'static str,
|
||||
doc: &'static str,
|
||||
chain_with: Option<TransformGroupIndex>,
|
||||
isa_name: Option<&'static str>,
|
||||
pub custom_legalizes: HashMap<String, &'static str>,
|
||||
pub transforms: Vec<Transform>,
|
||||
}
|
||||
|
||||
impl TransformGroupBuilder {
|
||||
pub fn new(name: &'static str, doc: &'static str) -> Self {
|
||||
Self {
|
||||
name,
|
||||
doc,
|
||||
chain_with: None,
|
||||
isa_name: None,
|
||||
custom_legalizes: HashMap::new(),
|
||||
transforms: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chain_with(mut self, next_id: TransformGroupIndex) -> Self {
|
||||
assert!(self.chain_with.is_none());
|
||||
self.chain_with = Some(next_id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn isa(mut self, isa_name: &'static str) -> Self {
|
||||
assert!(self.isa_name.is_none());
|
||||
self.isa_name = Some(isa_name);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a custom legalization action for `inst`.
|
||||
///
|
||||
/// The `func_name` parameter is the fully qualified name of a Rust function which takes the
|
||||
/// same arguments as the `isa::Legalize` actions.
|
||||
///
|
||||
/// The custom function will be called to legalize `inst` and any return value is ignored.
|
||||
pub fn custom_legalize(&mut self, inst: &Instruction, func_name: &'static str) {
|
||||
assert!(
|
||||
self.custom_legalizes
|
||||
.insert(inst.camel_name.clone(), func_name)
|
||||
.is_none(),
|
||||
format!(
|
||||
"custom legalization action for {} inserted twice",
|
||||
inst.name
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/// Add a legalization pattern to this group.
|
||||
pub fn legalize(&mut self, src: DummyDef, dst: Vec<DummyDef>) {
|
||||
let transform = Transform::new(src, dst);
|
||||
transform.verify_legalize();
|
||||
self.transforms.push(transform);
|
||||
}
|
||||
|
||||
pub fn finish_and_add_to(self, owner: &mut TransformGroups) -> TransformGroupIndex {
|
||||
let next_id = owner.next_key();
|
||||
owner.add(TransformGroup {
|
||||
name: self.name,
|
||||
doc: self.doc,
|
||||
isa_name: self.isa_name,
|
||||
id: next_id,
|
||||
chain_with: self.chain_with,
|
||||
custom_legalizes: self.custom_legalizes,
|
||||
transforms: self.transforms,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TransformGroups {
|
||||
groups: PrimaryMap<TransformGroupIndex, TransformGroup>,
|
||||
}
|
||||
|
||||
impl TransformGroups {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
groups: PrimaryMap::new(),
|
||||
}
|
||||
}
|
||||
pub fn add(&mut self, new_group: TransformGroup) -> TransformGroupIndex {
|
||||
for group in self.groups.values() {
|
||||
assert!(
|
||||
group.name != new_group.name,
|
||||
format!("trying to insert {} for the second time", new_group.name)
|
||||
);
|
||||
}
|
||||
self.groups.push(new_group)
|
||||
}
|
||||
pub fn get(&self, id: TransformGroupIndex) -> &TransformGroup {
|
||||
&self.groups[id]
|
||||
}
|
||||
pub fn get_mut(&mut self, id: TransformGroupIndex) -> &mut TransformGroup {
|
||||
self.groups.get_mut(id).unwrap()
|
||||
}
|
||||
fn next_key(&self) -> TransformGroupIndex {
|
||||
self.groups.next_key()
|
||||
}
|
||||
pub fn by_name(&self, name: &'static str) -> &TransformGroup {
|
||||
for group in self.groups.values() {
|
||||
if group.name == name {
|
||||
return group;
|
||||
}
|
||||
}
|
||||
panic!(format!("transform group with name {} not found", name));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_double_custom_legalization() {
|
||||
use crate::cdsl::formats::{FormatRegistry, InstructionFormatBuilder};
|
||||
use crate::cdsl::inst::InstructionBuilder;
|
||||
|
||||
let mut format = FormatRegistry::new();
|
||||
format.insert(InstructionFormatBuilder::new("nullary"));
|
||||
let dummy_inst = InstructionBuilder::new("dummy", "doc").finish(&format);
|
||||
|
||||
let mut transform_group = TransformGroupBuilder::new("test", "doc");
|
||||
transform_group.custom_legalize(&dummy_inst, "custom 1");
|
||||
transform_group.custom_legalize(&dummy_inst, "custom 2");
|
||||
}
|
||||
Reference in New Issue
Block a user