lib/codegen-meta moved into lib/codegen. (#423)
* lib/codegen-meta moved into lib/codegen. * Renamed codegen-meta and existing meta.
This commit is contained in:
59
lib/codegen/meta-python/cdsl/__init__.py
Normal file
59
lib/codegen/meta-python/cdsl/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Cranelift DSL classes.
|
||||
|
||||
This module defines the classes that are used to define Cranelift instructions
|
||||
and other entitties.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
import re
|
||||
|
||||
|
||||
camel_re = re.compile('(^|_)([a-z])')
|
||||
|
||||
|
||||
def camel_case(s):
|
||||
# type: (str) -> str
|
||||
"""Convert the string s to CamelCase:
|
||||
>>> camel_case('x')
|
||||
'X'
|
||||
>>> camel_case('camel_case')
|
||||
'CamelCase'
|
||||
"""
|
||||
return camel_re.sub(lambda m: m.group(2).upper(), s)
|
||||
|
||||
|
||||
def is_power_of_two(x):
|
||||
# type: (int) -> bool
|
||||
"""Check if `x` is a power of two:
|
||||
>>> is_power_of_two(0)
|
||||
False
|
||||
>>> is_power_of_two(1)
|
||||
True
|
||||
>>> is_power_of_two(2)
|
||||
True
|
||||
>>> is_power_of_two(3)
|
||||
False
|
||||
"""
|
||||
return x > 0 and x & (x-1) == 0
|
||||
|
||||
|
||||
def next_power_of_two(x):
|
||||
# type: (int) -> int
|
||||
"""
|
||||
Compute the next power of two that is greater than `x`:
|
||||
>>> next_power_of_two(0)
|
||||
1
|
||||
>>> next_power_of_two(1)
|
||||
2
|
||||
>>> next_power_of_two(2)
|
||||
4
|
||||
>>> next_power_of_two(3)
|
||||
4
|
||||
>>> next_power_of_two(4)
|
||||
8
|
||||
"""
|
||||
s = 1
|
||||
while x & (x + 1) != 0:
|
||||
x |= x >> s
|
||||
s *= 2
|
||||
return x + 1
|
||||
577
lib/codegen/meta-python/cdsl/ast.py
Normal file
577
lib/codegen/meta-python/cdsl/ast.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
Abstract syntax trees.
|
||||
|
||||
This module defines classes that can be used to create abstract syntax trees
|
||||
for patern matching an rewriting of cranelift instructions.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from . import instructions
|
||||
from .typevar import TypeVar
|
||||
from .predicates import IsEqual, And, TypePredicate, CtrlTypePredicate
|
||||
|
||||
try:
|
||||
from typing import Union, Tuple, Sequence, TYPE_CHECKING, Dict, List # noqa
|
||||
from typing import Optional, Set, Any # noqa
|
||||
if TYPE_CHECKING:
|
||||
from .operands import ImmediateKind # noqa
|
||||
from .predicates import PredNode # noqa
|
||||
VarAtomMap = Dict["Var", "Atom"]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def replace_var(arg, m):
|
||||
# type: (Expr, VarAtomMap) -> Expr
|
||||
"""
|
||||
Given a var v return either m[v] or a new variable v' (and remember
|
||||
m[v]=v'). Otherwise return the argument unchanged
|
||||
"""
|
||||
if isinstance(arg, Var):
|
||||
new_arg = m.get(arg, Var(arg.name)) # type: Atom
|
||||
m[arg] = new_arg
|
||||
return new_arg
|
||||
return arg
|
||||
|
||||
|
||||
class Def(object):
|
||||
"""
|
||||
An AST definition associates a set of variables with the values produced by
|
||||
an expression.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from base.instructions import iadd_cout, iconst
|
||||
>>> x = Var('x')
|
||||
>>> y = Var('y')
|
||||
>>> x << iconst(4)
|
||||
(Var(x),) << Apply(iconst, (4,))
|
||||
>>> (x, y) << iadd_cout(4, 5)
|
||||
(Var(x), Var(y)) << Apply(iadd_cout, (4, 5))
|
||||
|
||||
The `<<` operator is used to create variable definitions.
|
||||
|
||||
:param defs: Single variable or tuple of variables to be defined.
|
||||
:param expr: Expression generating the values.
|
||||
"""
|
||||
|
||||
def __init__(self, defs, expr):
|
||||
# type: (Union[Var, Tuple[Var, ...]], Apply) -> None
|
||||
if not isinstance(defs, tuple):
|
||||
self.defs = (defs,) # type: Tuple[Var, ...]
|
||||
else:
|
||||
self.defs = defs
|
||||
assert isinstance(expr, Apply)
|
||||
self.expr = expr
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return "{} << {!r}".format(self.defs, self.expr)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
if len(self.defs) == 1:
|
||||
return "{!s} << {!s}".format(self.defs[0], self.expr)
|
||||
else:
|
||||
return "({}) << {!s}".format(
|
||||
', '.join(map(str, self.defs)), self.expr)
|
||||
|
||||
def copy(self, m):
|
||||
# type: (VarAtomMap) -> Def
|
||||
"""
|
||||
Return a copy of this Def with vars replaced with fresh variables,
|
||||
in accordance with the map m. Update m as neccessary.
|
||||
"""
|
||||
new_expr = self.expr.copy(m)
|
||||
new_defs = [] # type: List[Var]
|
||||
for v in self.defs:
|
||||
new_v = replace_var(v, m)
|
||||
assert(isinstance(new_v, Var))
|
||||
new_defs.append(new_v)
|
||||
|
||||
return Def(tuple(new_defs), new_expr)
|
||||
|
||||
def definitions(self):
|
||||
# type: () -> Set[Var]
|
||||
""" Return the set of all Vars that are defined by self"""
|
||||
return set(self.defs)
|
||||
|
||||
def uses(self):
|
||||
# type: () -> Set[Var]
|
||||
""" Return the set of all Vars that are used(read) by self"""
|
||||
return set(self.expr.vars())
|
||||
|
||||
def vars(self):
|
||||
# type: () -> Set[Var]
|
||||
"""Return the set of all Vars in self that correspond to SSA values"""
|
||||
return self.definitions().union(self.uses())
|
||||
|
||||
def substitution(self, other, s):
|
||||
# type: (Def, VarAtomMap) -> Optional[VarAtomMap]
|
||||
"""
|
||||
If the Defs self and other agree structurally, return a variable
|
||||
substitution to transform self to other. Otherwise return None. Two
|
||||
Defs agree structurally if there exists a Var substitution, that can
|
||||
transform one into the other. See Apply.substitution() for more
|
||||
details.
|
||||
"""
|
||||
s = self.expr.substitution(other.expr, s)
|
||||
|
||||
if (s is None):
|
||||
return s
|
||||
|
||||
assert len(self.defs) == len(other.defs)
|
||||
for (self_d, other_d) in zip(self.defs, other.defs):
|
||||
assert self_d not in s # Guaranteed by SSA form
|
||||
s[self_d] = other_d
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class Expr(object):
|
||||
"""
|
||||
An AST expression.
|
||||
"""
|
||||
|
||||
|
||||
class Atom(Expr):
|
||||
"""
|
||||
An Atom in the DSL is either a literal or a Var
|
||||
"""
|
||||
|
||||
|
||||
class Var(Atom):
|
||||
"""
|
||||
A free variable.
|
||||
|
||||
When variables are used in `XForms` with source and destination patterns,
|
||||
they are classified as follows:
|
||||
|
||||
Input values
|
||||
Uses in the source pattern with no preceding def. These may appear as
|
||||
inputs in the destination pattern too, but no new inputs can be
|
||||
introduced.
|
||||
Output values
|
||||
Variables that are defined in both the source and destination pattern.
|
||||
These values may have uses outside the source pattern, and the
|
||||
destination pattern must compute the same value.
|
||||
Intermediate values
|
||||
Values that are defined in the source pattern, but not in the
|
||||
destination pattern. These may have uses outside the source pattern, so
|
||||
the defining instruction can't be deleted immediately.
|
||||
Temporary values
|
||||
Values that are defined only in the destination pattern.
|
||||
"""
|
||||
|
||||
def __init__(self, name, typevar=None):
|
||||
# type: (str, TypeVar) -> None
|
||||
self.name = name
|
||||
# The `Def` defining this variable in a source pattern.
|
||||
self.src_def = None # type: Def
|
||||
# The `Def` defining this variable in a destination pattern.
|
||||
self.dst_def = None # type: Def
|
||||
# TypeVar representing the type of this variable.
|
||||
self.typevar = typevar # type: TypeVar
|
||||
# The original 'typeof(x)' type variable that was created for this Var.
|
||||
# This one doesn't change. `self.typevar` above may be changed to
|
||||
# another typevar by type inference.
|
||||
self.original_typevar = self.typevar # type: TypeVar
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
s = self.name
|
||||
if self.src_def:
|
||||
s += ", src"
|
||||
if self.dst_def:
|
||||
s += ", dst"
|
||||
return "Var({})".format(s)
|
||||
|
||||
# Context bits for `set_def` indicating which pattern has defines of this
|
||||
# var.
|
||||
SRCCTX = 1
|
||||
DSTCTX = 2
|
||||
|
||||
def set_def(self, context, d):
|
||||
# type: (int, Def) -> None
|
||||
"""
|
||||
Set the `Def` that defines this variable in the given context.
|
||||
|
||||
The `context` must be one of `SRCCTX` or `DSTCTX`
|
||||
"""
|
||||
if context == self.SRCCTX:
|
||||
self.src_def = d
|
||||
else:
|
||||
self.dst_def = d
|
||||
|
||||
def get_def(self, context):
|
||||
# type: (int) -> Def
|
||||
"""
|
||||
Get the def of this variable in context.
|
||||
|
||||
The `context` must be one of `SRCCTX` or `DSTCTX`
|
||||
"""
|
||||
if context == self.SRCCTX:
|
||||
return self.src_def
|
||||
else:
|
||||
return self.dst_def
|
||||
|
||||
def is_input(self):
|
||||
# type: () -> bool
|
||||
"""Is this an input value to the src pattern?"""
|
||||
return self.src_def is None and self.dst_def is None
|
||||
|
||||
def is_output(self):
|
||||
# type: () -> bool
|
||||
"""Is this an output value, defined in both src and dst patterns?"""
|
||||
return self.src_def is not None and self.dst_def is not None
|
||||
|
||||
def is_intermediate(self):
|
||||
# type: () -> bool
|
||||
"""Is this an intermediate value, defined only in the src pattern?"""
|
||||
return self.src_def is not None and self.dst_def is None
|
||||
|
||||
def is_temp(self):
|
||||
# type: () -> bool
|
||||
"""Is this a temp value, defined only in the dst pattern?"""
|
||||
return self.src_def is None and self.dst_def is not None
|
||||
|
||||
def get_typevar(self):
|
||||
# type: () -> TypeVar
|
||||
"""Get the type variable representing the type of this variable."""
|
||||
if not self.typevar:
|
||||
# Create a TypeVar allowing all types.
|
||||
tv = TypeVar(
|
||||
'typeof_{}'.format(self),
|
||||
'Type of the pattern variable `{}`'.format(self),
|
||||
ints=True, floats=True, bools=True,
|
||||
scalars=True, simd=True, bitvecs=True,
|
||||
specials=True)
|
||||
self.original_typevar = tv
|
||||
self.typevar = tv
|
||||
return self.typevar
|
||||
|
||||
def set_typevar(self, tv):
|
||||
# type: (TypeVar) -> None
|
||||
self.typevar = tv
|
||||
|
||||
def has_free_typevar(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Check if this variable has a free type variable.
|
||||
|
||||
If not, the type of this variable is computed from the type of another
|
||||
variable.
|
||||
"""
|
||||
if not self.typevar or self.typevar.is_derived:
|
||||
return False
|
||||
return self.typevar is self.original_typevar
|
||||
|
||||
def rust_type(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Get a Rust expression that computes the type of this variable.
|
||||
|
||||
It is assumed that local variables exist corresponding to the free type
|
||||
variables.
|
||||
"""
|
||||
return self.typevar.rust_expr()
|
||||
|
||||
|
||||
class Apply(Expr):
|
||||
"""
|
||||
Apply an instruction to arguments.
|
||||
|
||||
An `Apply` AST expression is created by using function call syntax on
|
||||
instructions. This applies to both bound and unbound polymorphic
|
||||
instructions:
|
||||
|
||||
>>> from base.instructions import jump, iadd
|
||||
>>> jump('next', ())
|
||||
Apply(jump, ('next', ()))
|
||||
>>> iadd.i32('x', 'y')
|
||||
Apply(iadd.i32, ('x', 'y'))
|
||||
|
||||
:param inst: The instruction being applied, an `Instruction` or
|
||||
`BoundInstruction` instance.
|
||||
:param args: Tuple of arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, inst, args):
|
||||
# type: (instructions.MaybeBoundInst, Tuple[Expr, ...]) -> None # noqa
|
||||
if isinstance(inst, instructions.BoundInstruction):
|
||||
self.inst = inst.inst
|
||||
self.typevars = inst.typevars
|
||||
else:
|
||||
assert isinstance(inst, instructions.Instruction)
|
||||
self.inst = inst
|
||||
self.typevars = ()
|
||||
self.args = args
|
||||
assert len(self.inst.ins) == len(args)
|
||||
|
||||
# Check that the kinds of Literals arguments match the expected Operand
|
||||
for op_idx in self.inst.imm_opnums:
|
||||
arg = self.args[op_idx]
|
||||
op = self.inst.ins[op_idx]
|
||||
|
||||
if isinstance(arg, Literal):
|
||||
assert arg.kind == op.kind, \
|
||||
"Passing literal {} to field of wrong kind {}."\
|
||||
.format(arg, op.kind)
|
||||
|
||||
def __rlshift__(self, other):
|
||||
# type: (Union[Var, Tuple[Var, ...]]) -> Def
|
||||
"""
|
||||
Define variables using `var << expr` or `(v1, v2) << expr`.
|
||||
"""
|
||||
return Def(other, self)
|
||||
|
||||
def instname(self):
|
||||
# type: () -> str
|
||||
i = self.inst.name
|
||||
for t in self.typevars:
|
||||
i += '.{}'.format(t)
|
||||
return i
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return "Apply({}, {})".format(self.instname(), self.args)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
args = ', '.join(map(str, self.args))
|
||||
return '{}({})'.format(self.instname(), args)
|
||||
|
||||
def rust_builder(self, defs=None):
|
||||
# type: (Sequence[Var]) -> str
|
||||
"""
|
||||
Return a Rust Builder method call for instantiating this instruction
|
||||
application.
|
||||
|
||||
The `defs` argument should be a list of variables defined by this
|
||||
instruction. It is used to construct a result type if necessary.
|
||||
"""
|
||||
args = ', '.join(map(str, self.args))
|
||||
# Do we need to pass an explicit type argument?
|
||||
if self.inst.is_polymorphic and not self.inst.use_typevar_operand:
|
||||
args = defs[0].rust_type() + ', ' + args
|
||||
method = self.inst.snake_name()
|
||||
return '{}({})'.format(method, args)
|
||||
|
||||
def inst_predicate(self):
|
||||
# type: () -> PredNode
|
||||
"""
|
||||
Construct an instruction predicate that verifies the immediate operands
|
||||
on this instruction.
|
||||
|
||||
Immediate operands in a source pattern can be either free variables or
|
||||
constants like `ConstantInt` and `Enumerator`. We don't currently
|
||||
support constraints on free variables, but we may in the future.
|
||||
"""
|
||||
pred = None # type: PredNode
|
||||
iform = self.inst.format
|
||||
|
||||
# Examine all of the immediate operands.
|
||||
for ffield, opnum in zip(iform.imm_fields, self.inst.imm_opnums):
|
||||
arg = self.args[opnum]
|
||||
|
||||
# Ignore free variables for now. We may add variable predicates
|
||||
# later.
|
||||
if isinstance(arg, Var):
|
||||
continue
|
||||
|
||||
pred = And.combine(pred, IsEqual(ffield, arg))
|
||||
|
||||
# Add checks for any bound secondary type variables.
|
||||
# We can't check the controlling type variable this way since it may
|
||||
# not appear as the type of an operand.
|
||||
if len(self.typevars) > 1:
|
||||
for bound_ty, tv in zip(self.typevars[1:],
|
||||
self.inst.other_typevars):
|
||||
if bound_ty is None:
|
||||
continue
|
||||
type_chk = TypePredicate.typevar_check(self.inst, tv, bound_ty)
|
||||
pred = And.combine(pred, type_chk)
|
||||
|
||||
return pred
|
||||
|
||||
def inst_predicate_with_ctrl_typevar(self):
|
||||
# type: () -> PredNode
|
||||
"""
|
||||
Same as `inst_predicate()`, but also check the controlling type
|
||||
variable.
|
||||
"""
|
||||
pred = self.inst_predicate()
|
||||
|
||||
if len(self.typevars) > 0:
|
||||
bound_ty = self.typevars[0]
|
||||
type_chk = None # type: PredNode
|
||||
if bound_ty is not None:
|
||||
# Prefer to look at the types of input operands.
|
||||
if self.inst.use_typevar_operand:
|
||||
type_chk = TypePredicate.typevar_check(
|
||||
self.inst, self.inst.ctrl_typevar, bound_ty)
|
||||
else:
|
||||
type_chk = CtrlTypePredicate(bound_ty)
|
||||
pred = And.combine(pred, type_chk)
|
||||
|
||||
return pred
|
||||
|
||||
def copy(self, m):
|
||||
# type: (VarAtomMap) -> Apply
|
||||
"""
|
||||
Return a copy of this Expr with vars replaced with fresh variables,
|
||||
in accordance with the map m. Update m as neccessary.
|
||||
"""
|
||||
return Apply(self.inst, tuple(map(lambda e: replace_var(e, m),
|
||||
self.args)))
|
||||
|
||||
def vars(self):
|
||||
# type: () -> Set[Var]
|
||||
"""Return the set of all Vars in self that correspond to SSA values"""
|
||||
res = set()
|
||||
for i in self.inst.value_opnums:
|
||||
arg = self.args[i]
|
||||
assert isinstance(arg, Var)
|
||||
res.add(arg)
|
||||
return res
|
||||
|
||||
def substitution(self, other, s):
|
||||
# type: (Apply, VarAtomMap) -> Optional[VarAtomMap]
|
||||
"""
|
||||
If there is a substituion from Var->Atom that converts self to other,
|
||||
return it, otherwise return None. Note that this is strictly weaker
|
||||
than unification (see TestXForm.test_subst_enum_bad_var_const for
|
||||
example).
|
||||
"""
|
||||
if self.inst != other.inst:
|
||||
return None
|
||||
|
||||
# Guaranteed by self.inst == other.inst
|
||||
assert (len(self.args) == len(other.args))
|
||||
|
||||
for (self_a, other_a) in zip(self.args, other.args):
|
||||
assert isinstance(self_a, Atom) and isinstance(other_a, Atom)
|
||||
|
||||
if (isinstance(self_a, Var)):
|
||||
if (self_a not in s):
|
||||
s[self_a] = other_a
|
||||
else:
|
||||
if (s[self_a] != other_a):
|
||||
return None
|
||||
elif isinstance(other_a, Var):
|
||||
assert isinstance(self_a, Literal)
|
||||
if (other_a not in s):
|
||||
s[other_a] = self_a
|
||||
else:
|
||||
if s[other_a] != self_a:
|
||||
return None
|
||||
else:
|
||||
assert (isinstance(self_a, Literal) and
|
||||
isinstance(other_a, Literal))
|
||||
# Guaranteed by self.inst == other.inst
|
||||
assert self_a.kind == other_a.kind
|
||||
if (self_a.value != other_a.value):
|
||||
return None
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class Literal(Atom):
|
||||
"""
|
||||
Base Class for all literal expressions in the DSL.
|
||||
"""
|
||||
def __init__(self, kind, value):
|
||||
# type: (ImmediateKind, Any) -> None
|
||||
self.kind = kind
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (Any) -> bool
|
||||
if not isinstance(other, Literal):
|
||||
return False
|
||||
|
||||
if self.kind != other.kind:
|
||||
return False
|
||||
|
||||
# Can't just compare value here, as comparison Any <> Any returns Any
|
||||
return repr(self) == repr(other)
|
||||
|
||||
def __ne__(self, other):
|
||||
# type: (Any) -> bool
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return '{}.{}'.format(self.kind, self.value)
|
||||
|
||||
|
||||
class ConstantInt(Literal):
|
||||
"""
|
||||
A value of an integer immediate operand.
|
||||
|
||||
Immediate operands like `imm64` or `offset32` can be specified in AST
|
||||
expressions using the call syntax: `imm64(5)` which greates a `ConstantInt`
|
||||
node.
|
||||
"""
|
||||
|
||||
def __init__(self, kind, value):
|
||||
# type: (ImmediateKind, int) -> None
|
||||
super(ConstantInt, self).__init__(kind, value)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Get the Rust expression form of this constant.
|
||||
"""
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class ConstantBits(Literal):
|
||||
"""
|
||||
A bitwise value of an immediate operand.
|
||||
|
||||
This is used to create bitwise exact floating point constants using
|
||||
`ieee32.bits(0x80000000)`.
|
||||
"""
|
||||
|
||||
def __init__(self, kind, bits):
|
||||
# type: (ImmediateKind, int) -> None
|
||||
v = '{}::with_bits({:#x})'.format(kind.rust_type, bits)
|
||||
super(ConstantBits, self).__init__(kind, v)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Get the Rust expression form of this constant.
|
||||
"""
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class Enumerator(Literal):
|
||||
"""
|
||||
A value of an enumerated immediate operand.
|
||||
|
||||
Some immediate operand kinds like `intcc` and `floatcc` have an enumerated
|
||||
range of values corresponding to a Rust enum type. An `Enumerator` object
|
||||
is an AST leaf node representing one of the values.
|
||||
|
||||
:param kind: The enumerated `ImmediateKind` containing the value.
|
||||
:param value: The textual IR representation of the value.
|
||||
|
||||
`Enumerator` nodes are not usually created directly. They are created by
|
||||
using the dot syntax on immediate kinds: `intcc.ult`.
|
||||
"""
|
||||
|
||||
def __init__(self, kind, value):
|
||||
# type: (ImmediateKind, str) -> None
|
||||
super(Enumerator, self).__init__(kind, value)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Get the Rust expression form of this enumerator.
|
||||
"""
|
||||
return self.kind.rust_enumerator(self.value)
|
||||
268
lib/codegen/meta-python/cdsl/formats.py
Normal file
268
lib/codegen/meta-python/cdsl/formats.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Classes for describing instruction formats."""
|
||||
from __future__ import absolute_import
|
||||
from .operands import OperandKind, VALUE, VARIABLE_ARGS
|
||||
from .operands import Operand # noqa
|
||||
|
||||
# The typing module is only required by mypy, and we don't use these imports
|
||||
# outside type comments.
|
||||
try:
|
||||
from typing import Dict, List, Tuple, Union, Any, Sequence, Iterable # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class InstructionContext(object):
|
||||
"""
|
||||
Most instruction predicates refer to immediate fields of a specific
|
||||
instruction format, so their `predicate_context()` method returns the
|
||||
specific instruction format.
|
||||
|
||||
Predicates that only care about the types of SSA values are independent of
|
||||
the instruction format. They can be evaluated in the context of any
|
||||
instruction.
|
||||
|
||||
The singleton `InstructionContext` class serves as the predicate context
|
||||
for these predicates.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# type: () -> None
|
||||
self.name = 'inst'
|
||||
|
||||
|
||||
# Singleton instance.
|
||||
instruction_context = InstructionContext()
|
||||
|
||||
|
||||
class InstructionFormat(object):
|
||||
"""
|
||||
Every instruction opcode has a corresponding instruction format which
|
||||
determines the number of operands and their kinds. Instruction formats are
|
||||
identified structurally, i.e., the format of an instruction is derived from
|
||||
the kinds of operands used in its declaration.
|
||||
|
||||
The instruction format stores two separate lists of operands: Immediates
|
||||
and values. Immediate operands (including entity references) are
|
||||
represented as explicit members in the `InstructionData` variants. The
|
||||
value operands are stored differently, depending on how many there are.
|
||||
Beyond a certain point, instruction formats switch to an external value
|
||||
list for storing value arguments. Value lists can hold an arbitrary number
|
||||
of values.
|
||||
|
||||
All instruction formats must be predefined in the
|
||||
:py:mod:`cranelift.formats` module.
|
||||
|
||||
:param kinds: List of `OperandKind` objects describing the operands.
|
||||
:param name: Instruction format name in CamelCase. This is used as a Rust
|
||||
variant name in both the `InstructionData` and `InstructionFormat`
|
||||
enums.
|
||||
:param typevar_operand: Index of the value input operand that is used to
|
||||
infer the controlling type variable. By default, this is `0`, the first
|
||||
`value` operand. The index is relative to the values only, ignoring
|
||||
immediate operands.
|
||||
"""
|
||||
|
||||
# Map (imm_kinds, num_value_operands) -> format
|
||||
_registry = dict() # type: Dict[Tuple[Tuple[OperandKind, ...], int, bool], InstructionFormat] # noqa
|
||||
|
||||
# All existing formats.
|
||||
all_formats = list() # type: List[InstructionFormat]
|
||||
|
||||
def __init__(self, *kinds, **kwargs):
|
||||
# type: (*Union[OperandKind, Tuple[str, OperandKind]], **Any) -> None # noqa
|
||||
self.name = kwargs.get('name', None) # type: str
|
||||
self.parent = instruction_context
|
||||
|
||||
# The number of value operands stored in the format, or `None` when
|
||||
# `has_value_list` is set.
|
||||
self.num_value_operands = 0
|
||||
# Does this format use a value list for storing value operands?
|
||||
self.has_value_list = False
|
||||
# Operand fields for the immediate operands. All other instruction
|
||||
# operands are values or variable argument lists. They are all handled
|
||||
# specially.
|
||||
self.imm_fields = tuple(self._process_member_names(kinds))
|
||||
|
||||
# The typevar_operand argument must point to a 'value' operand.
|
||||
self.typevar_operand = kwargs.get('typevar_operand', None) # type: int
|
||||
if self.typevar_operand is not None:
|
||||
if not self.has_value_list:
|
||||
assert self.typevar_operand < self.num_value_operands, \
|
||||
"typevar_operand must indicate a 'value' operand"
|
||||
elif self.has_value_list or self.num_value_operands > 0:
|
||||
# Default to the first 'value' operand, if there is one.
|
||||
self.typevar_operand = 0
|
||||
|
||||
# Compute a signature for the global registry.
|
||||
imm_kinds = tuple(f.kind for f in self.imm_fields)
|
||||
sig = (imm_kinds, self.num_value_operands, self.has_value_list)
|
||||
if sig in InstructionFormat._registry:
|
||||
raise RuntimeError(
|
||||
"Format '{}' has the same signature as existing format '{}'"
|
||||
.format(self.name, InstructionFormat._registry[sig]))
|
||||
InstructionFormat._registry[sig] = self
|
||||
InstructionFormat.all_formats.append(self)
|
||||
|
||||
def args(self):
|
||||
# type: () -> FormatField
|
||||
"""
|
||||
Provides a ValueListField, which is derived from FormatField,
|
||||
corresponding to the full ValueList of the instruction format. This
|
||||
is useful for creating predicates for instructions which use variadic
|
||||
arguments.
|
||||
"""
|
||||
|
||||
if self.has_value_list:
|
||||
return ValueListField(self)
|
||||
return None
|
||||
|
||||
def _process_member_names(self, kinds):
|
||||
# type: (Sequence[Union[OperandKind, Tuple[str, OperandKind]]]) -> Iterable[FormatField] # noqa
|
||||
"""
|
||||
Extract names of all the immediate operands in the kinds tuple.
|
||||
|
||||
Each entry is either an `OperandKind` instance, or a `(member, kind)`
|
||||
pair. The member names correspond to members in the Rust
|
||||
`InstructionData` data structure.
|
||||
|
||||
Updates the fields `self.num_value_operands` and `self.has_value_list`.
|
||||
|
||||
Yields the immediate operand fields.
|
||||
"""
|
||||
inum = 0
|
||||
for arg in kinds:
|
||||
if isinstance(arg, OperandKind):
|
||||
member = arg.default_member
|
||||
k = arg
|
||||
else:
|
||||
member, k = arg
|
||||
|
||||
# We define 'immediate' as not a value or variable arguments.
|
||||
if k is VALUE:
|
||||
self.num_value_operands += 1
|
||||
elif k is VARIABLE_ARGS:
|
||||
self.has_value_list = True
|
||||
else:
|
||||
yield FormatField(self, inum, k, member)
|
||||
inum += 1
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
args = ', '.join(
|
||||
'{}: {}'.format(f.member, f.kind) for f in self.imm_fields)
|
||||
return '{}(imms=({}), vals={})'.format(
|
||||
self.name, args, self.num_value_operands)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
# type: (str) -> FormatField
|
||||
"""
|
||||
Make immediate instruction format members available as attributes.
|
||||
|
||||
Each non-value format member becomes a corresponding `FormatField`
|
||||
attribute.
|
||||
"""
|
||||
for f in self.imm_fields:
|
||||
if f.member == attr:
|
||||
# Cache this field attribute so we won't have to search again.
|
||||
setattr(self, attr, f)
|
||||
return f
|
||||
|
||||
raise AttributeError(
|
||||
'{} is neither a {} member or a '
|
||||
.format(attr, self.name) +
|
||||
'normal InstructionFormat attribute')
|
||||
|
||||
@staticmethod
|
||||
def lookup(ins, outs):
|
||||
# type: (Sequence[Operand], Sequence[Operand]) -> InstructionFormat
|
||||
"""
|
||||
Find an existing instruction format that matches the given lists of
|
||||
instruction inputs and outputs.
|
||||
|
||||
The `ins` and `outs` arguments correspond to the
|
||||
:py:class:`Instruction` arguments of the same name, except they must be
|
||||
tuples of :py:`Operand` objects.
|
||||
"""
|
||||
# Construct a signature.
|
||||
imm_kinds = tuple(op.kind for op in ins if op.is_immediate())
|
||||
num_values = sum(1 for op in ins if op.is_value())
|
||||
has_varargs = (VARIABLE_ARGS in tuple(op.kind for op in ins))
|
||||
|
||||
sig = (imm_kinds, num_values, has_varargs)
|
||||
if sig in InstructionFormat._registry:
|
||||
return InstructionFormat._registry[sig]
|
||||
|
||||
# Try another value list format as an alternative.
|
||||
sig = (imm_kinds, 0, True)
|
||||
if sig in InstructionFormat._registry:
|
||||
return InstructionFormat._registry[sig]
|
||||
|
||||
raise RuntimeError(
|
||||
'No instruction format matches '
|
||||
'imms={}, vals={}, varargs={}'.format(
|
||||
imm_kinds, num_values, has_varargs))
|
||||
|
||||
@staticmethod
|
||||
def extract_names(globs):
|
||||
# type: (Dict[str, Any]) -> None
|
||||
"""
|
||||
Given a dict mapping name -> object as returned by `globals()`, find
|
||||
all the InstructionFormat objects and set their name from the dict key.
|
||||
This is used to name a bunch of global values in a module.
|
||||
"""
|
||||
for name, obj in globs.items():
|
||||
if isinstance(obj, InstructionFormat):
|
||||
assert obj.name is None
|
||||
obj.name = name
|
||||
|
||||
|
||||
class FormatField(object):
|
||||
"""
|
||||
An immediate field in an instruction format.
|
||||
|
||||
This corresponds to a single member of a variant of the `InstructionData`
|
||||
data type.
|
||||
|
||||
:param iform: Parent `InstructionFormat`.
|
||||
:param immnum: Immediate operand number in parent.
|
||||
:param kind: Immediate Operand kind.
|
||||
:param member: Member name in `InstructionData` variant.
|
||||
"""
|
||||
|
||||
def __init__(self, iform, immnum, kind, member):
|
||||
# type: (InstructionFormat, int, OperandKind, str) -> None
|
||||
self.format = iform
|
||||
self.immnum = immnum
|
||||
self.kind = kind
|
||||
self.member = member
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return '{}.{}'.format(self.format.name, self.member)
|
||||
|
||||
def rust_destructuring_name(self):
|
||||
# type: () -> str
|
||||
return self.member
|
||||
|
||||
def rust_name(self):
|
||||
# type: () -> str
|
||||
return self.member
|
||||
|
||||
|
||||
class ValueListField(FormatField):
|
||||
"""
|
||||
The full value list field of an instruction format.
|
||||
|
||||
This corresponds to all Value-type members of a variant of the
|
||||
`InstructionData` format, which contains a ValueList.
|
||||
|
||||
:param iform: Parent `InstructionFormat`.
|
||||
"""
|
||||
def __init__(self, iform):
|
||||
# type: (InstructionFormat) -> None
|
||||
self.format = iform
|
||||
self.member = "args"
|
||||
|
||||
def rust_destructuring_name(self):
|
||||
# type: () -> str
|
||||
return 'ref {}'.format(self.member)
|
||||
440
lib/codegen/meta-python/cdsl/instructions.py
Normal file
440
lib/codegen/meta-python/cdsl/instructions.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""Classes for defining instructions."""
|
||||
from __future__ import absolute_import
|
||||
from . import camel_case
|
||||
from .types import ValueType
|
||||
from .operands import Operand
|
||||
from .formats import InstructionFormat
|
||||
|
||||
try:
|
||||
from typing import Union, Sequence, List, Tuple, Any, TYPE_CHECKING # noqa
|
||||
from typing import Dict # noqa
|
||||
if TYPE_CHECKING:
|
||||
from .ast import Expr, Apply, Var, Def, VarAtomMap # noqa
|
||||
from .typevar import TypeVar # noqa
|
||||
from .ti import TypeConstraint # noqa
|
||||
from .xform import XForm, Rtl
|
||||
# List of operands for ins/outs:
|
||||
OpList = Union[Sequence[Operand], Operand]
|
||||
ConstrList = Union[Sequence[TypeConstraint], TypeConstraint]
|
||||
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
|
||||
InstructionSemantics = Sequence[XForm]
|
||||
SemDefCase = Union[Rtl, Tuple[Rtl, Sequence[TypeConstraint]], XForm]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class InstructionGroup(object):
|
||||
"""
|
||||
Every instruction must belong to exactly one instruction group. A given
|
||||
target architecture can support instructions from multiple groups, and it
|
||||
does not necessarily support all instructions in a group.
|
||||
|
||||
New instructions are automatically added to the currently open instruction
|
||||
group.
|
||||
"""
|
||||
|
||||
# The currently open instruction group.
|
||||
_current = None # type: InstructionGroup
|
||||
|
||||
def open(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Open this instruction group such that future new instructions are
|
||||
added to this group.
|
||||
"""
|
||||
assert InstructionGroup._current is None, (
|
||||
"Can't open {} since {} is already open"
|
||||
.format(self, InstructionGroup._current))
|
||||
InstructionGroup._current = self
|
||||
|
||||
def close(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Close this instruction group. This function should be called before
|
||||
opening another instruction group.
|
||||
"""
|
||||
assert InstructionGroup._current is self, (
|
||||
"Can't close {}, the open instuction group is {}"
|
||||
.format(self, InstructionGroup._current))
|
||||
InstructionGroup._current = None
|
||||
|
||||
def __init__(self, name, doc):
|
||||
# type: (str, str) -> None
|
||||
self.name = name
|
||||
self.__doc__ = doc
|
||||
self.instructions = [] # type: List[Instruction]
|
||||
self.open()
|
||||
|
||||
@staticmethod
|
||||
def append(inst):
|
||||
# type: (Instruction) -> None
|
||||
assert InstructionGroup._current, \
|
||||
"Open an instruction group before defining instructions."
|
||||
InstructionGroup._current.instructions.append(inst)
|
||||
|
||||
|
||||
class Instruction(object):
|
||||
"""
|
||||
The operands to the instruction are specified as two tuples: ``ins`` and
|
||||
``outs``. Since the Python singleton tuple syntax is a bit awkward, it is
|
||||
allowed to specify a singleton as just the operand itself, i.e., `ins=x`
|
||||
and `ins=(x,)` are both allowed and mean the same thing.
|
||||
|
||||
:param name: Instruction mnemonic, also becomes opcode name.
|
||||
:param doc: Documentation string.
|
||||
:param ins: Tuple of input operands. This can be a mix of SSA value
|
||||
operands and other operand kinds.
|
||||
:param outs: Tuple of output operands. The output operands must be SSA
|
||||
values or `variable_args`.
|
||||
:param constraints: Tuple of instruction-specific TypeConstraints.
|
||||
:param is_terminator: This is a terminator instruction.
|
||||
:param is_branch: This is a branch instruction.
|
||||
:param is_call: This is a call instruction.
|
||||
:param is_return: This is a return instruction.
|
||||
:param can_trap: This instruction can trap.
|
||||
:param can_load: This instruction can load from memory.
|
||||
:param can_store: This instruction can store to memory.
|
||||
:param other_side_effects: Instruction has other side effects.
|
||||
"""
|
||||
|
||||
# Boolean instruction attributes that can be passed as keyword arguments to
|
||||
# the constructor. Map attribute name to doc comment for generated Rust
|
||||
# code.
|
||||
ATTRIBS = {
|
||||
'is_terminator': 'True for instructions that terminate the EBB.',
|
||||
'is_branch': 'True for all branch or jump instructions.',
|
||||
'is_call': 'Is this a call instruction?',
|
||||
'is_return': 'Is this a return instruction?',
|
||||
'can_load': 'Can this instruction read from memory?',
|
||||
'can_store': 'Can this instruction write to memory?',
|
||||
'can_trap': 'Can this instruction cause a trap?',
|
||||
'other_side_effects':
|
||||
'Does this instruction have other side effects besides can_*',
|
||||
'writes_cpu_flags': 'Does this instruction write to CPU flags?',
|
||||
}
|
||||
|
||||
def __init__(self, name, doc, ins=(), outs=(), constraints=(), **kwargs):
|
||||
# type: (str, str, OpList, OpList, ConstrList, **Any) -> None
|
||||
self.name = name
|
||||
self.camel_name = camel_case(name)
|
||||
self.__doc__ = doc
|
||||
self.ins = self._to_operand_tuple(ins)
|
||||
self.outs = self._to_operand_tuple(outs)
|
||||
self.constraints = self._to_constraint_tuple(constraints)
|
||||
self.format = InstructionFormat.lookup(self.ins, self.outs)
|
||||
self.semantics = None # type: InstructionSemantics
|
||||
|
||||
# Opcode number, assigned by gen_instr.py.
|
||||
self.number = None # type: int
|
||||
|
||||
# Indexes into `self.outs` for value results.
|
||||
# Other results are `variable_args`.
|
||||
self.value_results = tuple(
|
||||
i for i, o in enumerate(self.outs) if o.is_value())
|
||||
# Indexes into `self.ins` for value operands.
|
||||
self.value_opnums = tuple(
|
||||
i for i, o in enumerate(self.ins) if o.is_value())
|
||||
# Indexes into `self.ins` for non-value operands.
|
||||
self.imm_opnums = tuple(
|
||||
i for i, o in enumerate(self.ins) if o.is_immediate())
|
||||
|
||||
self._verify_polymorphic()
|
||||
for attr in kwargs:
|
||||
if attr not in Instruction.ATTRIBS:
|
||||
raise AssertionError(
|
||||
"unknown instruction attribute '" + attr + "'")
|
||||
for attr in Instruction.ATTRIBS:
|
||||
setattr(self, attr, not not kwargs.get(attr, False))
|
||||
|
||||
# Infer the 'writes_cpu_flags' field value.
|
||||
if 'writes_cpu_flags' not in kwargs:
|
||||
self.writes_cpu_flags = any(
|
||||
out.is_cpu_flags() for out in self.outs)
|
||||
|
||||
InstructionGroup.append(self)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
prefix = ', '.join(o.name for o in self.outs)
|
||||
if prefix:
|
||||
prefix = prefix + ' = '
|
||||
suffix = ', '.join(o.name for o in self.ins)
|
||||
return '{}{} {}'.format(prefix, self.name, suffix)
|
||||
|
||||
def snake_name(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Get the snake_case name of this instruction.
|
||||
|
||||
Keywords in Rust and Python are altered by appending a '_'
|
||||
"""
|
||||
if self.name == 'return':
|
||||
return 'return_'
|
||||
else:
|
||||
return self.name
|
||||
|
||||
def blurb(self):
|
||||
# type: () -> str
|
||||
"""Get the first line of the doc comment"""
|
||||
for line in self.__doc__.split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
return line
|
||||
return ""
|
||||
|
||||
def _verify_polymorphic(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Check if this instruction is polymorphic, and verify its use of type
|
||||
variables.
|
||||
"""
|
||||
poly_ins = [
|
||||
i for i in self.value_opnums
|
||||
if self.ins[i].typevar.free_typevar()]
|
||||
poly_outs = [
|
||||
i for i, o in enumerate(self.outs)
|
||||
if o.is_value() and o.typevar.free_typevar()]
|
||||
self.is_polymorphic = len(poly_ins) > 0 or len(poly_outs) > 0
|
||||
if not self.is_polymorphic:
|
||||
return
|
||||
|
||||
# Prefer to use the typevar_operand to infer the controlling typevar.
|
||||
self.use_typevar_operand = False
|
||||
typevar_error = None
|
||||
tv_op = self.format.typevar_operand
|
||||
if tv_op is not None and tv_op < len(self.value_opnums):
|
||||
try:
|
||||
opnum = self.value_opnums[tv_op]
|
||||
tv = self.ins[opnum].typevar
|
||||
if tv is tv.free_typevar() or tv.singleton_type() is not None:
|
||||
self.other_typevars = self._verify_ctrl_typevar(tv)
|
||||
self.ctrl_typevar = tv
|
||||
self.use_typevar_operand = True
|
||||
except RuntimeError as e:
|
||||
typevar_error = e
|
||||
|
||||
if not self.use_typevar_operand:
|
||||
# The typevar_operand argument doesn't work. Can we infer from the
|
||||
# first result instead?
|
||||
if len(self.outs) == 0:
|
||||
if typevar_error:
|
||||
raise typevar_error
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"typevar_operand must be a free type variable")
|
||||
tv = self.outs[0].typevar
|
||||
if tv is not tv.free_typevar():
|
||||
raise RuntimeError("first result must be a free type variable")
|
||||
self.other_typevars = self._verify_ctrl_typevar(tv)
|
||||
self.ctrl_typevar = tv
|
||||
|
||||
def _verify_ctrl_typevar(self, ctrl_typevar):
|
||||
# type: (TypeVar) -> List[TypeVar]
|
||||
"""
|
||||
Verify that the use of TypeVars is consistent with `ctrl_typevar` as
|
||||
the controlling type variable.
|
||||
|
||||
All polymorhic inputs must either be derived from `ctrl_typevar` or be
|
||||
independent free type variables only used once.
|
||||
|
||||
All polymorphic results must be derived from `ctrl_typevar`.
|
||||
|
||||
Return list of other type variables used, or raise an error.
|
||||
"""
|
||||
other_tvs = [] # type: List[TypeVar]
|
||||
# Check value inputs.
|
||||
for opnum in self.value_opnums:
|
||||
typ = self.ins[opnum].typevar
|
||||
tv = typ.free_typevar()
|
||||
# Non-polymorphic or derived form ctrl_typevar is OK.
|
||||
if tv is None or tv is ctrl_typevar:
|
||||
continue
|
||||
# No other derived typevars allowed.
|
||||
if typ is not tv:
|
||||
raise RuntimeError(
|
||||
"{}: type variable {} must be derived from {}"
|
||||
.format(self.ins[opnum], typ.name, ctrl_typevar))
|
||||
# Other free type variables can only be used once each.
|
||||
if tv in other_tvs:
|
||||
raise RuntimeError(
|
||||
"type variable {} can't be used more than once"
|
||||
.format(tv.name))
|
||||
other_tvs.append(tv)
|
||||
|
||||
# Check outputs.
|
||||
for result in self.outs:
|
||||
if not result.is_value():
|
||||
continue
|
||||
typ = result.typevar
|
||||
tv = typ.free_typevar()
|
||||
# Non-polymorphic or derived from ctrl_typevar is OK.
|
||||
if tv is None or tv is ctrl_typevar:
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"type variable in output not derived from ctrl_typevar")
|
||||
|
||||
return other_tvs
|
||||
|
||||
def all_typevars(self):
|
||||
# type: () -> List[TypeVar]
|
||||
"""
|
||||
Get a list of all type variables in the instruction.
|
||||
"""
|
||||
if self.is_polymorphic:
|
||||
return [self.ctrl_typevar] + self.other_typevars
|
||||
else:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _to_operand_tuple(x):
|
||||
# type: (Union[Sequence[Operand], Operand]) -> Tuple[Operand, ...]
|
||||
# Allow a single Operand instance instead of the awkward singleton
|
||||
# tuple syntax.
|
||||
if isinstance(x, Operand):
|
||||
y = (x,) # type: Tuple[Operand, ...]
|
||||
else:
|
||||
y = tuple(x)
|
||||
for op in y:
|
||||
assert isinstance(op, Operand)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def _to_constraint_tuple(x):
|
||||
# type: (ConstrList) -> Tuple[TypeConstraint, ...]
|
||||
"""
|
||||
Allow a single TypeConstraint instance instead of the awkward singleton
|
||||
tuple syntax.
|
||||
"""
|
||||
# import placed here to avoid circular dependency
|
||||
from .ti import TypeConstraint # noqa
|
||||
if isinstance(x, TypeConstraint):
|
||||
y = (x,) # type: Tuple[TypeConstraint, ...]
|
||||
else:
|
||||
y = tuple(x)
|
||||
for op in y:
|
||||
assert isinstance(op, TypeConstraint)
|
||||
return y
|
||||
|
||||
def bind(self, *args):
|
||||
# type: (*ValueType) -> BoundInstruction
|
||||
"""
|
||||
Bind a polymorphic instruction to a concrete list of type variable
|
||||
values.
|
||||
"""
|
||||
assert self.is_polymorphic
|
||||
return BoundInstruction(self, args)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# type: (str) -> BoundInstruction
|
||||
"""
|
||||
Bind a polymorphic instruction to a single type variable with dot
|
||||
syntax:
|
||||
|
||||
>>> iadd.i32
|
||||
"""
|
||||
assert name != 'any', 'Wildcard not allowed for ctrl_typevar'
|
||||
return self.bind(ValueType.by_name(name))
|
||||
|
||||
def fully_bound(self):
|
||||
# type: () -> Tuple[Instruction, Tuple[ValueType, ...]]
|
||||
"""
|
||||
Verify that all typevars have been bound, and return a
|
||||
`(inst, typevars)` pair.
|
||||
|
||||
This version in `Instruction` itself allows non-polymorphic
|
||||
instructions to duck-type as `BoundInstruction`\s.
|
||||
"""
|
||||
assert not self.is_polymorphic, self
|
||||
return (self, ())
|
||||
|
||||
def __call__(self, *args):
|
||||
# type: (*Expr) -> Apply
|
||||
"""
|
||||
Create an `ast.Apply` AST node representing the application of this
|
||||
instruction to the arguments.
|
||||
"""
|
||||
from .ast import Apply # noqa
|
||||
return Apply(self, args)
|
||||
|
||||
def set_semantics(self, src, *dsts):
|
||||
# type: (Union[Def, Apply], *SemDefCase) -> None
|
||||
"""Set our semantics."""
|
||||
from semantics import verify_semantics
|
||||
from .xform import XForm, Rtl
|
||||
|
||||
sem = [] # type: List[XForm]
|
||||
for dst in dsts:
|
||||
if isinstance(dst, Rtl):
|
||||
sem.append(XForm(Rtl(src).copy({}), dst))
|
||||
elif isinstance(dst, XForm):
|
||||
sem.append(XForm(
|
||||
dst.src.copy({}),
|
||||
dst.dst.copy({}),
|
||||
dst.constraints))
|
||||
else:
|
||||
assert isinstance(dst, tuple)
|
||||
sem.append(XForm(Rtl(src).copy({}), dst[0],
|
||||
constraints=dst[1]))
|
||||
|
||||
verify_semantics(self, Rtl(src), sem)
|
||||
|
||||
self.semantics = sem
|
||||
|
||||
|
||||
class BoundInstruction(object):
|
||||
"""
|
||||
A polymorphic `Instruction` bound to concrete type variables.
|
||||
"""
|
||||
|
||||
def __init__(self, inst, typevars):
|
||||
# type: (Instruction, Tuple[ValueType, ...]) -> None
|
||||
self.inst = inst
|
||||
self.typevars = typevars
|
||||
assert len(typevars) <= 1 + len(inst.other_typevars)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return '.'.join([self.inst.name, ] + list(map(str, self.typevars)))
|
||||
|
||||
def bind(self, *args):
|
||||
# type: (*ValueType) -> BoundInstruction
|
||||
"""
|
||||
Bind additional typevars.
|
||||
"""
|
||||
return BoundInstruction(self.inst, self.typevars + args)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# type: (str) -> BoundInstruction
|
||||
"""
|
||||
Bind an additional typevar dot syntax:
|
||||
|
||||
>>> uext.i32.i8
|
||||
"""
|
||||
if name == 'any':
|
||||
# This is a wild card bind represented as a None type variable.
|
||||
return self.bind(None)
|
||||
|
||||
return self.bind(ValueType.by_name(name))
|
||||
|
||||
def fully_bound(self):
|
||||
# type: () -> Tuple[Instruction, Tuple[ValueType, ...]]
|
||||
"""
|
||||
Verify that all typevars have been bound, and return a
|
||||
`(inst, typevars)` pair.
|
||||
"""
|
||||
if len(self.typevars) < 1 + len(self.inst.other_typevars):
|
||||
unb = ', '.join(
|
||||
str(tv) for tv in
|
||||
self.inst.other_typevars[len(self.typevars) - 1:])
|
||||
raise AssertionError("Unbound typevar {} in {}".format(unb, self))
|
||||
assert len(self.typevars) == 1 + len(self.inst.other_typevars)
|
||||
return (self.inst, self.typevars)
|
||||
|
||||
def __call__(self, *args):
|
||||
# type: (*Expr) -> Apply
|
||||
"""
|
||||
Create an `ast.Apply` AST node representing the application of this
|
||||
instruction to the arguments.
|
||||
"""
|
||||
from .ast import Apply # noqa
|
||||
return Apply(self, args)
|
||||
496
lib/codegen/meta-python/cdsl/isa.py
Normal file
496
lib/codegen/meta-python/cdsl/isa.py
Normal file
@@ -0,0 +1,496 @@
|
||||
"""Defining instruction set architectures."""
|
||||
from __future__ import absolute_import
|
||||
from collections import OrderedDict
|
||||
from .predicates import And, TypePredicate
|
||||
from .registers import RegClass, Register, Stack
|
||||
from .ast import Apply
|
||||
from .types import ValueType
|
||||
from .instructions import InstructionGroup
|
||||
|
||||
# The typing module is only required by mypy, and we don't use these imports
|
||||
# outside type comments.
|
||||
try:
|
||||
from typing import Tuple, Union, Any, Iterable, Sequence, List, Set, Dict, TYPE_CHECKING # noqa
|
||||
if TYPE_CHECKING:
|
||||
from .instructions import MaybeBoundInst, InstructionFormat # noqa
|
||||
from .predicates import PredNode, PredKey # noqa
|
||||
from .settings import SettingGroup # noqa
|
||||
from .registers import RegBank # noqa
|
||||
from .xform import XFormGroup # noqa
|
||||
OperandConstraint = Union[RegClass, Register, int, Stack]
|
||||
ConstraintSeq = Union[OperandConstraint, Tuple[OperandConstraint, ...]]
|
||||
# Instruction specification for encodings. Allows for predicated
|
||||
# instructions.
|
||||
InstSpec = Union[MaybeBoundInst, Apply]
|
||||
BranchRange = Sequence[int]
|
||||
# A recipe predicate consisting of an ISA predicate and an instruction
|
||||
# predicate.
|
||||
RecipePred = Tuple[PredNode, PredNode]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class TargetISA(object):
|
||||
"""
|
||||
A target instruction set architecture.
|
||||
|
||||
The `TargetISA` class collects everything known about a target ISA.
|
||||
|
||||
:param name: Short mnemonic name for the ISA.
|
||||
:param instruction_groups: List of `InstructionGroup` instances that are
|
||||
relevant for this ISA.
|
||||
"""
|
||||
|
||||
def __init__(self, name, instruction_groups):
|
||||
# type: (str, Sequence[InstructionGroup]) -> None
|
||||
self.name = name
|
||||
self.settings = None # type: SettingGroup
|
||||
self.instruction_groups = instruction_groups
|
||||
self.cpumodes = list() # type: List[CPUMode]
|
||||
self.regbanks = list() # type: List[RegBank]
|
||||
self.regclasses = list() # type: List[RegClass]
|
||||
self.legalize_codes = OrderedDict() # type: OrderedDict[XFormGroup, int] # noqa
|
||||
# Unique copies of all predicates.
|
||||
self._predicates = dict() # type: Dict[PredKey, PredNode]
|
||||
|
||||
assert InstructionGroup._current is None,\
|
||||
"InstructionGroup {} is still open"\
|
||||
.format(InstructionGroup._current.name)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return self.name
|
||||
|
||||
def finish(self):
|
||||
# type: () -> TargetISA
|
||||
"""
|
||||
Finish the definition of a target ISA after adding all CPU modes and
|
||||
settings.
|
||||
|
||||
This computes some derived properties that are used in multiple
|
||||
places.
|
||||
|
||||
:returns self:
|
||||
"""
|
||||
self._collect_encoding_recipes()
|
||||
self._collect_predicates()
|
||||
self._collect_regclasses()
|
||||
self._collect_legalize_codes()
|
||||
return self
|
||||
|
||||
def _collect_encoding_recipes(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Collect and number all encoding recipes in use.
|
||||
"""
|
||||
self.all_recipes = list() # type: List[EncRecipe]
|
||||
rcps = set() # type: Set[EncRecipe]
|
||||
for cpumode in self.cpumodes:
|
||||
for enc in cpumode.encodings:
|
||||
recipe = enc.recipe
|
||||
if recipe not in rcps:
|
||||
assert recipe.number is None
|
||||
recipe.number = len(rcps)
|
||||
rcps.add(recipe)
|
||||
self.all_recipes.append(recipe)
|
||||
# Make sure ISA predicates are registered.
|
||||
if recipe.isap:
|
||||
recipe.isap = self.unique_pred(recipe.isap)
|
||||
self.settings.number_predicate(recipe.isap)
|
||||
recipe.instp = self.unique_pred(recipe.instp)
|
||||
|
||||
def _collect_predicates(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Collect and number all predicates in use.
|
||||
|
||||
Ensures that all ISA predicates have an assigned bit number in
|
||||
`self.settings`.
|
||||
"""
|
||||
self.instp_number = OrderedDict() # type: OrderedDict[PredNode, int]
|
||||
for cpumode in self.cpumodes:
|
||||
for enc in cpumode.encodings:
|
||||
instp = enc.instp
|
||||
if instp and instp not in self.instp_number:
|
||||
# assign predicate number starting from 0.
|
||||
n = len(self.instp_number)
|
||||
self.instp_number[instp] = n
|
||||
|
||||
# All referenced ISA predicates must have a number in
|
||||
# `self.settings`. This may cause some parent predicates to be
|
||||
# replicated here, which is OK.
|
||||
if enc.isap:
|
||||
self.settings.number_predicate(enc.isap)
|
||||
|
||||
def _collect_regclasses(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Collect and number register classes.
|
||||
|
||||
Every register class needs a unique index, and the classes need to be
|
||||
topologically ordered.
|
||||
|
||||
We also want all the top-level register classes to be first.
|
||||
"""
|
||||
# Compute subclasses and top-level classes in each bank.
|
||||
# Collect the top-level classes so they get numbered consecutively.
|
||||
for bank in self.regbanks:
|
||||
bank.finish_regclasses()
|
||||
# Always get the pressure tracking classes in first.
|
||||
if bank.pressure_tracking:
|
||||
self.regclasses.extend(bank.toprcs)
|
||||
|
||||
# The limit on the number of top-level register classes can be raised.
|
||||
# This should be coordinated with the `MAX_TRACKED_TOPRCS` constant in
|
||||
# `isa/registers.rs`.
|
||||
assert len(self.regclasses) <= 4, "Too many top-level register classes"
|
||||
|
||||
# Get the remaining top-level register classes which may exceed
|
||||
# `MAX_TRACKED_TOPRCS`.
|
||||
for bank in self.regbanks:
|
||||
if not bank.pressure_tracking:
|
||||
self.regclasses.extend(bank.toprcs)
|
||||
|
||||
# Collect all of the non-top-level register classes.
|
||||
# They are numbered strictly after the top-level classes.
|
||||
for bank in self.regbanks:
|
||||
self.regclasses.extend(
|
||||
rc for rc in bank.classes if not rc.is_toprc())
|
||||
|
||||
for idx, rc in enumerate(self.regclasses):
|
||||
rc.index = idx
|
||||
|
||||
# The limit on the number of register classes can be changed. It should
|
||||
# be coordinated with the `RegClassMask` and `RegClassIndex` types in
|
||||
# `isa/registers.rs`.
|
||||
assert len(self.regclasses) <= 32, "Too many register classes"
|
||||
|
||||
def _collect_legalize_codes(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Make sure all legalization transforms have been assigned a code.
|
||||
"""
|
||||
for cpumode in self.cpumodes:
|
||||
self.legalize_code(cpumode.default_legalize)
|
||||
for x in cpumode.type_legalize.values():
|
||||
self.legalize_code(x)
|
||||
|
||||
def legalize_code(self, xgrp):
|
||||
# type: (XFormGroup) -> int
|
||||
"""
|
||||
Get the legalization code for the transform group `xgrp`. Assign one if
|
||||
necessary.
|
||||
|
||||
Each target ISA has its own list of legalization actions with
|
||||
associated legalize codes that appear in the encoding tables.
|
||||
|
||||
This method is used to maintain the registry of legalization actions
|
||||
and their table codes.
|
||||
"""
|
||||
if xgrp in self.legalize_codes:
|
||||
code = self.legalize_codes[xgrp]
|
||||
else:
|
||||
code = len(self.legalize_codes)
|
||||
self.legalize_codes[xgrp] = code
|
||||
return code
|
||||
|
||||
def unique_pred(self, pred):
|
||||
# type: (PredNode) -> PredNode
|
||||
"""
|
||||
Get a unique predicate that is equivalent to `pred`.
|
||||
"""
|
||||
if pred is None:
|
||||
return pred
|
||||
# TODO: We could actually perform some algebraic simplifications. It's
|
||||
# not clear if it is worthwhile.
|
||||
k = pred.predicate_key()
|
||||
if k in self._predicates:
|
||||
return self._predicates[k]
|
||||
self._predicates[k] = pred
|
||||
return pred
|
||||
|
||||
|
||||
class CPUMode(object):
|
||||
"""
|
||||
A CPU mode determines which instruction encodings are active.
|
||||
|
||||
All instruction encodings are associated with exactly one `CPUMode`, and
|
||||
all CPU modes are associated with exactly one `TargetISA`.
|
||||
|
||||
:param name: Short mnemonic name for the CPU mode.
|
||||
:param target: Associated `TargetISA`.
|
||||
"""
|
||||
|
||||
def __init__(self, name, isa):
|
||||
# type: (str, TargetISA) -> None
|
||||
self.name = name
|
||||
self.isa = isa
|
||||
self.encodings = [] # type: List[Encoding]
|
||||
isa.cpumodes.append(self)
|
||||
|
||||
# Tables for configuring legalization actions when no valid encoding
|
||||
# exists for an instruction.
|
||||
self.default_legalize = None # type: XFormGroup
|
||||
self.type_legalize = OrderedDict() # type: OrderedDict[ValueType, XFormGroup] # noqa
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return self.name
|
||||
|
||||
def enc(self, *args, **kwargs):
|
||||
# type: (*Any, **Any) -> None
|
||||
"""
|
||||
Add a new encoding to this CPU mode.
|
||||
|
||||
Arguments are the `Encoding constructor arguments, except for the first
|
||||
`CPUMode argument which is implied.
|
||||
"""
|
||||
self.encodings.append(Encoding(self, *args, **kwargs))
|
||||
|
||||
def legalize_type(self, default=None, **kwargs):
|
||||
# type: (XFormGroup, **XFormGroup) -> None
|
||||
"""
|
||||
Configure the legalization action per controlling type variable.
|
||||
|
||||
Instructions that have a controlling type variable mentioned in one of
|
||||
the arguments will be legalized according to the action specified here
|
||||
instead of using the `legalize_default` action.
|
||||
|
||||
The keyword arguments are value type names:
|
||||
|
||||
mode.legalize_type(i8=widen, i16=widen, i32=expand)
|
||||
|
||||
The `default` argument specifies the action to take for controlling
|
||||
type variables that don't have an explicitly configured action.
|
||||
"""
|
||||
if default is not None:
|
||||
self.default_legalize = default
|
||||
|
||||
for name, xgrp in kwargs.items():
|
||||
ty = ValueType.by_name(name)
|
||||
self.type_legalize[ty] = xgrp
|
||||
|
||||
def legalize_monomorphic(self, xgrp):
|
||||
# type: (XFormGroup) -> None
|
||||
"""
|
||||
Configure the legalization action to take for monomorphic instructions
|
||||
which don't have a controlling type variable.
|
||||
|
||||
See also `legalize_type()` for polymorphic instructions.
|
||||
"""
|
||||
self.type_legalize[None] = xgrp
|
||||
|
||||
def get_legalize_action(self, ty):
|
||||
# type: (ValueType) -> XFormGroup
|
||||
"""
|
||||
Get the legalization action to use for `ty`.
|
||||
"""
|
||||
return self.type_legalize.get(ty, self.default_legalize)
|
||||
|
||||
|
||||
class EncRecipe(object):
|
||||
"""
|
||||
A recipe for encoding instructions with a given format.
|
||||
|
||||
Many different instructions can be encoded by the same recipe, but they
|
||||
must all have the same instruction format.
|
||||
|
||||
The `ins` and `outs` arguments are tuples specifying the register
|
||||
allocation constraints for the value operands and results respectively. The
|
||||
possible constraints for an operand are:
|
||||
|
||||
- A `RegClass` specifying the set of allowed registers.
|
||||
- A `Register` specifying a fixed-register operand.
|
||||
- An integer indicating that this result is tied to a value operand, so
|
||||
they must use the same register.
|
||||
- A `Stack` specifying a value in a stack slot.
|
||||
|
||||
The `branch_range` argument must be provided for recipes that can encode
|
||||
branch instructions. It is an `(origin, bits)` tuple describing the exact
|
||||
range that can be encoded in a branch instruction.
|
||||
|
||||
For ISAs that use CPU flags in `iflags` and `fflags` value types, the
|
||||
`clobbers_flags` is used to indicate instruction encodings that clobbers
|
||||
the CPU flags, so they can't be used where a flag value is live.
|
||||
|
||||
:param name: Short mnemonic name for this recipe.
|
||||
:param format: All encoded instructions must have this
|
||||
:py:class:`InstructionFormat`.
|
||||
:param size: Number of bytes in the binary encoded instruction.
|
||||
:param ins: Tuple of register constraints for value operands.
|
||||
:param outs: Tuple of register constraints for results.
|
||||
:param branch_range: `(origin, bits)` range for branches.
|
||||
:param clobbers_flags: This instruction clobbers `iflags` and `fflags`.
|
||||
:param instp: Instruction predicate.
|
||||
:param isap: ISA predicate.
|
||||
:param emit: Rust code for binary emission.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name, # type: str
|
||||
format, # type: InstructionFormat
|
||||
size, # type: int
|
||||
ins, # type: ConstraintSeq
|
||||
outs, # type: ConstraintSeq
|
||||
branch_range=None, # type: BranchRange
|
||||
clobbers_flags=True, # type: bool
|
||||
instp=None, # type: PredNode
|
||||
isap=None, # type: PredNode
|
||||
emit=None # type: str
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.name = name
|
||||
self.format = format
|
||||
assert size >= 0
|
||||
self.size = size
|
||||
self.branch_range = branch_range
|
||||
self.clobbers_flags = clobbers_flags
|
||||
self.instp = instp
|
||||
self.isap = isap
|
||||
self.emit = emit
|
||||
if instp:
|
||||
assert instp.predicate_context() == format
|
||||
self.number = None # type: int
|
||||
|
||||
self.ins = self._verify_constraints(ins)
|
||||
if not format.has_value_list:
|
||||
assert len(self.ins) == format.num_value_operands
|
||||
self.outs = self._verify_constraints(outs)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return self.name
|
||||
|
||||
def _verify_constraints(self, seq):
|
||||
# type: (ConstraintSeq) -> Sequence[OperandConstraint]
|
||||
if not isinstance(seq, tuple):
|
||||
seq = (seq,)
|
||||
for c in seq:
|
||||
if isinstance(c, int):
|
||||
# An integer constraint is bound to a value operand.
|
||||
# Check that it is in range.
|
||||
assert c >= 0 and c < len(self.ins)
|
||||
else:
|
||||
assert (isinstance(c, RegClass)
|
||||
or isinstance(c, Register)
|
||||
or isinstance(c, Stack))
|
||||
return seq
|
||||
|
||||
def ties(self):
|
||||
# type: () -> Tuple[Dict[int, int], Dict[int, int]]
|
||||
"""
|
||||
Return two dictionaries representing the tied operands.
|
||||
|
||||
The first maps input number to tied output number, the second maps
|
||||
output number to tied input number.
|
||||
"""
|
||||
i2o = dict() # type: Dict[int, int]
|
||||
o2i = dict() # type: Dict[int, int]
|
||||
for o, i in enumerate(self.outs):
|
||||
if isinstance(i, int):
|
||||
i2o[i] = o
|
||||
o2i[o] = i
|
||||
return (i2o, o2i)
|
||||
|
||||
def fixed_ops(self):
|
||||
# type: () -> Tuple[Set[Register], Set[Register]]
|
||||
"""
|
||||
Return two sets of registers representing the fixed input and output
|
||||
operands.
|
||||
"""
|
||||
i = set(r for r in self.ins if isinstance(r, Register))
|
||||
o = set(r for r in self.outs if isinstance(r, Register))
|
||||
return (i, o)
|
||||
|
||||
def recipe_pred(self):
|
||||
# type: () -> RecipePred
|
||||
"""
|
||||
Get the combined recipe predicate which includes both the ISA predicate
|
||||
and the instruction predicate.
|
||||
|
||||
Return `None` if this recipe has neither predicate.
|
||||
"""
|
||||
if self.isap is None and self.instp is None:
|
||||
return None
|
||||
else:
|
||||
return (self.isap, self.instp)
|
||||
|
||||
|
||||
class Encoding(object):
|
||||
"""
|
||||
Encoding for a concrete instruction.
|
||||
|
||||
An `Encoding` object ties an instruction opcode with concrete type
|
||||
variables together with and encoding recipe and encoding bits.
|
||||
|
||||
The concrete instruction can be in three different forms:
|
||||
|
||||
1. A naked opcode: `trap` for non-polymorphic instructions.
|
||||
2. With bound type variables: `iadd.i32` for polymorphic instructions.
|
||||
3. With operands providing constraints: `icmp.i32(intcc.eq, x, y)`.
|
||||
|
||||
If the instruction is polymorphic, all type variables must be provided.
|
||||
|
||||
:param cpumode: The CPU mode where the encoding is active.
|
||||
:param inst: The :py:class:`Instruction` or :py:class:`BoundInstruction`
|
||||
being encoded.
|
||||
:param recipe: The :py:class:`EncRecipe` to use.
|
||||
:param encbits: Additional encoding bits to be interpreted by `recipe`.
|
||||
:param instp: Instruction predicate, or `None`.
|
||||
:param isap: ISA predicate, or `None`.
|
||||
"""
|
||||
|
||||
def __init__(self, cpumode, inst, recipe, encbits, instp=None, isap=None):
|
||||
# type: (CPUMode, InstSpec, EncRecipe, int, PredNode, PredNode) -> None # noqa
|
||||
assert isinstance(cpumode, CPUMode)
|
||||
assert isinstance(recipe, EncRecipe)
|
||||
|
||||
# Check for possible instruction predicates in `inst`.
|
||||
if isinstance(inst, Apply):
|
||||
instp = And.combine(instp, inst.inst_predicate())
|
||||
self.inst = inst.inst
|
||||
self.typevars = inst.typevars
|
||||
else:
|
||||
self.inst, self.typevars = inst.fully_bound()
|
||||
|
||||
# Add secondary type variables to the instruction predicate.
|
||||
# This is already included by Apply.inst_predicate() above.
|
||||
if len(self.typevars) > 1:
|
||||
for tv, vt in zip(self.inst.other_typevars, self.typevars[1:]):
|
||||
# A None tv is an 'any' wild card: `ishl.i32.any`.
|
||||
if vt is None:
|
||||
continue
|
||||
typred = TypePredicate.typevar_check(self.inst, tv, vt)
|
||||
instp = And.combine(instp, typred)
|
||||
|
||||
self.cpumode = cpumode
|
||||
assert self.inst.format == recipe.format, (
|
||||
"Format {} must match recipe: {}".format(
|
||||
self.inst.format, recipe.format))
|
||||
|
||||
if self.inst.is_branch:
|
||||
assert recipe.branch_range, (
|
||||
'Recipe {} for {} must have a branch_range'
|
||||
.format(recipe, self.inst.name))
|
||||
|
||||
self.recipe = recipe
|
||||
self.encbits = encbits
|
||||
|
||||
# Record specific predicates. Note that the recipe also has predicates.
|
||||
self.instp = self.cpumode.isa.unique_pred(instp)
|
||||
self.isap = self.cpumode.isa.unique_pred(isap)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return '[{}#{:02x}]'.format(self.recipe, self.encbits)
|
||||
|
||||
def ctrl_typevar(self):
|
||||
# type: () -> ValueType
|
||||
"""
|
||||
Get the controlling type variable for this encoding or `None`.
|
||||
"""
|
||||
if self.typevars:
|
||||
return self.typevars[0]
|
||||
else:
|
||||
return None
|
||||
251
lib/codegen/meta-python/cdsl/operands.py
Normal file
251
lib/codegen/meta-python/cdsl/operands.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""Classes for describing instruction operands."""
|
||||
from __future__ import absolute_import
|
||||
from . import camel_case
|
||||
from .types import ValueType
|
||||
from .typevar import TypeVar
|
||||
|
||||
try:
|
||||
from typing import Union, Dict, TYPE_CHECKING, Iterable # noqa
|
||||
OperandSpec = Union['OperandKind', ValueType, TypeVar]
|
||||
if TYPE_CHECKING:
|
||||
from .ast import Enumerator, ConstantInt, ConstantBits, Literal # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# Kinds of operands.
|
||||
#
|
||||
# Each instruction has an opcode and a number of operands. The opcode
|
||||
# determines the instruction format, and the format determines the number of
|
||||
# operands and the kind of each operand.
|
||||
class OperandKind(object):
|
||||
"""
|
||||
An instance of the `OperandKind` class corresponds to a kind of operand.
|
||||
Each operand kind has a corresponding type in the Rust representation of an
|
||||
instruction.
|
||||
"""
|
||||
|
||||
def __init__(self, name, doc, default_member=None, rust_type=None):
|
||||
# type: (str, str, str, str) -> None
|
||||
self.name = name
|
||||
self.__doc__ = doc
|
||||
self.default_member = default_member
|
||||
# The camel-cased name of an operand kind is also the Rust type used to
|
||||
# represent it.
|
||||
self.rust_type = rust_type or ('ir::' + camel_case(name))
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return 'OperandKind({})'.format(self.name)
|
||||
|
||||
|
||||
#: An SSA value operand. This is a value defined by another instruction.
|
||||
VALUE = OperandKind(
|
||||
'value', """
|
||||
An SSA value defined by another instruction.
|
||||
|
||||
This kind of operand can represent any SSA value type, but the
|
||||
instruction format may restrict the valid value types for a given
|
||||
operand.
|
||||
""")
|
||||
|
||||
#: A variable-sized list of value operands. Use for Ebb and function call
|
||||
#: arguments.
|
||||
VARIABLE_ARGS = OperandKind(
|
||||
'variable_args', """
|
||||
A variable size list of `value` operands.
|
||||
|
||||
Use this to represent arguments passed to a function call, arguments
|
||||
passed to an extended basic block, or a variable number of results
|
||||
returned from an instruction.
|
||||
""",
|
||||
rust_type='&[Value]')
|
||||
|
||||
|
||||
# Instances of immediate operand types are provided in the
|
||||
# `cranelift.immediates` module.
|
||||
class ImmediateKind(OperandKind):
|
||||
"""
|
||||
The kind of an immediate instruction operand.
|
||||
|
||||
:param default_member: The default member name of this kind the
|
||||
`InstructionData` data structure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name, doc,
|
||||
default_member='imm',
|
||||
rust_type=None,
|
||||
values=None):
|
||||
# type: (str, str, str, str, Dict[str, str]) -> None
|
||||
if rust_type is None:
|
||||
rust_type = 'ir::immediates::' + camel_case(name)
|
||||
super(ImmediateKind, self).__init__(
|
||||
name, doc, default_member, rust_type)
|
||||
self.values = values
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return 'ImmediateKind({})'.format(self.name)
|
||||
|
||||
def __getattr__(self, value):
|
||||
# type: (str) -> Enumerator
|
||||
"""
|
||||
Enumerated immediate kinds allow the use of dot syntax to produce
|
||||
`Enumerator` AST nodes: `icmp.i32(intcc.ult, a, b)`.
|
||||
"""
|
||||
from .ast import Enumerator # noqa
|
||||
if not self.values:
|
||||
raise AssertionError(
|
||||
'{n} is not an enumerated operand kind: {n}.{a}'.format(
|
||||
n=self.name, a=value))
|
||||
if value not in self.values:
|
||||
raise AssertionError(
|
||||
'No such {n} enumerator: {n}.{a}'.format(
|
||||
n=self.name, a=value))
|
||||
return Enumerator(self, value)
|
||||
|
||||
def __call__(self, value):
|
||||
# type: (int) -> ConstantInt
|
||||
"""
|
||||
Create an AST node representing a constant integer:
|
||||
|
||||
iconst(imm64(0))
|
||||
"""
|
||||
from .ast import ConstantInt # noqa
|
||||
if self.values:
|
||||
raise AssertionError(
|
||||
"{}({}): Can't make a constant numeric value for an enum"
|
||||
.format(self.name, value))
|
||||
return ConstantInt(self, value)
|
||||
|
||||
def bits(self, bits):
|
||||
# type: (int) -> ConstantBits
|
||||
"""
|
||||
Create an AST literal node for the given bitwise representation of this
|
||||
immediate operand kind.
|
||||
"""
|
||||
from .ast import ConstantBits # noqa
|
||||
return ConstantBits(self, bits)
|
||||
|
||||
def rust_enumerator(self, value):
|
||||
# type: (str) -> str
|
||||
"""
|
||||
Get the qualified Rust name of the enumerator value `value`.
|
||||
"""
|
||||
return '{}::{}'.format(self.rust_type, self.values[value])
|
||||
|
||||
def is_enumerable(self):
|
||||
# type: () -> bool
|
||||
return self.values is not None
|
||||
|
||||
def possible_values(self):
|
||||
# type: () -> Iterable[Literal]
|
||||
from cdsl.ast import Enumerator # noqa
|
||||
assert self.is_enumerable()
|
||||
for v in self.values.keys():
|
||||
yield Enumerator(self, v)
|
||||
|
||||
|
||||
# Instances of entity reference operand types are provided in the
|
||||
# `cranelift.entities` module.
|
||||
class EntityRefKind(OperandKind):
|
||||
"""
|
||||
The kind of an entity reference instruction operand.
|
||||
"""
|
||||
|
||||
def __init__(self, name, doc, default_member=None, rust_type=None):
|
||||
# type: (str, str, str, str) -> None
|
||||
super(EntityRefKind, self).__init__(
|
||||
name, doc, default_member or name, rust_type)
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return 'EntityRefKind({})'.format(self.name)
|
||||
|
||||
|
||||
class Operand(object):
|
||||
"""
|
||||
An instruction operand can be an *immediate*, an *SSA value*, or an *entity
|
||||
reference*. The type of the operand is one of:
|
||||
|
||||
1. A :py:class:`ValueType` instance indicates an SSA value operand with a
|
||||
concrete type.
|
||||
|
||||
2. A :py:class:`TypeVar` instance indicates an SSA value operand, and the
|
||||
instruction is polymorphic over the possible concrete types that the
|
||||
type variable can assume.
|
||||
|
||||
3. An :py:class:`ImmediateKind` instance indicates an immediate operand
|
||||
whose value is encoded in the instruction itself rather than being
|
||||
passed as an SSA value.
|
||||
|
||||
4. An :py:class:`EntityRefKind` instance indicates an operand that
|
||||
references another entity in the function, typically something declared
|
||||
in the function preamble.
|
||||
|
||||
"""
|
||||
def __init__(self, name, typ, doc=''):
|
||||
# type: (str, OperandSpec, str) -> None
|
||||
self.name = name
|
||||
self.__doc__ = doc
|
||||
|
||||
# Decode the operand spec and set self.kind.
|
||||
# Only VALUE operands have a typevar member.
|
||||
if isinstance(typ, ValueType):
|
||||
self.kind = VALUE
|
||||
self.typevar = TypeVar.singleton(typ)
|
||||
elif isinstance(typ, TypeVar):
|
||||
self.kind = VALUE
|
||||
self.typevar = typ
|
||||
else:
|
||||
assert isinstance(typ, OperandKind)
|
||||
self.kind = typ
|
||||
|
||||
def get_doc(self):
|
||||
# type: () -> str
|
||||
if self.__doc__:
|
||||
return self.__doc__
|
||||
if self.kind is VALUE:
|
||||
return self.typevar.__doc__
|
||||
return self.kind.__doc__
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return "`{}`".format(self.name)
|
||||
|
||||
def is_value(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Is this an SSA value operand?
|
||||
"""
|
||||
return self.kind is VALUE
|
||||
|
||||
def is_varargs(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Is this a VARIABLE_ARGS operand?
|
||||
"""
|
||||
return self.kind is VARIABLE_ARGS
|
||||
|
||||
def is_immediate(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Is this an immediate operand?
|
||||
|
||||
Note that this includes both `ImmediateKind` operands *and* entity
|
||||
references. It is any operand that doesn't represent a value
|
||||
dependency.
|
||||
"""
|
||||
return self.kind is not VALUE and self.kind is not VARIABLE_ARGS
|
||||
|
||||
def is_cpu_flags(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Is this a CPU flags operand?
|
||||
"""
|
||||
return self.kind is VALUE and self.typevar.name in ['iflags', 'fflags']
|
||||
447
lib/codegen/meta-python/cdsl/predicates.py
Normal file
447
lib/codegen/meta-python/cdsl/predicates.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
Cranelift predicates.
|
||||
|
||||
A *predicate* is a function that computes a boolean result. The inputs to the
|
||||
function determine the kind of predicate:
|
||||
|
||||
- An *ISA predicate* is evaluated on the current ISA settings together with the
|
||||
shared settings defined in the :py:mod:`settings` module. Once a target ISA
|
||||
has been configured, the value of all ISA predicates is known.
|
||||
|
||||
- An *Instruction predicate* is evaluated on an instruction instance, so it can
|
||||
inspect all the immediate fields and type variables of the instruction.
|
||||
Instruction predicates can be evaluated before register allocation, so they
|
||||
can not depend on specific register assignments to the value operands or
|
||||
outputs.
|
||||
|
||||
Predicates can also be computed from other predicates using the `And`, `Or`,
|
||||
and `Not` combinators defined in this module.
|
||||
|
||||
All predicates have a *context* which determines where they can be evaluated.
|
||||
For an ISA predicate, the context is the ISA settings group. For an instruction
|
||||
predicate, the context is the instruction format.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from functools import reduce
|
||||
from .formats import instruction_context
|
||||
|
||||
try:
|
||||
from typing import Sequence, Tuple, Set, Any, Union, TYPE_CHECKING # noqa
|
||||
if TYPE_CHECKING:
|
||||
from .formats import InstructionFormat, InstructionContext, FormatField # noqa
|
||||
from .instructions import Instruction # noqa
|
||||
from .settings import BoolSetting, SettingGroup # noqa
|
||||
from .types import ValueType # noqa
|
||||
from .typevar import TypeVar # noqa
|
||||
PredContext = Union[SettingGroup, InstructionFormat,
|
||||
InstructionContext]
|
||||
PredLeaf = Union[BoolSetting, 'FieldPredicate', 'TypePredicate',
|
||||
'CtrlTypePredicate']
|
||||
PredNode = Union[PredLeaf, 'Predicate']
|
||||
# A predicate key is a (recursive) tuple of primitive types that
|
||||
# uniquely describes a predicate. It is used for interning.
|
||||
PredKey = Tuple[Any, ...]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _is_parent(a, b):
|
||||
# type: (PredContext, PredContext) -> bool
|
||||
"""
|
||||
Return true if a is a parent of b, or equal to it.
|
||||
"""
|
||||
while b and a is not b:
|
||||
b = getattr(b, 'parent', None)
|
||||
return a is b
|
||||
|
||||
|
||||
def _descendant(a, b):
|
||||
# type: (PredContext, PredContext) -> PredContext
|
||||
"""
|
||||
If a is a parent of b or b is a parent of a, return the descendant of the
|
||||
two.
|
||||
|
||||
If neither is a parent of the other, return None.
|
||||
"""
|
||||
if _is_parent(a, b):
|
||||
return b
|
||||
if _is_parent(b, a):
|
||||
return a
|
||||
return None
|
||||
|
||||
|
||||
class Predicate(object):
|
||||
"""
|
||||
Superclass for all computed predicates.
|
||||
|
||||
Leaf predicates can have other types, such as `Setting`.
|
||||
|
||||
:param parts: Tuple of components in the predicate expression.
|
||||
"""
|
||||
|
||||
def __init__(self, parts):
|
||||
# type: (Sequence[PredNode]) -> None
|
||||
self.parts = parts
|
||||
self.context = reduce(
|
||||
_descendant,
|
||||
(p.predicate_context() for p in parts))
|
||||
assert self.context, "Incompatible predicate parts"
|
||||
self.predkey = None # type: PredKey
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return '{}({})'.format(type(self).__name__,
|
||||
', '.join(map(str, self.parts)))
|
||||
|
||||
def predicate_context(self):
|
||||
# type: () -> PredContext
|
||||
return self.context
|
||||
|
||||
def predicate_leafs(self, leafs):
|
||||
# type: (Set[PredLeaf]) -> None
|
||||
"""
|
||||
Collect all leaf predicates into the `leafs` set.
|
||||
"""
|
||||
for part in self.parts:
|
||||
part.predicate_leafs(leafs)
|
||||
|
||||
def rust_predicate(self, prec):
|
||||
# type: (int) -> str
|
||||
raise NotImplementedError("rust_predicate is an abstract method")
|
||||
|
||||
def predicate_key(self):
|
||||
# type: () -> PredKey
|
||||
"""Tuple uniquely identifying a predicate."""
|
||||
if not self.predkey:
|
||||
p = tuple(p.predicate_key() for p in self.parts) # type: PredKey
|
||||
self.predkey = (type(self).__name__,) + p
|
||||
return self.predkey
|
||||
|
||||
|
||||
class And(Predicate):
|
||||
"""
|
||||
Computed predicate that is true if all parts are true.
|
||||
"""
|
||||
|
||||
precedence = 2
|
||||
|
||||
def __init__(self, *args):
|
||||
# type: (*PredNode) -> None
|
||||
super(And, self).__init__(args)
|
||||
|
||||
def rust_predicate(self, prec):
|
||||
# type: (int) -> str
|
||||
"""
|
||||
Return a Rust expression computing the value of this predicate.
|
||||
|
||||
The surrounding precedence determines whether parentheses are needed:
|
||||
|
||||
0. An `if` statement.
|
||||
1. An `||` expression.
|
||||
2. An `&&` expression.
|
||||
3. A `!` expression.
|
||||
"""
|
||||
s = ' && '.join(p.rust_predicate(And.precedence) for p in self.parts)
|
||||
if prec > And.precedence:
|
||||
s = '({})'.format(s)
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def combine(*args):
|
||||
# type: (*PredNode) -> PredNode
|
||||
"""
|
||||
Combine a sequence of predicates, allowing for `None` members.
|
||||
|
||||
Return a predicate that is true when all non-`None` arguments are true,
|
||||
or `None` if all of the arguments are `None`.
|
||||
"""
|
||||
args = tuple(p for p in args if p)
|
||||
if args == ():
|
||||
return None
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
# We have multiple predicate args. Combine with `And`.
|
||||
return And(*args)
|
||||
|
||||
|
||||
class Or(Predicate):
|
||||
"""
|
||||
Computed predicate that is true if any parts are true.
|
||||
"""
|
||||
|
||||
precedence = 1
|
||||
|
||||
def __init__(self, *args):
|
||||
# type: (*PredNode) -> None
|
||||
super(Or, self).__init__(args)
|
||||
|
||||
def rust_predicate(self, prec):
|
||||
# type: (int) -> str
|
||||
s = ' || '.join(p.rust_predicate(Or.precedence) for p in self.parts)
|
||||
if prec > Or.precedence:
|
||||
s = '({})'.format(s)
|
||||
return s
|
||||
|
||||
|
||||
class Not(Predicate):
|
||||
"""
|
||||
Computed predicate that is true if its single part is false.
|
||||
"""
|
||||
|
||||
precedence = 3
|
||||
|
||||
def __init__(self, part):
|
||||
# type: (PredNode) -> None
|
||||
super(Not, self).__init__((part,))
|
||||
|
||||
def rust_predicate(self, prec):
|
||||
# type: (int) -> str
|
||||
return '!' + self.parts[0].rust_predicate(Not.precedence)
|
||||
|
||||
|
||||
class FieldPredicate(object):
|
||||
"""
|
||||
An instruction predicate that performs a test on a single `FormatField`.
|
||||
|
||||
:param field: The `FormatField` to be tested.
|
||||
:param function: Boolean predicate function to call.
|
||||
:param args: Additional arguments for the predicate function.
|
||||
"""
|
||||
|
||||
def __init__(self, field, function, args):
|
||||
# type: (FormatField, str, Sequence[Any]) -> None
|
||||
self.field = field
|
||||
self.function = function
|
||||
self.args = args
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
args = (self.field.rust_name(),) + tuple(map(str, self.args))
|
||||
return '{}({})'.format(self.function, ', '.join(args))
|
||||
|
||||
def predicate_context(self):
|
||||
# type: () -> PredContext
|
||||
"""
|
||||
This predicate can be evaluated in the context of an instruction
|
||||
format.
|
||||
"""
|
||||
iform = self.field.format # type: InstructionFormat
|
||||
return iform
|
||||
|
||||
def predicate_key(self):
|
||||
# type: () -> PredKey
|
||||
a = tuple(map(str, self.args))
|
||||
return (self.function, str(self.field)) + a
|
||||
|
||||
def predicate_leafs(self, leafs):
|
||||
# type: (Set[PredLeaf]) -> None
|
||||
leafs.add(self)
|
||||
|
||||
def rust_predicate(self, prec):
|
||||
# type: (int) -> str
|
||||
"""
|
||||
Return a string of Rust code that evaluates this predicate.
|
||||
"""
|
||||
# Prepend `field` to the predicate function arguments.
|
||||
args = (self.field.rust_name(),) + tuple(map(str, self.args))
|
||||
return 'predicates::{}({})'.format(self.function, ', '.join(args))
|
||||
|
||||
|
||||
class IsEqual(FieldPredicate):
|
||||
"""
|
||||
Instruction predicate that checks if an immediate instruction format field
|
||||
is equal to a constant value.
|
||||
|
||||
:param field: `FormatField` to be checked.
|
||||
:param value: The constant value to compare against.
|
||||
"""
|
||||
|
||||
def __init__(self, field, value):
|
||||
# type: (FormatField, Any) -> None
|
||||
super(IsEqual, self).__init__(field, 'is_equal', (value,))
|
||||
self.value = value
|
||||
|
||||
|
||||
class IsZero32BitFloat(FieldPredicate):
|
||||
"""
|
||||
Instruction predicate that checks if an immediate instruction format field
|
||||
is equal to zero.
|
||||
|
||||
:param field: `FormatField` to be checked.
|
||||
:param value: The constant value to check.
|
||||
"""
|
||||
|
||||
def __init__(self, field):
|
||||
# type: (FormatField) -> None
|
||||
super(IsZero32BitFloat, self).__init__(field,
|
||||
'is_zero_32_bit_float',
|
||||
())
|
||||
|
||||
|
||||
class IsZero64BitFloat(FieldPredicate):
|
||||
"""
|
||||
Instruction predicate that checks if an immediate instruction format field
|
||||
is equal to zero.
|
||||
|
||||
:param field: `FormatField` to be checked.
|
||||
:param value: The constant value to check.
|
||||
"""
|
||||
|
||||
def __init__(self, field):
|
||||
# type: (FormatField) -> None
|
||||
super(IsZero64BitFloat, self).__init__(field,
|
||||
'is_zero_64_bit_float',
|
||||
())
|
||||
|
||||
|
||||
class IsSignedInt(FieldPredicate):
|
||||
"""
|
||||
Instruction predicate that checks if an immediate instruction format field
|
||||
is representable as an n-bit two's complement integer.
|
||||
|
||||
:param field: `FormatField` to be checked.
|
||||
:param width: Number of bits in the allowed range.
|
||||
:param scale: Number of low bits that must be 0.
|
||||
|
||||
The predicate is true if the field is in the range:
|
||||
`-2^(width-1) -- 2^(width-1)-1`
|
||||
and a multiple of `2^scale`.
|
||||
"""
|
||||
|
||||
def __init__(self, field, width, scale=0):
|
||||
# type: (FormatField, int, int) -> None
|
||||
super(IsSignedInt, self).__init__(
|
||||
field, 'is_signed_int', (width, scale))
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
assert width >= 0 and width <= 64
|
||||
assert scale >= 0 and scale < width
|
||||
|
||||
|
||||
class IsUnsignedInt(FieldPredicate):
|
||||
"""
|
||||
Instruction predicate that checks if an immediate instruction format field
|
||||
is representable as an n-bit unsigned complement integer.
|
||||
|
||||
:param field: `FormatField` to be checked.
|
||||
:param width: Number of bits in the allowed range.
|
||||
:param scale: Number of low bits that must be 0.
|
||||
|
||||
The predicate is true if the field is in the range:
|
||||
`0 -- 2^width - 1` and a multiple of `2^scale`.
|
||||
"""
|
||||
|
||||
def __init__(self, field, width, scale=0):
|
||||
# type: (FormatField, int, int) -> None
|
||||
super(IsUnsignedInt, self).__init__(
|
||||
field, 'is_unsigned_int', (width, scale))
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
assert width >= 0 and width <= 64
|
||||
assert scale >= 0 and scale < width
|
||||
|
||||
|
||||
class TypePredicate(object):
|
||||
"""
|
||||
An instruction predicate that checks the type of an SSA argument value.
|
||||
|
||||
Type predicates are used to implement encodings for instructions with
|
||||
multiple type variables. The encoding tables are keyed by the controlling
|
||||
type variable, type predicates check any secondary type variables.
|
||||
|
||||
A type predicate is not bound to any specific instruction format.
|
||||
|
||||
:param value_arg: Index of the value argument to type check.
|
||||
:param value_type: The required value type.
|
||||
"""
|
||||
|
||||
def __init__(self, value_arg, value_type):
|
||||
# type: (int, ValueType) -> None
|
||||
assert value_arg >= 0
|
||||
assert value_type is not None
|
||||
self.value_arg = value_arg
|
||||
self.value_type = value_type
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return 'args[{}]:{}'.format(self.value_arg, self.value_type)
|
||||
|
||||
def predicate_context(self):
|
||||
# type: () -> PredContext
|
||||
return instruction_context
|
||||
|
||||
def predicate_key(self):
|
||||
# type: () -> PredKey
|
||||
return ('typecheck', self.value_arg, self.value_type.name)
|
||||
|
||||
def predicate_leafs(self, leafs):
|
||||
# type: (Set[PredLeaf]) -> None
|
||||
leafs.add(self)
|
||||
|
||||
@staticmethod
|
||||
def typevar_check(inst, typevar, value_type):
|
||||
# type: (Instruction, TypeVar, ValueType) -> TypePredicate
|
||||
"""
|
||||
Return a type check predicate for the given type variable in `inst`.
|
||||
|
||||
The type variable must appear directly as the type of one of the
|
||||
operands to `inst`, so this is only guaranteed to work for secondary
|
||||
type variables.
|
||||
|
||||
Find an `inst` value operand whose type is determined by `typevar` and
|
||||
create a `TypePredicate` that checks that the type variable has the
|
||||
value `value_type`.
|
||||
"""
|
||||
# Find the first value operand whose type is `typevar`.
|
||||
value_arg = next(i for i, opnum in enumerate(inst.value_opnums)
|
||||
if inst.ins[opnum].typevar == typevar)
|
||||
return TypePredicate(value_arg, value_type)
|
||||
|
||||
def rust_predicate(self, prec):
|
||||
# type: (int) -> str
|
||||
"""
|
||||
Return Rust code for evaluating this predicate.
|
||||
|
||||
It is assumed that the context has `func` and `args` variables.
|
||||
"""
|
||||
return 'func.dfg.value_type(args[{}]) == {}'.format(
|
||||
self.value_arg, self.value_type.rust_name())
|
||||
|
||||
|
||||
class CtrlTypePredicate(object):
|
||||
"""
|
||||
An instruction predicate that checks the controlling type variable
|
||||
|
||||
:param value_type: The required value type.
|
||||
"""
|
||||
|
||||
def __init__(self, value_type):
|
||||
# type: (ValueType) -> None
|
||||
assert value_type is not None
|
||||
self.value_type = value_type
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return 'ctrl_typevar:{}'.format(self.value_type)
|
||||
|
||||
def predicate_context(self):
|
||||
# type: () -> PredContext
|
||||
return instruction_context
|
||||
|
||||
def predicate_key(self):
|
||||
# type: () -> PredKey
|
||||
return ('ctrltypecheck', self.value_type.name)
|
||||
|
||||
def predicate_leafs(self, leafs):
|
||||
# type: (Set[PredLeaf]) -> None
|
||||
leafs.add(self)
|
||||
|
||||
def rust_predicate(self, prec):
|
||||
# type: (int) -> str
|
||||
"""
|
||||
Return Rust code for evaluating this predicate.
|
||||
|
||||
It is assumed that the context has `func` and `inst` variables.
|
||||
"""
|
||||
return 'func.dfg.ctrl_typevar(inst) == {}'.format(
|
||||
self.value_type.rust_name())
|
||||
414
lib/codegen/meta-python/cdsl/registers.py
Normal file
414
lib/codegen/meta-python/cdsl/registers.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
Register set definitions
|
||||
------------------------
|
||||
|
||||
Each ISA defines a separate register set that is used by the register allocator
|
||||
and the final binary encoding of machine code.
|
||||
|
||||
The CPU registers are first divided into disjoint register banks, represented
|
||||
by a `RegBank` instance. Registers in different register banks never interfere
|
||||
with each other. A typical CPU will have a general purpose and a floating point
|
||||
register bank.
|
||||
|
||||
A register bank consists of a number of *register units* which are the smallest
|
||||
indivisible units of allocation and interference. A register unit doesn't
|
||||
necessarily correspond to a particular number of bits in a register, it is more
|
||||
like a placeholder that can be used to determine of a register is taken or not.
|
||||
|
||||
The register allocator works with *register classes* which can allocate one or
|
||||
more register units at a time. A register class allocates more than one
|
||||
register unit at a time when its registers are composed of smaller allocatable
|
||||
units. For example, the ARM double precision floating point registers are
|
||||
composed of two single precision registers.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from . import is_power_of_two, next_power_of_two
|
||||
|
||||
|
||||
try:
|
||||
from typing import Sequence, Tuple, List, Dict, Any, Optional, TYPE_CHECKING # noqa
|
||||
if TYPE_CHECKING:
|
||||
from .isa import TargetISA # noqa
|
||||
# A tuple uniquely identifying a register class inside a register bank.
|
||||
# (width, bitmask)
|
||||
RCTup = Tuple[int, int]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# The number of 32-bit elements in a register unit mask
|
||||
MASK_LEN = 3
|
||||
|
||||
# The maximum total number of register units allowed.
|
||||
# This limit can be raised by also adjusting the RegUnitMask type in
|
||||
# src/isa/registers.rs.
|
||||
MAX_UNITS = MASK_LEN * 32
|
||||
|
||||
|
||||
class RegBank(object):
|
||||
"""
|
||||
A register bank belonging to an ISA.
|
||||
|
||||
A register bank controls a set of *register units* disjoint from all the
|
||||
other register banks in the ISA. The register units are numbered uniquely
|
||||
within the target ISA, and the units in a register bank form a contiguous
|
||||
sequence starting from a sufficiently aligned point that their low bits can
|
||||
be used directly when encoding machine code instructions.
|
||||
|
||||
Register units can be given generated names like `r0`, `r1`, ..., or a
|
||||
tuple of special register unit names can be provided.
|
||||
|
||||
:param name: Name of this register bank.
|
||||
:param doc: Documentation string.
|
||||
:param units: Number of register units.
|
||||
:param pressure_tracking: Enable tracking of register pressure.
|
||||
:param prefix: Prefix for generated unit names.
|
||||
:param names: Special names for the first units. May be shorter than
|
||||
`units`, the remaining units are named using `prefix`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name, # type: str
|
||||
isa, # type: TargetISA
|
||||
doc, # type: str
|
||||
units, # type: int
|
||||
pressure_tracking=True, # type: bool
|
||||
prefix='r', # type: str
|
||||
names=() # type: Sequence[str]
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.name = name
|
||||
self.isa = isa
|
||||
self.first_unit = 0
|
||||
self.units = units
|
||||
self.pressure_tracking = pressure_tracking
|
||||
self.prefix = prefix
|
||||
self.names = names
|
||||
self.classes = list() # type: List[RegClass]
|
||||
self.toprcs = list() # type: List[RegClass]
|
||||
self.first_toprc_index = None # type: int
|
||||
|
||||
assert len(names) <= units
|
||||
|
||||
if isa.regbanks:
|
||||
# Get the next free unit number.
|
||||
last = isa.regbanks[-1]
|
||||
u = last.first_unit + last.units
|
||||
align = units
|
||||
if not is_power_of_two(align):
|
||||
align = next_power_of_two(align)
|
||||
self.first_unit = (u + align - 1) & -align
|
||||
|
||||
self.index = len(isa.regbanks)
|
||||
isa.regbanks.append(self)
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return ('RegBank({}, units={}, first_unit={})'
|
||||
.format(self.name, self.units, self.first_unit))
|
||||
|
||||
def finish_regclasses(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Compute subclasses and the top-level register class.
|
||||
|
||||
Verify that the set of register classes satisfies:
|
||||
|
||||
1. Closed under intersection: The intersection of any two register
|
||||
classes in the set is either empty or identical to a member of the
|
||||
set.
|
||||
2. There are no identical classes under different names.
|
||||
3. Classes are sorted topologically such that all subclasses have a
|
||||
higher index that the superclass.
|
||||
|
||||
We could reorder classes topologically here instead of just enforcing
|
||||
the order, but the ordering tends to fall out naturally anyway.
|
||||
"""
|
||||
cmap = dict() # type: Dict[RCTup, RegClass]
|
||||
|
||||
for rc in self.classes:
|
||||
# All register classes must be given a name.
|
||||
assert rc.name, "Anonymous register class found"
|
||||
|
||||
# Check for duplicates.
|
||||
tup = rc.rctup()
|
||||
if tup in cmap:
|
||||
raise AssertionError(
|
||||
'{} and {} are identical register classes'
|
||||
.format(rc, cmap[tup]))
|
||||
cmap[tup] = rc
|
||||
|
||||
# Check intersections and topological order.
|
||||
for idx, rc1 in enumerate(self.classes):
|
||||
rc1.toprc = rc1
|
||||
for rc2 in self.classes[0:idx]:
|
||||
itup = rc1.intersect(rc2)
|
||||
if itup is None:
|
||||
continue
|
||||
if itup not in cmap:
|
||||
raise AssertionError(
|
||||
'intersection of {} and {} missing'
|
||||
.format(rc1, rc2))
|
||||
irc = cmap[itup]
|
||||
# rc1 > rc2, so rc2 can't be the sub-class.
|
||||
if irc is rc2:
|
||||
raise AssertionError(
|
||||
'Bad topological order: {}/{}'
|
||||
.format(rc1, rc2))
|
||||
if irc is rc1:
|
||||
# The intersection of rc1 and rc2 is rc1, so it must be a
|
||||
# sub-class.
|
||||
rc2.subclasses.append(rc1)
|
||||
rc1.toprc = rc2.toprc
|
||||
|
||||
if rc1.is_toprc():
|
||||
self.toprcs.append(rc1)
|
||||
|
||||
def unit_by_name(self, name):
|
||||
# type: (str) -> int
|
||||
"""
|
||||
Get a register unit in this bank by name.
|
||||
"""
|
||||
if name in self.names:
|
||||
r = self.names.index(name)
|
||||
elif name.startswith(self.prefix):
|
||||
r = int(name[len(self.prefix):])
|
||||
assert r < self.units, 'Invalid register name: ' + name
|
||||
return self.first_unit + r
|
||||
|
||||
|
||||
class RegClass(object):
|
||||
"""
|
||||
A register class is a subset of register units in a RegBank along with a
|
||||
strategy for allocating registers.
|
||||
|
||||
The *width* parameter determines how many register units are allocated at a
|
||||
time. Usually it that is one, but for example the ARM D registers are
|
||||
allocated two units at a time. When multiple units are allocated, it is
|
||||
always a contiguous set of unit numbers.
|
||||
|
||||
:param bank: The register bank we're allocating from.
|
||||
:param count: The maximum number of allocations in this register class. By
|
||||
default, the whole register bank can be allocated.
|
||||
:param width: How many units to allocate at a time.
|
||||
:param start: The first unit to allocate, relative to `bank.first.unit`.
|
||||
"""
|
||||
|
||||
def __init__(self, bank, count=0, width=1, start=0, bitmask=None):
|
||||
# type: (RegBank, int, int, int, Optional[int]) -> None
|
||||
self.name = None # type: str
|
||||
self.index = None # type: int
|
||||
self.bank = bank
|
||||
self.width = width
|
||||
self.bitmask = 0
|
||||
|
||||
# This is computed later in `finish_regclasses()`.
|
||||
self.subclasses = list() # type: List[RegClass]
|
||||
self.toprc = None # type: RegClass
|
||||
|
||||
assert width > 0
|
||||
|
||||
if bitmask:
|
||||
self.bitmask = bitmask
|
||||
else:
|
||||
assert start >= 0 and start < bank.units
|
||||
if count == 0:
|
||||
count = bank.units // width
|
||||
for a in range(count):
|
||||
u = start + a * self.width
|
||||
self.bitmask |= 1 << u
|
||||
|
||||
bank.classes.append(self)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return self.name
|
||||
|
||||
def is_toprc(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Is this a top-level register class?
|
||||
|
||||
A top-level register class has no sub-classes. This can only be
|
||||
answered aster running `finish_regclasses()`.
|
||||
"""
|
||||
return self.toprc is self
|
||||
|
||||
def rctup(self):
|
||||
# type: () -> RCTup
|
||||
"""
|
||||
Get a tuple that uniquely identifies the registers in this class.
|
||||
|
||||
The tuple can be used as a dictionary key to ensure that there are no
|
||||
duplicate register classes.
|
||||
"""
|
||||
return (self.width, self.bitmask)
|
||||
|
||||
def intersect(self, other):
|
||||
# type: (RegClass) -> RCTup
|
||||
"""
|
||||
Get a tuple representing the intersction of two register classes.
|
||||
|
||||
Returns `None` if the two classes are disjoint.
|
||||
"""
|
||||
if self.width != other.width:
|
||||
return None
|
||||
intersection = self.bitmask & other.bitmask
|
||||
if intersection == 0:
|
||||
return None
|
||||
|
||||
return (self.width, intersection)
|
||||
|
||||
def __getitem__(self, sliced):
|
||||
# type: (slice) -> RegClass
|
||||
"""
|
||||
Create a sub-class of a register class using slice notation. The slice
|
||||
indexes refer to allocations in the parent register class, not register
|
||||
units.
|
||||
"""
|
||||
assert isinstance(sliced, slice), "RegClass slicing can't be 1 reg"
|
||||
# We could add strided sub-classes if needed.
|
||||
assert sliced.step is None, 'Subclass striding not supported'
|
||||
# Can't slice a non-contiguous class
|
||||
assert self.is_contiguous(), 'Cannot slice non-contiguous RegClass'
|
||||
|
||||
w = self.width
|
||||
s = self.start() + sliced.start * w
|
||||
c = sliced.stop - sliced.start
|
||||
assert c > 1, "Can't have single-register classes"
|
||||
|
||||
return RegClass(self.bank, count=c, width=w, start=s)
|
||||
|
||||
def without(self, *registers):
|
||||
# type: (*Register) -> RegClass
|
||||
"""
|
||||
Create a sub-class of a register class excluding a specific set of
|
||||
registers.
|
||||
|
||||
For example: GPR.without(GPR.r9)
|
||||
"""
|
||||
bm = self.bitmask
|
||||
w = self.width
|
||||
fmask = (1 << self.width) - 1
|
||||
for reg in registers:
|
||||
bm &= ~(fmask << (reg.unit * w))
|
||||
|
||||
return RegClass(self.bank, bitmask=bm)
|
||||
|
||||
def is_contiguous(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Returns boolean indicating whether a register class is a contiguous set
|
||||
of register units.
|
||||
"""
|
||||
x = self.bitmask | (self.bitmask-1)
|
||||
return self.bitmask != 0 and ((x+1) & x) == 0
|
||||
|
||||
def start(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
Returns the first valid register unit in this class.
|
||||
"""
|
||||
start = 0
|
||||
bm = self.bitmask
|
||||
fmask = (1 << self.width) - 1
|
||||
while True:
|
||||
if bm & fmask > 0:
|
||||
break
|
||||
start += 1
|
||||
bm >>= self.width
|
||||
|
||||
return start
|
||||
|
||||
def __getattr__(self, attr):
|
||||
# type: (str) -> Register
|
||||
"""
|
||||
Get a specific register in the class by name.
|
||||
|
||||
For example: `GPR.r5`.
|
||||
"""
|
||||
reg = Register(self, self.bank.unit_by_name(attr))
|
||||
# Save this register so we won't have to create it again.
|
||||
setattr(self, attr, reg)
|
||||
return reg
|
||||
|
||||
def mask(self):
|
||||
# type: () -> List[int]
|
||||
"""
|
||||
Compute a bit-mask of the register units allocated by this register
|
||||
class.
|
||||
|
||||
Return as a list of 32-bit integers.
|
||||
"""
|
||||
out_mask = []
|
||||
mask32 = (1 << 32) - 1
|
||||
bitmask = self.bitmask << self.bank.first_unit
|
||||
for i in range(MASK_LEN):
|
||||
out_mask.append((bitmask >> (i * 32)) & mask32)
|
||||
|
||||
return out_mask
|
||||
|
||||
def subclass_mask(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
Compute a bit-mask of subclasses, including self.
|
||||
"""
|
||||
m = 1 << self.index
|
||||
for rc in self.subclasses:
|
||||
m |= 1 << rc.index
|
||||
return m
|
||||
|
||||
@staticmethod
|
||||
def extract_names(globs):
|
||||
# type: (Dict[str, Any]) -> None
|
||||
"""
|
||||
Given a dict mapping name -> object as returned by `globals()`, find
|
||||
all the RegClass objects and set their name from the dict key.
|
||||
This is used to name a bunch of global values in a module.
|
||||
"""
|
||||
for name, obj in globs.items():
|
||||
if isinstance(obj, RegClass):
|
||||
assert obj.name is None
|
||||
obj.name = name
|
||||
|
||||
|
||||
class Register(object):
|
||||
"""
|
||||
A specific register in a register class.
|
||||
|
||||
A register is identified by the top-level register class it belongs to and
|
||||
its first register unit.
|
||||
|
||||
Specific registers are used to describe constraints on instructions where
|
||||
some operands must use a fixed register.
|
||||
|
||||
Register instances can be created with the constructor, or accessed as
|
||||
attributes on the register class: `GPR.rcx`.
|
||||
"""
|
||||
def __init__(self, rc, unit):
|
||||
# type: (RegClass, int) -> None
|
||||
self.regclass = rc
|
||||
self.unit = unit
|
||||
|
||||
|
||||
class Stack(object):
|
||||
"""
|
||||
An operand that must be in a stack slot.
|
||||
|
||||
A `Stack` object can be used to indicate an operand constraint for a value
|
||||
operand that must live in a stack slot.
|
||||
"""
|
||||
def __init__(self, rc):
|
||||
# type: (RegClass) -> None
|
||||
self.regclass = rc
|
||||
|
||||
def stack_base_mask(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Get the StackBaseMask to use for this operand.
|
||||
|
||||
This is a mask of base registers that can be supported by this operand.
|
||||
"""
|
||||
# TODO: Make this configurable instead of just using the SP.
|
||||
return 'StackBaseMask(1)'
|
||||
407
lib/codegen/meta-python/cdsl/settings.py
Normal file
407
lib/codegen/meta-python/cdsl/settings.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""Classes for describing settings and groups of settings."""
|
||||
from __future__ import absolute_import
|
||||
from collections import OrderedDict
|
||||
from .predicates import Predicate
|
||||
|
||||
try:
|
||||
from typing import Tuple, Set, List, Dict, Any, Union, TYPE_CHECKING # noqa
|
||||
BoolOrPresetOrDict = Union['BoolSetting', 'Preset', Dict['Setting', Any]]
|
||||
if TYPE_CHECKING:
|
||||
from .predicates import PredLeaf, PredNode, PredKey # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class Setting(object):
|
||||
"""
|
||||
A named setting variable that can be configured externally to Cranelift.
|
||||
|
||||
Settings are normally not named when they are created. They get their name
|
||||
from the `extract_names` method.
|
||||
"""
|
||||
|
||||
def __init__(self, doc):
|
||||
# type: (str) -> None
|
||||
self.name = None # type: str # Assigned later by `extract_names()`.
|
||||
self.__doc__ = doc
|
||||
# Offset of byte in settings vector containing this setting.
|
||||
self.byte_offset = None # type: int
|
||||
# Index into the generated DESCRIPTORS table.
|
||||
self.descriptor_index = None # type: int
|
||||
|
||||
self.group = SettingGroup.append(self)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return '{}.{}'.format(self.group.name, self.name)
|
||||
|
||||
def default_byte(self):
|
||||
# type: () -> int
|
||||
raise NotImplementedError("default_byte is an abstract method")
|
||||
|
||||
def byte_for_value(self, value):
|
||||
# type: (Any) -> int
|
||||
"""Get the setting byte value that corresponds to `value`"""
|
||||
raise NotImplementedError("byte_for_value is an abstract method")
|
||||
|
||||
def byte_mask(self):
|
||||
# type: () -> int
|
||||
"""Get a mask of bits in our byte that are relevant to this setting."""
|
||||
# Only BoolSetting has a different mask.
|
||||
return 0xff
|
||||
|
||||
|
||||
class BoolSetting(Setting):
|
||||
"""
|
||||
A named setting with a boolean on/off value.
|
||||
|
||||
:param doc: Documentation string.
|
||||
:param default: The default value of this setting.
|
||||
"""
|
||||
|
||||
def __init__(self, doc, default=False):
|
||||
# type: (str, bool) -> None
|
||||
super(BoolSetting, self).__init__(doc)
|
||||
self.default = default
|
||||
self.bit_offset = None # type: int
|
||||
|
||||
def default_byte(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
Get the default value of this setting, as a byte that can be bitwise
|
||||
or'ed with the other booleans sharing the same byte.
|
||||
"""
|
||||
if self.default:
|
||||
return 1 << self.bit_offset
|
||||
else:
|
||||
return 0
|
||||
|
||||
def byte_for_value(self, value):
|
||||
# type: (Any) -> int
|
||||
if value:
|
||||
return 1 << self.bit_offset
|
||||
else:
|
||||
return 0
|
||||
|
||||
def byte_mask(self):
|
||||
# type: () -> int
|
||||
return 1 << self.bit_offset
|
||||
|
||||
def predicate_context(self):
|
||||
# type: () -> SettingGroup
|
||||
"""
|
||||
Return the context where this setting can be evaluated as a (leaf)
|
||||
predicate.
|
||||
"""
|
||||
return self.group
|
||||
|
||||
def predicate_key(self):
|
||||
# type: () -> PredKey
|
||||
assert self.name, "Can't compute key before setting is named"
|
||||
return ('setting', self.group.name, self.name)
|
||||
|
||||
def predicate_leafs(self, leafs):
|
||||
# type: (Set[PredLeaf]) -> None
|
||||
leafs.add(self)
|
||||
|
||||
def rust_predicate(self, prec):
|
||||
# type: (int) -> str
|
||||
"""
|
||||
Return the Rust code to compute the value of this setting.
|
||||
|
||||
The emitted code assumes that the setting group exists as a local
|
||||
variable.
|
||||
"""
|
||||
return '{}.{}()'.format(self.group.name, self.name)
|
||||
|
||||
|
||||
class NumSetting(Setting):
|
||||
"""
|
||||
A named setting with an integral value in the range 0--255.
|
||||
|
||||
:param doc: Documentation string.
|
||||
:param default: The default value of this setting.
|
||||
"""
|
||||
|
||||
def __init__(self, doc, default=0):
|
||||
# type: (str, int) -> None
|
||||
super(NumSetting, self).__init__(doc)
|
||||
assert default == int(default)
|
||||
assert default >= 0 and default <= 255
|
||||
self.default = default
|
||||
|
||||
def default_byte(self):
|
||||
# type: () -> int
|
||||
return self.default
|
||||
|
||||
def byte_for_value(self, value):
|
||||
# type: (Any) -> int
|
||||
assert isinstance(value, int), "NumSetting must be set to an int"
|
||||
assert value >= 0 and value <= 255
|
||||
return value
|
||||
|
||||
|
||||
class EnumSetting(Setting):
|
||||
"""
|
||||
A named setting with an enumerated set of possible values.
|
||||
|
||||
The default value is always the first enumerator.
|
||||
|
||||
:param doc: Documentation string.
|
||||
:param args: Tuple of unique strings representing the possible values.
|
||||
"""
|
||||
|
||||
def __init__(self, doc, *args):
|
||||
# type: (str, *str) -> None
|
||||
super(EnumSetting, self).__init__(doc)
|
||||
assert len(args) > 0, "EnumSetting must have at least one value"
|
||||
self.values = tuple(str(x) for x in args)
|
||||
self.default = self.values[0]
|
||||
|
||||
def default_byte(self):
|
||||
# type: () -> int
|
||||
return 0
|
||||
|
||||
def byte_for_value(self, value):
|
||||
# type: (Any) -> int
|
||||
return self.values.index(value)
|
||||
|
||||
|
||||
class SettingGroup(object):
|
||||
"""
|
||||
A group of settings.
|
||||
|
||||
Whenever a :class:`Setting` object is created, it is added to the currently
|
||||
open group. A setting group must be closed explicitly before another can be
|
||||
opened.
|
||||
|
||||
:param name: Short mnemonic name for setting group.
|
||||
:param parent: Parent settings group.
|
||||
"""
|
||||
|
||||
# The currently open setting group.
|
||||
_current = None # type: SettingGroup
|
||||
|
||||
def __init__(self, name, parent=None):
|
||||
# type: (str, SettingGroup) -> None
|
||||
self.name = name
|
||||
self.parent = parent
|
||||
self.settings = [] # type: List[Setting]
|
||||
# Named predicates computed from settings in this group or its
|
||||
# parents.
|
||||
self.named_predicates = OrderedDict() # type: OrderedDict[str, Predicate] # noqa
|
||||
# All boolean predicates that can be accessed by number. This includes:
|
||||
# - All boolean settings in this group.
|
||||
# - All named predicates.
|
||||
# - Added anonymous predicates, see `number_predicate()`.
|
||||
# - Added parent predicates that are replicated in this group.
|
||||
# Maps predicate -> number.
|
||||
self.predicate_number = OrderedDict() # type: OrderedDict[PredNode, int] # noqa
|
||||
self.presets = [] # type: List[Preset]
|
||||
|
||||
# Fully qualified Rust module name. See gen_settings.py.
|
||||
self.qual_mod = None # type: str
|
||||
|
||||
self.open()
|
||||
|
||||
def open(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Open this setting group such that future new settings are added to this
|
||||
group.
|
||||
"""
|
||||
assert SettingGroup._current is None, (
|
||||
"Can't open {} since {} is already open"
|
||||
.format(self, SettingGroup._current))
|
||||
SettingGroup._current = self
|
||||
|
||||
def close(self, globs=None):
|
||||
# type: (Dict[str, Any]) -> None
|
||||
"""
|
||||
Close this setting group. This function must be called before opening
|
||||
another setting group.
|
||||
|
||||
:param globs: Pass in `globals()` to run `extract_names` on all
|
||||
settings defined in the module.
|
||||
"""
|
||||
assert SettingGroup._current is self, (
|
||||
"Can't close {}, the open setting group is {}"
|
||||
.format(self, SettingGroup._current))
|
||||
SettingGroup._current = None
|
||||
if globs:
|
||||
for name, obj in globs.items():
|
||||
if isinstance(obj, Setting):
|
||||
assert obj.name is None, obj.name
|
||||
obj.name = name
|
||||
if isinstance(obj, Predicate):
|
||||
self.named_predicates[name] = obj
|
||||
if isinstance(obj, Preset):
|
||||
assert obj.name is None, obj.name
|
||||
obj.name = name
|
||||
|
||||
self.layout()
|
||||
|
||||
@staticmethod
|
||||
def append(setting):
|
||||
# type: (Setting) -> SettingGroup
|
||||
g = SettingGroup._current
|
||||
assert g, "Open a setting group before defining settings."
|
||||
g.settings.append(setting)
|
||||
return g
|
||||
|
||||
@staticmethod
|
||||
def append_preset(preset):
|
||||
# type: (Preset) -> SettingGroup
|
||||
g = SettingGroup._current
|
||||
assert g, "Open a setting group before defining presets."
|
||||
g.presets.append(preset)
|
||||
return g
|
||||
|
||||
def number_predicate(self, pred):
|
||||
# type: (PredNode) -> int
|
||||
"""
|
||||
Make sure that `pred` has an assigned number, and will be included in
|
||||
this group's bit vector.
|
||||
|
||||
The numbered predicates include:
|
||||
- `BoolSetting` settings that belong to this group.
|
||||
- `Predicate` instances in `named_predicates`.
|
||||
- `Predicate` instances without a name.
|
||||
- Settings or computed predicates that belong to the parent group, but
|
||||
need to be accessible by number in this group.
|
||||
|
||||
The numbered predicates are referenced by the encoding tables as ISA
|
||||
predicates. See the `isap` field on `Encoding`.
|
||||
|
||||
:returns: The assigned predicate number in this group.
|
||||
"""
|
||||
if pred in self.predicate_number:
|
||||
return self.predicate_number[pred]
|
||||
else:
|
||||
number = len(self.predicate_number)
|
||||
self.predicate_number[pred] = number
|
||||
return number
|
||||
|
||||
def layout(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Compute the layout of the byte vector used to represent this settings
|
||||
group.
|
||||
|
||||
The byte vector contains the following entries in order:
|
||||
|
||||
1. Byte-sized settings like `NumSetting` and `EnumSetting`.
|
||||
2. `BoolSetting` settings.
|
||||
3. Precomputed named predicates.
|
||||
4. Other numbered predicates, including anonymous predicates and parent
|
||||
predicates that need to be accessible by number.
|
||||
|
||||
Set `self.settings_size` to the length of the byte vector prefix that
|
||||
contains the settings. All bytes after that are computed, not
|
||||
configured.
|
||||
|
||||
Set `self.boolean_offset` to the beginning of the numbered predicates,
|
||||
2. in the list above.
|
||||
|
||||
Assign `byte_offset` and `bit_offset` fields in all settings.
|
||||
|
||||
After calling this method, no more settings can be added, but
|
||||
additional predicates can be made accessible with `number_predicate()`.
|
||||
"""
|
||||
assert len(self.predicate_number) == 0, "Too late for layout"
|
||||
|
||||
# Assign the non-boolean settings.
|
||||
byte_offset = 0
|
||||
for s in self.settings:
|
||||
if not isinstance(s, BoolSetting):
|
||||
s.byte_offset = byte_offset
|
||||
byte_offset += 1
|
||||
|
||||
# Then the boolean settings.
|
||||
self.boolean_offset = byte_offset
|
||||
for s in self.settings:
|
||||
if isinstance(s, BoolSetting):
|
||||
number = self.number_predicate(s)
|
||||
s.byte_offset = byte_offset + number // 8
|
||||
s.bit_offset = number % 8
|
||||
|
||||
# This is the end of the settings. Round up to a whole number of bytes.
|
||||
self.boolean_settings = len(self.predicate_number)
|
||||
self.settings_size = self.byte_size()
|
||||
|
||||
# Now assign numbers to all our named predicates.
|
||||
for name, pred in self.named_predicates.items():
|
||||
self.number_predicate(pred)
|
||||
|
||||
def byte_size(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
Compute the number of bytes required to hold all settings and
|
||||
precomputed predicates.
|
||||
|
||||
This is the size of the byte-sized settings plus all the numbered
|
||||
predcate bits rounded up to a whole number of bytes.
|
||||
"""
|
||||
return self.boolean_offset + (len(self.predicate_number) + 7) // 8
|
||||
|
||||
|
||||
class Preset(object):
|
||||
"""
|
||||
A collection of setting values that are applied at once.
|
||||
|
||||
A `Preset` represents a shorthand notation for applying a number of
|
||||
settings at once. Example:
|
||||
|
||||
nehalem = Preset(has_sse41, has_cmov, has_avx=0)
|
||||
|
||||
Enabling the `nehalem` setting is equivalent to enabling `has_sse41` and
|
||||
`has_cmov` while disabling the `has_avx` setting.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
# type: (*BoolOrPresetOrDict) -> None
|
||||
self.name = None # type: str # Assigned later by `SettingGroup`.
|
||||
# Each tuple provides the value for a setting.
|
||||
self.values = list() # type: List[Tuple[Setting, Any]]
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, Preset):
|
||||
# Any presets in args are immediately expanded.
|
||||
self.values.extend(arg.values)
|
||||
elif isinstance(arg, dict):
|
||||
# A dictionary of key: value pairs.
|
||||
self.values.extend(arg.items())
|
||||
else:
|
||||
# A BoolSetting to enable.
|
||||
assert isinstance(arg, BoolSetting)
|
||||
self.values.append((arg, True))
|
||||
|
||||
self.group = SettingGroup.append_preset(self)
|
||||
# Index into the generated DESCRIPTORS table.
|
||||
self.descriptor_index = None # type: int
|
||||
|
||||
def layout(self):
|
||||
# type: () -> List[Tuple[int, int]]
|
||||
"""
|
||||
Compute a list of (mask, byte) pairs that incorporate all values in
|
||||
this preset.
|
||||
|
||||
The list will have an entry for each setting byte in the settings
|
||||
group.
|
||||
"""
|
||||
lst = [(0, 0)] * self.group.settings_size
|
||||
|
||||
# Apply setting values in order.
|
||||
for s, v in self.values:
|
||||
ofs = s.byte_offset
|
||||
s_mask = s.byte_mask()
|
||||
s_val = s.byte_for_value(v)
|
||||
assert (s_val & ~s_mask) == 0
|
||||
l_mask, l_val = lst[ofs]
|
||||
# Accumulated mask of modified bits.
|
||||
l_mask |= s_mask
|
||||
# Overwrite the relevant bits with the new value.
|
||||
l_val = (l_val & ~s_mask) | s_val
|
||||
lst[ofs] = (l_mask, l_val)
|
||||
|
||||
return lst
|
||||
28
lib/codegen/meta-python/cdsl/test_ast.py
Normal file
28
lib/codegen/meta-python/cdsl/test_ast.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import absolute_import
|
||||
from unittest import TestCase
|
||||
from doctest import DocTestSuite
|
||||
from . import ast
|
||||
from base.instructions import jump, iadd
|
||||
|
||||
|
||||
def load_tests(loader, tests, ignore):
|
||||
tests.addTests(DocTestSuite(ast))
|
||||
return tests
|
||||
|
||||
|
||||
x = 'x'
|
||||
y = 'y'
|
||||
a = 'a'
|
||||
|
||||
|
||||
class TestPatterns(TestCase):
|
||||
def test_apply(self):
|
||||
i = jump(x, y)
|
||||
self.assertEqual(repr(i), "Apply(jump, ('x', 'y'))")
|
||||
|
||||
i = iadd.i32(x, y)
|
||||
self.assertEqual(repr(i), "Apply(iadd.i32, ('x', 'y'))")
|
||||
|
||||
def test_single_ins(self):
|
||||
pat = a << iadd.i32(x, y)
|
||||
self.assertEqual(repr(pat), "('a',) << Apply(iadd.i32, ('x', 'y'))")
|
||||
8
lib/codegen/meta-python/cdsl/test_package.py
Normal file
8
lib/codegen/meta-python/cdsl/test_package.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from __future__ import absolute_import
|
||||
import doctest
|
||||
import cdsl
|
||||
|
||||
|
||||
def load_tests(loader, tests, ignore):
|
||||
tests.addTests(doctest.DocTestSuite(cdsl))
|
||||
return tests
|
||||
605
lib/codegen/meta-python/cdsl/test_ti.py
Normal file
605
lib/codegen/meta-python/cdsl/test_ti.py
Normal file
@@ -0,0 +1,605 @@
|
||||
from __future__ import absolute_import
|
||||
from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\
|
||||
b1, icmp, iadd_cout, iadd_cin, uextend, sextend, ireduce, fpromote, \
|
||||
fdemote
|
||||
from base.legalize import narrow, expand
|
||||
from base.immediates import intcc
|
||||
from base.types import i32, i8
|
||||
from .typevar import TypeVar
|
||||
from .ast import Var, Def
|
||||
from .xform import Rtl, XForm
|
||||
from .ti import ti_rtl, subst, TypeEnv, get_type_env, TypesEqual, WiderOrEq
|
||||
from unittest import TestCase
|
||||
from functools import reduce
|
||||
|
||||
try:
|
||||
from .ti import TypeMap, ConstraintList, VarTyping, TypingOrError # noqa
|
||||
from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa
|
||||
except ImportError:
|
||||
TYPE_CHECKING = False
|
||||
|
||||
|
||||
def agree(me, other):
|
||||
# type: (TypeEnv, TypeEnv) -> bool
|
||||
"""
|
||||
Given TypeEnvs me and other, check if they agree. As part of that build
|
||||
a map m from TVs in me to their corresponding TVs in other.
|
||||
Specifically:
|
||||
|
||||
1. Check that all TVs that are keys in me.type_map are also defined
|
||||
in other.type_map
|
||||
|
||||
2. For any tv in me.type_map check that:
|
||||
me[tv].get_typeset() == other[tv].get_typeset()
|
||||
|
||||
3. Set m[me[tv]] = other[tv] in the substitution m
|
||||
|
||||
4. If we find another tv1 such that me[tv1] == me[tv], assert that
|
||||
other[tv1] == m[me[tv1]] == m[me[tv]] = other[tv]
|
||||
|
||||
5. Check that me and other have the same constraints under the
|
||||
substitution m
|
||||
"""
|
||||
m = {} # type: TypeMap
|
||||
# Check that our type map and other's agree and built substitution m
|
||||
for tv in me.type_map:
|
||||
if (me[tv] not in m):
|
||||
m[me[tv]] = other[tv]
|
||||
if me[tv].get_typeset() != other[tv].get_typeset():
|
||||
return False
|
||||
else:
|
||||
if m[me[tv]] != other[tv]:
|
||||
return False
|
||||
|
||||
# Translate our constraints using m, and sort
|
||||
me_equiv_constr = sorted([constr.translate(m)
|
||||
for constr in me.constraints], key=repr)
|
||||
# Sort other's constraints
|
||||
other_equiv_constr = sorted([constr.translate(other)
|
||||
for constr in other.constraints], key=repr)
|
||||
return me_equiv_constr == other_equiv_constr
|
||||
|
||||
|
||||
def check_typing(got_or_err, expected, symtab=None):
|
||||
# type: (TypingOrError, Tuple[VarTyping, ConstraintList], Dict[str, Var]) -> None # noqa
|
||||
"""
|
||||
Check that a the typing we received (got_or_err) complies with the
|
||||
expected typing (expected). If symtab is specified, substitute the Vars in
|
||||
expected using symtab first (used when checking type inference on XForms)
|
||||
"""
|
||||
(m, c) = expected
|
||||
got = get_type_env(got_or_err)
|
||||
|
||||
if (symtab is not None):
|
||||
# For xforms we first need to re-write our TVs in terms of the tvs
|
||||
# stored internally in the XForm. Use the symtab passed
|
||||
subst_m = {k.get_typevar(): symtab[str(k)].get_typevar()
|
||||
for k in m.keys()}
|
||||
# Convert m from a Var->TypeVar map to TypeVar->TypeVar map where
|
||||
# the key TypeVar is re-written to its XForm internal version
|
||||
tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()}
|
||||
# Rewrite the TVs in the input constraints to their XForm internal
|
||||
# versions
|
||||
c = [constr.translate(subst_m) for constr in c]
|
||||
else:
|
||||
# If no symtab, just convert m from Var->TypeVar map to a
|
||||
# TypeVar->TypeVar map
|
||||
tv_m = {k.get_typevar(): v for (k, v) in m.items()}
|
||||
|
||||
expected_typ = TypeEnv((tv_m, c))
|
||||
assert agree(expected_typ, got), \
|
||||
"typings disagree:\n {} \n {}".format(got.dot(),
|
||||
expected_typ.dot())
|
||||
|
||||
|
||||
def check_concrete_typing_rtl(var_types, rtl):
|
||||
# type: (VarTyping, Rtl) -> None
|
||||
"""
|
||||
Check that a concrete type assignment var_types (Dict[Var, TypeVar]) is
|
||||
valid for an Rtl rtl. Specifically check that:
|
||||
|
||||
1) For each Var v \in rtl, v is defined in var_types
|
||||
|
||||
2) For all v, var_types[v] is a singleton type
|
||||
|
||||
3) For each v, and each location u, where v is used with expected type
|
||||
tv_u, var_types[v].get_typeset() is a subset of
|
||||
subst(tv_u, m).get_typeset() where m is the substitution of
|
||||
formals->actuals we are building so far.
|
||||
|
||||
4) If tv_u is non-derived and not in m, set m[tv_u]= var_types[v]
|
||||
"""
|
||||
for d in rtl.rtl:
|
||||
assert isinstance(d, Def)
|
||||
inst = d.expr.inst
|
||||
# Accumulate all actual TVs for value defs/opnums in actual_tvs
|
||||
actual_tvs = [var_types[d.defs[i]] for i in inst.value_results]
|
||||
for v in [d.expr.args[i] for i in inst.value_opnums]:
|
||||
assert isinstance(v, Var)
|
||||
actual_tvs.append(var_types[v])
|
||||
|
||||
# Accumulate all formal TVs for value defs/opnums in actual_tvs
|
||||
formal_tvs = [inst.outs[i].typevar for i in inst.value_results] +\
|
||||
[inst.ins[i].typevar for i in inst.value_opnums]
|
||||
m = {} # type: TypeMap
|
||||
|
||||
# For each actual/formal pair check that they agree
|
||||
for (actual_tv, formal_tv) in zip(actual_tvs, formal_tvs):
|
||||
# actual should be a singleton
|
||||
assert actual_tv.singleton_type() is not None
|
||||
formal_tv = subst(formal_tv, m)
|
||||
# actual should agree with the concretized formal
|
||||
assert actual_tv.get_typeset().issubset(formal_tv.get_typeset())
|
||||
|
||||
if formal_tv not in m and not formal_tv.is_derived:
|
||||
m[formal_tv] = actual_tv
|
||||
|
||||
|
||||
def check_concrete_typing_xform(var_types, xform):
|
||||
# type: (VarTyping, XForm) -> None
|
||||
"""
|
||||
Check a concrete type assignment var_types for an XForm xform
|
||||
"""
|
||||
check_concrete_typing_rtl(var_types, xform.src)
|
||||
check_concrete_typing_rtl(var_types, xform.dst)
|
||||
|
||||
|
||||
class TypeCheckingBaseTest(TestCase):
|
||||
def setUp(self):
|
||||
# type: () -> None
|
||||
self.v0 = Var("v0")
|
||||
self.v1 = Var("v1")
|
||||
self.v2 = Var("v2")
|
||||
self.v3 = Var("v3")
|
||||
self.v4 = Var("v4")
|
||||
self.v5 = Var("v5")
|
||||
self.v6 = Var("v6")
|
||||
self.v7 = Var("v7")
|
||||
self.v8 = Var("v8")
|
||||
self.v9 = Var("v9")
|
||||
self.imm0 = Var("imm0")
|
||||
self.IxN_nonscalar = TypeVar("IxN", "", ints=True, scalars=False,
|
||||
simd=True)
|
||||
self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True,
|
||||
scalars=False, simd=True)
|
||||
self.b1 = TypeVar.singleton(b1)
|
||||
|
||||
|
||||
class TestRTL(TypeCheckingBaseTest):
|
||||
def test_bad_rtl1(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
(self.v0, self.v1) << vsplit(self.v2),
|
||||
self.v3 << vconcat(self.v0, self.v2),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
self.assertEqual(ti_rtl(r, ti),
|
||||
"On line 1: fail ti on `typeof_v2` <: `1`: " +
|
||||
"Error: empty type created when unifying " +
|
||||
"`typeof_v2` and `half_vector(typeof_v2)`")
|
||||
|
||||
def test_vselect(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v0 << vselect(self.v1, self.v2, self.v3),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
txn = self.TxN.get_fresh_copy("TxN1")
|
||||
check_typing(typing, ({
|
||||
self.v0: txn,
|
||||
self.v1: txn.as_bool(),
|
||||
self.v2: txn,
|
||||
self.v3: txn
|
||||
}, []))
|
||||
|
||||
def test_vselect_icmpimm(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v0 << iconst(self.imm0),
|
||||
self.v1 << icmp(intcc.eq, self.v2, self.v0),
|
||||
self.v5 << vselect(self.v1, self.v3, self.v4),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
ixn = self.IxN_nonscalar.get_fresh_copy("IxN1")
|
||||
txn = self.TxN.get_fresh_copy("TxN1")
|
||||
check_typing(typing, ({
|
||||
self.v0: ixn,
|
||||
self.v1: ixn.as_bool(),
|
||||
self.v2: ixn,
|
||||
self.v3: txn,
|
||||
self.v4: txn,
|
||||
self.v5: txn,
|
||||
}, [TypesEqual(ixn.as_bool(), txn.as_bool())]))
|
||||
|
||||
def test_vselect_vsplits(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v3 << vselect(self.v0, self.v1, self.v2),
|
||||
(self.v4, self.v5) << vsplit(self.v3),
|
||||
(self.v6, self.v7) << vsplit(self.v4),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
t = TypeVar("t", "", ints=True, bools=True, floats=True,
|
||||
simd=(4, 256))
|
||||
check_typing(typing, ({
|
||||
self.v0: t.as_bool(),
|
||||
self.v1: t,
|
||||
self.v2: t,
|
||||
self.v3: t,
|
||||
self.v4: t.half_vector(),
|
||||
self.v5: t.half_vector(),
|
||||
self.v6: t.half_vector().half_vector(),
|
||||
self.v7: t.half_vector().half_vector(),
|
||||
}, []))
|
||||
|
||||
def test_vselect_vconcats(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v3 << vselect(self.v0, self.v1, self.v2),
|
||||
self.v8 << vconcat(self.v3, self.v3),
|
||||
self.v9 << vconcat(self.v8, self.v8),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
t = TypeVar("t", "", ints=True, bools=True, floats=True,
|
||||
simd=(2, 64))
|
||||
check_typing(typing, ({
|
||||
self.v0: t.as_bool(),
|
||||
self.v1: t,
|
||||
self.v2: t,
|
||||
self.v3: t,
|
||||
self.v8: t.double_vector(),
|
||||
self.v9: t.double_vector().double_vector(),
|
||||
}, []))
|
||||
|
||||
def test_vselect_vsplits_vconcats(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v3 << vselect(self.v0, self.v1, self.v2),
|
||||
(self.v4, self.v5) << vsplit(self.v3),
|
||||
(self.v6, self.v7) << vsplit(self.v4),
|
||||
self.v8 << vconcat(self.v3, self.v3),
|
||||
self.v9 << vconcat(self.v8, self.v8),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
t = TypeVar("t", "", ints=True, bools=True, floats=True,
|
||||
simd=(4, 64))
|
||||
check_typing(typing, ({
|
||||
self.v0: t.as_bool(),
|
||||
self.v1: t,
|
||||
self.v2: t,
|
||||
self.v3: t,
|
||||
self.v4: t.half_vector(),
|
||||
self.v5: t.half_vector(),
|
||||
self.v6: t.half_vector().half_vector(),
|
||||
self.v7: t.half_vector().half_vector(),
|
||||
self.v8: t.double_vector(),
|
||||
self.v9: t.double_vector().double_vector(),
|
||||
}, []))
|
||||
|
||||
def test_bint(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v4 << iadd(self.v1, self.v2),
|
||||
self.v5 << bint(self.v3),
|
||||
self.v0 << iadd(self.v4, self.v5)
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
itype = TypeVar("t", "", ints=True, simd=(1, 256))
|
||||
btype = TypeVar("b", "", bools=True, simd=True)
|
||||
|
||||
# Check that self.v5 gets the same integer type as
|
||||
# the rest of them
|
||||
# TODO: Add constraint nlanes(v3) == nlanes(v1) when we
|
||||
# add that type constraint to bint
|
||||
check_typing(typing, ({
|
||||
self.v1: itype,
|
||||
self.v2: itype,
|
||||
self.v4: itype,
|
||||
self.v5: itype,
|
||||
self.v3: btype,
|
||||
self.v0: itype,
|
||||
}, []))
|
||||
|
||||
def test_fully_bound_inst_inference_bad(self):
|
||||
# Incompatible bound instructions fail accordingly
|
||||
r = Rtl(
|
||||
self.v3 << uextend.i32(self.v1),
|
||||
self.v4 << uextend.i16(self.v2),
|
||||
self.v5 << iadd(self.v3, self.v4),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
|
||||
self.assertEqual(typing,
|
||||
"On line 2: fail ti on `typeof_v4` <: `4`: " +
|
||||
"Error: empty type created when unifying " +
|
||||
"`i16` and `i32`")
|
||||
|
||||
def test_extend_reduce(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v1 << uextend(self.v0),
|
||||
self.v2 << ireduce(self.v1),
|
||||
self.v3 << sextend(self.v2),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
typing = typing.extract()
|
||||
|
||||
itype0 = TypeVar("t", "", ints=True, simd=(1, 256))
|
||||
itype1 = TypeVar("t1", "", ints=True, simd=(1, 256))
|
||||
itype2 = TypeVar("t2", "", ints=True, simd=(1, 256))
|
||||
itype3 = TypeVar("t3", "", ints=True, simd=(1, 256))
|
||||
|
||||
check_typing(typing, ({
|
||||
self.v0: itype0,
|
||||
self.v1: itype1,
|
||||
self.v2: itype2,
|
||||
self.v3: itype3,
|
||||
}, [WiderOrEq(itype1, itype0),
|
||||
WiderOrEq(itype1, itype2),
|
||||
WiderOrEq(itype3, itype2)]))
|
||||
|
||||
def test_extend_reduce_enumeration(self):
|
||||
# type: () -> None
|
||||
for op in (uextend, sextend, ireduce):
|
||||
r = Rtl(
|
||||
self.v1 << op(self.v0),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti).extract()
|
||||
|
||||
# The number of possible typings is 9 * (3+ 2*2 + 3) = 90
|
||||
lst = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
|
||||
assert (len(lst) == len(set(lst)) and len(lst) == 90)
|
||||
for (tv0, tv1) in lst:
|
||||
typ0, typ1 = (tv0.singleton_type(), tv1.singleton_type())
|
||||
if (op == ireduce):
|
||||
assert typ0.wider_or_equal(typ1)
|
||||
else:
|
||||
assert typ1.wider_or_equal(typ0)
|
||||
|
||||
def test_fpromote_fdemote(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v1 << fpromote(self.v0),
|
||||
self.v2 << fdemote(self.v1),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
typing = typing.extract()
|
||||
|
||||
ftype0 = TypeVar("t", "", floats=True, simd=(1, 256))
|
||||
ftype1 = TypeVar("t1", "", floats=True, simd=(1, 256))
|
||||
ftype2 = TypeVar("t2", "", floats=True, simd=(1, 256))
|
||||
|
||||
check_typing(typing, ({
|
||||
self.v0: ftype0,
|
||||
self.v1: ftype1,
|
||||
self.v2: ftype2,
|
||||
}, [WiderOrEq(ftype1, ftype0),
|
||||
WiderOrEq(ftype1, ftype2)]))
|
||||
|
||||
def test_fpromote_fdemote_enumeration(self):
|
||||
# type: () -> None
|
||||
for op in (fpromote, fdemote):
|
||||
r = Rtl(
|
||||
self.v1 << op(self.v0),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti).extract()
|
||||
|
||||
# The number of possible typings is 9*(2 + 1) = 27
|
||||
lst = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
|
||||
assert (len(lst) == len(set(lst)) and len(lst) == 27)
|
||||
for (tv0, tv1) in lst:
|
||||
(typ0, typ1) = (tv0.singleton_type(), tv1.singleton_type())
|
||||
if (op == fdemote):
|
||||
assert typ0.wider_or_equal(typ1)
|
||||
else:
|
||||
assert typ1.wider_or_equal(typ0)
|
||||
|
||||
|
||||
class TestXForm(TypeCheckingBaseTest):
|
||||
def test_iadd_cout(self):
|
||||
# type: () -> None
|
||||
x = XForm(Rtl((self.v0, self.v1) << iadd_cout(self.v2, self.v3),),
|
||||
Rtl(
|
||||
self.v0 << iadd(self.v2, self.v3),
|
||||
self.v1 << icmp(intcc.ult, self.v0, self.v2)
|
||||
))
|
||||
itype = TypeVar("t", "", ints=True, simd=(1, 1))
|
||||
|
||||
check_typing(x.ti, ({
|
||||
self.v0: itype,
|
||||
self.v2: itype,
|
||||
self.v3: itype,
|
||||
self.v1: itype.as_bool(),
|
||||
}, []), x.symtab)
|
||||
|
||||
def test_iadd_cin(self):
|
||||
# type: () -> None
|
||||
x = XForm(Rtl(self.v0 << iadd_cin(self.v1, self.v2, self.v3)),
|
||||
Rtl(
|
||||
self.v4 << iadd(self.v1, self.v2),
|
||||
self.v5 << bint(self.v3),
|
||||
self.v0 << iadd(self.v4, self.v5)
|
||||
))
|
||||
itype = TypeVar("t", "", ints=True, simd=(1, 1))
|
||||
|
||||
check_typing(x.ti, ({
|
||||
self.v0: itype,
|
||||
self.v1: itype,
|
||||
self.v2: itype,
|
||||
self.v3: self.b1,
|
||||
self.v4: itype,
|
||||
self.v5: itype,
|
||||
}, []), x.symtab)
|
||||
|
||||
def test_enumeration_with_constraints(self):
|
||||
# type: () -> None
|
||||
xform = XForm(
|
||||
Rtl(
|
||||
self.v0 << iconst(self.imm0),
|
||||
self.v1 << icmp(intcc.eq, self.v2, self.v0),
|
||||
self.v5 << vselect(self.v1, self.v3, self.v4)
|
||||
),
|
||||
Rtl(
|
||||
self.v0 << iconst(self.imm0),
|
||||
self.v1 << icmp(intcc.eq, self.v2, self.v0),
|
||||
self.v5 << vselect(self.v1, self.v3, self.v4)
|
||||
))
|
||||
|
||||
# Check all var assigns are correct
|
||||
assert len(xform.ti.constraints) > 0
|
||||
concrete_var_assigns = list(xform.ti.concrete_typings())
|
||||
|
||||
v0 = xform.symtab[str(self.v0)]
|
||||
v1 = xform.symtab[str(self.v1)]
|
||||
v2 = xform.symtab[str(self.v2)]
|
||||
v3 = xform.symtab[str(self.v3)]
|
||||
v4 = xform.symtab[str(self.v4)]
|
||||
v5 = xform.symtab[str(self.v5)]
|
||||
|
||||
for var_m in concrete_var_assigns:
|
||||
assert var_m[v0] == var_m[v2] and \
|
||||
var_m[v3] == var_m[v4] and\
|
||||
var_m[v5] == var_m[v3] and\
|
||||
var_m[v1] == var_m[v2].as_bool() and\
|
||||
var_m[v1].get_typeset() == var_m[v3].as_bool().get_typeset()
|
||||
check_concrete_typing_xform(var_m, xform)
|
||||
|
||||
# The number of possible typings here is:
|
||||
# 8 cases for v0 = i8xN times 2 options for v3 - i8, b8 = 16
|
||||
# 8 cases for v0 = i16xN times 2 options for v3 - i16, b16 = 16
|
||||
# 8 cases for v0 = i32xN times 3 options for v3 - i32, b32, f32 = 24
|
||||
# 8 cases for v0 = i64xN times 3 options for v3 - i64, b64, f64 = 24
|
||||
#
|
||||
# (Note we have 8 cases for lanes since vselect prevents scalars)
|
||||
# Total: 2*16 + 2*24 = 80
|
||||
assert len(concrete_var_assigns) == 80
|
||||
|
||||
def test_base_legalizations_enumeration(self):
|
||||
# type: () -> None
|
||||
for xform in narrow.xforms + expand.xforms:
|
||||
# Any legalization patterns we defined should have at least 1
|
||||
# concrete typing
|
||||
concrete_typings_list = list(xform.ti.concrete_typings())
|
||||
assert len(concrete_typings_list) > 0
|
||||
|
||||
# If there are no free_typevars, this is a non-polymorphic pattern.
|
||||
# There should be only one possible concrete typing.
|
||||
if (len(xform.ti.free_typevars()) == 0):
|
||||
assert len(concrete_typings_list) == 1
|
||||
continue
|
||||
|
||||
# For any patterns where the type env includes constraints, at
|
||||
# least one of the "theoretically possible" concrete typings must
|
||||
# be prevented by the constraints. (i.e. we are not emitting
|
||||
# unneccessary constraints).
|
||||
# We check that by asserting that the number of concrete typings is
|
||||
# less than the number of all possible free typevar assignments
|
||||
if (len(xform.ti.constraints) > 0):
|
||||
theoretical_num_typings =\
|
||||
reduce(lambda x, y: x*y,
|
||||
[tv.get_typeset().size()
|
||||
for tv in xform.ti.free_typevars()], 1)
|
||||
assert len(concrete_typings_list) < theoretical_num_typings
|
||||
|
||||
# Check the validity of each individual concrete typing against the
|
||||
# xform
|
||||
for concrete_typing in concrete_typings_list:
|
||||
check_concrete_typing_xform(concrete_typing, xform)
|
||||
|
||||
def test_bound_inst_inference(self):
|
||||
# First example from issue #26
|
||||
x = XForm(
|
||||
Rtl(
|
||||
self.v0 << iadd(self.v1, self.v2),
|
||||
),
|
||||
Rtl(
|
||||
self.v3 << uextend.i32(self.v1),
|
||||
self.v4 << uextend.i32(self.v2),
|
||||
self.v5 << iadd(self.v3, self.v4),
|
||||
self.v0 << ireduce(self.v5)
|
||||
))
|
||||
itype = TypeVar("t", "", ints=True, simd=True)
|
||||
i32t = TypeVar.singleton(i32)
|
||||
|
||||
check_typing(x.ti, ({
|
||||
self.v0: itype,
|
||||
self.v1: itype,
|
||||
self.v2: itype,
|
||||
self.v3: i32t,
|
||||
self.v4: i32t,
|
||||
self.v5: i32t,
|
||||
}, [WiderOrEq(i32t, itype)]), x.symtab)
|
||||
|
||||
def test_bound_inst_inference1(self):
|
||||
# Second example taken from issue #26
|
||||
x = XForm(
|
||||
Rtl(
|
||||
self.v0 << iadd(self.v1, self.v2),
|
||||
),
|
||||
Rtl(
|
||||
self.v3 << uextend(self.v1),
|
||||
self.v4 << uextend(self.v2),
|
||||
self.v5 << iadd.i32(self.v3, self.v4),
|
||||
self.v0 << ireduce(self.v5)
|
||||
))
|
||||
itype = TypeVar("t", "", ints=True, simd=True)
|
||||
i32t = TypeVar.singleton(i32)
|
||||
|
||||
check_typing(x.ti, ({
|
||||
self.v0: itype,
|
||||
self.v1: itype,
|
||||
self.v2: itype,
|
||||
self.v3: i32t,
|
||||
self.v4: i32t,
|
||||
self.v5: i32t,
|
||||
}, [WiderOrEq(i32t, itype)]), x.symtab)
|
||||
|
||||
def test_fully_bound_inst_inference(self):
|
||||
# Second example taken from issue #26 with complete bounds
|
||||
x = XForm(
|
||||
Rtl(
|
||||
self.v0 << iadd(self.v1, self.v2),
|
||||
),
|
||||
Rtl(
|
||||
self.v3 << uextend.i32.i8(self.v1),
|
||||
self.v4 << uextend.i32.i8(self.v2),
|
||||
self.v5 << iadd(self.v3, self.v4),
|
||||
self.v0 << ireduce(self.v5)
|
||||
))
|
||||
i8t = TypeVar.singleton(i8)
|
||||
i32t = TypeVar.singleton(i32)
|
||||
|
||||
# Note no constraints here since they are all trivial
|
||||
check_typing(x.ti, ({
|
||||
self.v0: i8t,
|
||||
self.v1: i8t,
|
||||
self.v2: i8t,
|
||||
self.v3: i32t,
|
||||
self.v4: i32t,
|
||||
self.v5: i32t,
|
||||
}, []), x.symtab)
|
||||
|
||||
def test_fully_bound_inst_inference_bad(self):
|
||||
# Can't force a mistyped XForm using bound instructions
|
||||
with self.assertRaises(AssertionError):
|
||||
XForm(
|
||||
Rtl(
|
||||
self.v0 << iadd(self.v1, self.v2),
|
||||
),
|
||||
Rtl(
|
||||
self.v3 << uextend.i32.i8(self.v1),
|
||||
self.v4 << uextend.i32.i16(self.v2),
|
||||
self.v5 << iadd(self.v3, self.v4),
|
||||
self.v0 << ireduce(self.v5)
|
||||
))
|
||||
266
lib/codegen/meta-python/cdsl/test_typevar.py
Normal file
266
lib/codegen/meta-python/cdsl/test_typevar.py
Normal file
@@ -0,0 +1,266 @@
|
||||
from __future__ import absolute_import
|
||||
from unittest import TestCase
|
||||
from doctest import DocTestSuite
|
||||
from . import typevar
|
||||
from .typevar import TypeSet, TypeVar
|
||||
from base.types import i32, i16, b1, f64
|
||||
from itertools import product
|
||||
from functools import reduce
|
||||
|
||||
|
||||
def load_tests(loader, tests, ignore):
|
||||
tests.addTests(DocTestSuite(typevar))
|
||||
return tests
|
||||
|
||||
|
||||
class TestTypeSet(TestCase):
|
||||
def test_invalid(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
TypeSet(lanes=(2, 1))
|
||||
with self.assertRaises(AssertionError):
|
||||
TypeSet(ints=(32, 16))
|
||||
with self.assertRaises(AssertionError):
|
||||
TypeSet(floats=(32, 16))
|
||||
with self.assertRaises(AssertionError):
|
||||
TypeSet(bools=(32, 16))
|
||||
with self.assertRaises(AssertionError):
|
||||
TypeSet(ints=(32, 33))
|
||||
|
||||
def test_hash(self):
|
||||
a = TypeSet(lanes=True, ints=True, floats=True)
|
||||
b = TypeSet(lanes=True, ints=True, floats=True)
|
||||
c = TypeSet(lanes=True, ints=(8, 16), floats=True)
|
||||
self.assertEqual(a, b)
|
||||
self.assertNotEqual(a, c)
|
||||
s = set()
|
||||
s.add(a)
|
||||
self.assertTrue(a in s)
|
||||
self.assertTrue(b in s)
|
||||
self.assertFalse(c in s)
|
||||
|
||||
def test_hash_modified(self):
|
||||
a = TypeSet(lanes=True, ints=True, floats=True)
|
||||
s = set()
|
||||
s.add(a)
|
||||
a.ints.remove(64)
|
||||
# Can't rehash after modification.
|
||||
with self.assertRaises(AssertionError):
|
||||
a in s
|
||||
|
||||
def test_forward_images(self):
|
||||
a = TypeSet(lanes=(2, 8), ints=(8, 8), floats=(32, 32))
|
||||
b = TypeSet(lanes=(1, 8), ints=(8, 8), floats=(32, 32))
|
||||
self.assertEqual(a.lane_of(), TypeSet(ints=(8, 8), floats=(32, 32)))
|
||||
|
||||
c = TypeSet(lanes=(2, 8))
|
||||
c.bools = set([8, 32])
|
||||
|
||||
# Test case with disjoint intervals
|
||||
self.assertEqual(a.as_bool(), c)
|
||||
|
||||
# For as_bool check b1 is present when 1 \in lanes
|
||||
d = TypeSet(lanes=(1, 8))
|
||||
d.bools = set([1, 8, 32])
|
||||
self.assertEqual(b.as_bool(), d)
|
||||
|
||||
self.assertEqual(TypeSet(lanes=(1, 32)).half_vector(),
|
||||
TypeSet(lanes=(1, 16)))
|
||||
|
||||
self.assertEqual(TypeSet(lanes=(1, 32)).double_vector(),
|
||||
TypeSet(lanes=(2, 64)))
|
||||
|
||||
self.assertEqual(TypeSet(lanes=(128, 256)).double_vector(),
|
||||
TypeSet(lanes=(256, 256)))
|
||||
|
||||
self.assertEqual(TypeSet(ints=(8, 32)).half_width(),
|
||||
TypeSet(ints=(8, 16)))
|
||||
|
||||
self.assertEqual(TypeSet(ints=(8, 32)).double_width(),
|
||||
TypeSet(ints=(16, 64)))
|
||||
|
||||
self.assertEqual(TypeSet(ints=(32, 64)).double_width(),
|
||||
TypeSet(ints=(64, 64)))
|
||||
|
||||
# Should produce an empty ts
|
||||
self.assertEqual(TypeSet(floats=(32, 32)).half_width(),
|
||||
TypeSet())
|
||||
|
||||
self.assertEqual(TypeSet(floats=(32, 64)).half_width(),
|
||||
TypeSet(floats=(32, 32)))
|
||||
|
||||
self.assertEqual(TypeSet(floats=(32, 32)).double_width(),
|
||||
TypeSet(floats=(64, 64)))
|
||||
|
||||
self.assertEqual(TypeSet(floats=(32, 64)).double_width(),
|
||||
TypeSet(floats=(64, 64)))
|
||||
|
||||
# Bools have trickier behavior around b1 (since b2, b4 don't exist)
|
||||
self.assertEqual(TypeSet(bools=(1, 8)).half_width(),
|
||||
TypeSet())
|
||||
|
||||
t = TypeSet()
|
||||
t.bools = set([8, 16])
|
||||
self.assertEqual(TypeSet(bools=(1, 32)).half_width(), t)
|
||||
|
||||
# double_width() of bools={1, 8, 16} must not include 2 or 8
|
||||
t.bools = set([16, 32])
|
||||
self.assertEqual(TypeSet(bools=(1, 16)).double_width(), t)
|
||||
|
||||
self.assertEqual(TypeSet(bools=(32, 64)).double_width(),
|
||||
TypeSet(bools=(64, 64)))
|
||||
|
||||
def test_get_singleton(self):
|
||||
# Raise error when calling get_singleton() on non-singleton TS
|
||||
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
|
||||
with self.assertRaises(AssertionError):
|
||||
t.get_singleton()
|
||||
t = TypeSet(lanes=(1, 2), floats=(32, 32))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
t.get_singleton()
|
||||
|
||||
self.assertEqual(TypeSet(ints=(16, 16)).get_singleton(), i16)
|
||||
self.assertEqual(TypeSet(floats=(64, 64)).get_singleton(), f64)
|
||||
self.assertEqual(TypeSet(bools=(1, 1)).get_singleton(), b1)
|
||||
self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(),
|
||||
i32.by(4))
|
||||
|
||||
def test_preimage(self):
|
||||
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
|
||||
|
||||
# LANEOF
|
||||
self.assertEqual(TypeSet(lanes=True, ints=(8, 8), floats=(32, 32)),
|
||||
t.preimage(TypeVar.LANEOF))
|
||||
# Inverse of empty set is still empty across LANEOF
|
||||
self.assertEqual(TypeSet(),
|
||||
TypeSet().preimage(TypeVar.LANEOF))
|
||||
|
||||
# ASBOOL
|
||||
t = TypeSet(lanes=(1, 4), bools=(1, 64))
|
||||
self.assertEqual(t.preimage(TypeVar.ASBOOL),
|
||||
TypeSet(lanes=(1, 4), ints=True, bools=True,
|
||||
floats=True))
|
||||
|
||||
# Half/Double Vector
|
||||
t = TypeSet(lanes=(1, 1), ints=(8, 8))
|
||||
t1 = TypeSet(lanes=(256, 256), ints=(8, 8))
|
||||
self.assertEqual(t.preimage(TypeVar.DOUBLEVECTOR).size(), 0)
|
||||
self.assertEqual(t1.preimage(TypeVar.HALFVECTOR).size(), 0)
|
||||
|
||||
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 32))
|
||||
t1 = TypeSet(lanes=(64, 256), bools=(1, 32))
|
||||
|
||||
self.assertEqual(t.preimage(TypeVar.DOUBLEVECTOR),
|
||||
TypeSet(lanes=(1, 8), ints=(8, 16), floats=(32, 32)))
|
||||
self.assertEqual(t1.preimage(TypeVar.HALFVECTOR),
|
||||
TypeSet(lanes=(128, 256), bools=(1, 32)))
|
||||
|
||||
# Half/Double Width
|
||||
t = TypeSet(ints=(8, 8), floats=(32, 32), bools=(1, 8))
|
||||
t1 = TypeSet(ints=(64, 64), floats=(64, 64), bools=(64, 64))
|
||||
self.assertEqual(t.preimage(TypeVar.DOUBLEWIDTH).size(), 0)
|
||||
self.assertEqual(t1.preimage(TypeVar.HALFWIDTH).size(), 0)
|
||||
|
||||
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 64))
|
||||
t1 = TypeSet(lanes=(64, 256), bools=(1, 64))
|
||||
|
||||
self.assertEqual(t.preimage(TypeVar.DOUBLEWIDTH),
|
||||
TypeSet(lanes=(1, 16), ints=(8, 8), floats=(32, 32)))
|
||||
self.assertEqual(t1.preimage(TypeVar.HALFWIDTH),
|
||||
TypeSet(lanes=(64, 256), bools=(16, 64)))
|
||||
|
||||
|
||||
def has_non_bijective_derived_f(iterable):
|
||||
return any(not TypeVar.is_bijection(x) for x in iterable)
|
||||
|
||||
|
||||
class TestTypeVar(TestCase):
|
||||
def test_functions(self):
|
||||
x = TypeVar('x', 'all ints', ints=True)
|
||||
with self.assertRaises(AssertionError):
|
||||
x.double_width()
|
||||
with self.assertRaises(AssertionError):
|
||||
x.half_width()
|
||||
|
||||
x2 = TypeVar('x2', 'i16 and up', ints=(16, 64))
|
||||
with self.assertRaises(AssertionError):
|
||||
x2.double_width()
|
||||
self.assertEqual(str(x2.half_width()), '`half_width(x2)`')
|
||||
self.assertEqual(x2.half_width().rust_expr(), 'x2.half_width()')
|
||||
self.assertEqual(
|
||||
x2.half_width().double_width().rust_expr(),
|
||||
'x2.half_width().double_width()')
|
||||
|
||||
x3 = TypeVar('x3', 'up to i32', ints=(8, 32))
|
||||
self.assertEqual(str(x3.double_width()), '`double_width(x3)`')
|
||||
with self.assertRaises(AssertionError):
|
||||
x3.half_width()
|
||||
|
||||
def test_singleton(self):
|
||||
x = TypeVar.singleton(i32)
|
||||
self.assertEqual(str(x), '`i32`')
|
||||
self.assertEqual(min(x.type_set.ints), 32)
|
||||
self.assertEqual(max(x.type_set.ints), 32)
|
||||
self.assertEqual(min(x.type_set.lanes), 1)
|
||||
self.assertEqual(max(x.type_set.lanes), 1)
|
||||
self.assertEqual(len(x.type_set.floats), 0)
|
||||
self.assertEqual(len(x.type_set.bools), 0)
|
||||
|
||||
x = TypeVar.singleton(i32.by(4))
|
||||
self.assertEqual(str(x), '`i32x4`')
|
||||
self.assertEqual(min(x.type_set.ints), 32)
|
||||
self.assertEqual(max(x.type_set.ints), 32)
|
||||
self.assertEqual(min(x.type_set.lanes), 4)
|
||||
self.assertEqual(max(x.type_set.lanes), 4)
|
||||
self.assertEqual(len(x.type_set.floats), 0)
|
||||
self.assertEqual(len(x.type_set.bools), 0)
|
||||
|
||||
def test_stress_constrain_types(self):
|
||||
# Get all 43 possible derived vars of length up to 2
|
||||
funcs = [TypeVar.LANEOF,
|
||||
TypeVar.ASBOOL, TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR,
|
||||
TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
|
||||
v = [()] + [(x,) for x in funcs] + list(product(*[funcs, funcs]))
|
||||
|
||||
# For each pair of derived variables
|
||||
for (i1, i2) in product(v, v):
|
||||
# Compute the derived sets for each starting with a full typeset
|
||||
full_ts = TypeSet(lanes=True, floats=True, ints=True, bools=True)
|
||||
ts1 = reduce(lambda ts, func: ts.image(func), i1, full_ts)
|
||||
ts2 = reduce(lambda ts, func: ts.image(func), i2, full_ts)
|
||||
|
||||
# Compute intersection
|
||||
intersect = ts1.copy()
|
||||
intersect &= ts2
|
||||
|
||||
# Propagate instersections backward
|
||||
ts1_src = reduce(lambda ts, func: ts.preimage(func),
|
||||
reversed(i1),
|
||||
intersect)
|
||||
ts2_src = reduce(lambda ts, func: ts.preimage(func),
|
||||
reversed(i2),
|
||||
intersect)
|
||||
|
||||
# If the intersection or its propagated forms are empty, then these
|
||||
# two variables can never overlap. For example x.double_vector and
|
||||
# x.lane_of.
|
||||
if (intersect.size() == 0 or ts1_src.size() == 0 or
|
||||
ts2_src.size() == 0):
|
||||
continue
|
||||
|
||||
# Should be safe to create derived tvs from ts1_src and ts2_src
|
||||
tv1 = reduce(lambda tv, func: TypeVar.derived(tv, func),
|
||||
i1,
|
||||
TypeVar.from_typeset(ts1_src))
|
||||
|
||||
tv2 = reduce(lambda tv, func: TypeVar.derived(tv, func),
|
||||
i2,
|
||||
TypeVar.from_typeset(ts2_src))
|
||||
|
||||
# In the absence of AS_BOOL image(preimage(f)) == f so the
|
||||
# typesets of tv1 and tv2 should be exactly intersection
|
||||
assert tv1.get_typeset() == intersect or\
|
||||
has_non_bijective_derived_f(i1)
|
||||
|
||||
assert tv2.get_typeset() == intersect or\
|
||||
has_non_bijective_derived_f(i2)
|
||||
131
lib/codegen/meta-python/cdsl/test_xform.py
Normal file
131
lib/codegen/meta-python/cdsl/test_xform.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from __future__ import absolute_import
|
||||
from unittest import TestCase
|
||||
from doctest import DocTestSuite
|
||||
from base.instructions import iadd, iadd_imm, iconst, icmp
|
||||
from base.immediates import intcc
|
||||
from . import xform
|
||||
from .ast import Var
|
||||
from .xform import Rtl, XForm
|
||||
|
||||
|
||||
def load_tests(loader, tests, ignore):
|
||||
tests.addTests(DocTestSuite(xform))
|
||||
return tests
|
||||
|
||||
|
||||
x = Var('x')
|
||||
y = Var('y')
|
||||
z = Var('z')
|
||||
u = Var('u')
|
||||
a = Var('a')
|
||||
b = Var('b')
|
||||
c = Var('c')
|
||||
|
||||
CC1 = Var('CC1')
|
||||
CC2 = Var('CC2')
|
||||
|
||||
|
||||
class TestXForm(TestCase):
|
||||
def test_macro_pattern(self):
|
||||
src = Rtl(a << iadd_imm(x, y))
|
||||
dst = Rtl(
|
||||
c << iconst(y),
|
||||
a << iadd(x, c))
|
||||
XForm(src, dst)
|
||||
|
||||
def test_def_input(self):
|
||||
# Src pattern has a def which is an input in dst.
|
||||
src = Rtl(a << iadd_imm(x, 1))
|
||||
dst = Rtl(y << iadd_imm(a, 1))
|
||||
with self.assertRaisesRegexp(
|
||||
AssertionError,
|
||||
"'a' used as both input and def"):
|
||||
XForm(src, dst)
|
||||
|
||||
def test_input_def(self):
|
||||
# Converse of the above.
|
||||
src = Rtl(y << iadd_imm(a, 1))
|
||||
dst = Rtl(a << iadd_imm(x, 1))
|
||||
with self.assertRaisesRegexp(
|
||||
AssertionError,
|
||||
"'a' used as both input and def"):
|
||||
XForm(src, dst)
|
||||
|
||||
def test_extra_input(self):
|
||||
src = Rtl(a << iadd_imm(x, 1))
|
||||
dst = Rtl(a << iadd(x, y))
|
||||
with self.assertRaisesRegexp(AssertionError, "extra inputs in dst"):
|
||||
XForm(src, dst)
|
||||
|
||||
def test_double_def(self):
|
||||
src = Rtl(
|
||||
a << iadd_imm(x, 1),
|
||||
a << iadd(x, y))
|
||||
dst = Rtl(a << iadd(x, y))
|
||||
with self.assertRaisesRegexp(AssertionError, "'a' multiply defined"):
|
||||
XForm(src, dst)
|
||||
|
||||
def test_subst_imm(self):
|
||||
src = Rtl(a << iconst(x))
|
||||
dst = Rtl(c << iconst(y))
|
||||
assert src.substitution(dst, {}) == {a: c, x: y}
|
||||
|
||||
def test_subst_enum_var(self):
|
||||
src = Rtl(a << icmp(CC1, x, y))
|
||||
dst = Rtl(b << icmp(CC2, z, u))
|
||||
assert src.substitution(dst, {}) == {a: b, CC1: CC2, x: z, y: u}
|
||||
|
||||
def test_subst_enum_const(self):
|
||||
src = Rtl(a << icmp(intcc.eq, x, y))
|
||||
dst = Rtl(b << icmp(intcc.eq, z, u))
|
||||
assert src.substitution(dst, {}) == {a: b, x: z, y: u}
|
||||
|
||||
def test_subst_enum_var_const(self):
|
||||
src = Rtl(a << icmp(CC1, x, y))
|
||||
dst = Rtl(b << icmp(intcc.eq, z, u))
|
||||
assert src.substitution(dst, {}) == {CC1: intcc.eq, x: z, y: u, a: b},\
|
||||
"{} != {}".format(src.substitution(dst, {}),
|
||||
{CC1: intcc.eq, x: z, y: u, a: b})
|
||||
|
||||
src = Rtl(a << icmp(intcc.eq, x, y))
|
||||
dst = Rtl(b << icmp(CC1, z, u))
|
||||
assert src.substitution(dst, {}) == {CC1: intcc.eq, x: z, y: u, a: b}
|
||||
|
||||
def test_subst_enum_bad(self):
|
||||
src = Rtl(a << icmp(intcc.eq, x, y))
|
||||
dst = Rtl(b << icmp(intcc.sge, z, u))
|
||||
assert src.substitution(dst, {}) is None
|
||||
|
||||
def test_subst_enum_bad_var_const(self):
|
||||
a1 = Var('a1')
|
||||
x1 = Var('x1')
|
||||
y1 = Var('y1')
|
||||
|
||||
b1 = Var('b1')
|
||||
z1 = Var('z1')
|
||||
u1 = Var('u1')
|
||||
|
||||
# Var mapping to 2 different constants
|
||||
src = Rtl(a << icmp(CC1, x, y),
|
||||
a1 << icmp(CC1, x1, y1))
|
||||
dst = Rtl(b << icmp(intcc.eq, z, u),
|
||||
b1 << icmp(intcc.sge, z1, u1))
|
||||
|
||||
assert src.substitution(dst, {}) is None
|
||||
|
||||
# 2 different constants mapping to the same var
|
||||
src = Rtl(a << icmp(intcc.eq, x, y),
|
||||
a1 << icmp(intcc.sge, x1, y1))
|
||||
dst = Rtl(b << icmp(CC1, z, u),
|
||||
b1 << icmp(CC1, z1, u1))
|
||||
|
||||
assert src.substitution(dst, {}) is None
|
||||
|
||||
# Var mapping to var and constant - note that full unification would
|
||||
# have allowed this.
|
||||
src = Rtl(a << icmp(CC1, x, y),
|
||||
a1 << icmp(CC1, x1, y1))
|
||||
dst = Rtl(b << icmp(CC2, z, u),
|
||||
b1 << icmp(intcc.sge, z1, u1))
|
||||
|
||||
assert src.substitution(dst, {}) is None
|
||||
886
lib/codegen/meta-python/cdsl/ti.py
Normal file
886
lib/codegen/meta-python/cdsl/ti.py
Normal file
@@ -0,0 +1,886 @@
|
||||
"""
|
||||
Type Inference
|
||||
"""
|
||||
from .typevar import TypeVar
|
||||
from .ast import Def, Var
|
||||
from copy import copy
|
||||
from itertools import product
|
||||
|
||||
try:
|
||||
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
||||
from typing import Iterable, List, Any, TypeVar as MTypeVar # noqa
|
||||
from typing import cast
|
||||
from .xform import Rtl, XForm # noqa
|
||||
from .ast import Expr # noqa
|
||||
from .typevar import TypeSet # noqa
|
||||
if TYPE_CHECKING:
|
||||
T = MTypeVar('T')
|
||||
TypeMap = Dict[TypeVar, TypeVar]
|
||||
VarTyping = Dict[Var, TypeVar]
|
||||
except ImportError:
|
||||
TYPE_CHECKING = False
|
||||
pass
|
||||
|
||||
|
||||
class TypeConstraint(object):
|
||||
"""
|
||||
Base class for all runtime-emittable type constraints.
|
||||
"""
|
||||
def translate(self, m):
|
||||
# type: (Union[TypeEnv, TypeMap]) -> TypeConstraint
|
||||
"""
|
||||
Translate any TypeVars in the constraint according to the map or
|
||||
TypeEnv m
|
||||
"""
|
||||
def translate_one(a):
|
||||
# type: (Any) -> Any
|
||||
if (isinstance(a, TypeVar)):
|
||||
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
|
||||
return a
|
||||
|
||||
res = None # type: TypeConstraint
|
||||
res = self.__class__(*tuple(map(translate_one, self._args())))
|
||||
return res
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if (not isinstance(other, self.__class__)):
|
||||
return False
|
||||
|
||||
assert isinstance(other, TypeConstraint) # help MyPy figure out other
|
||||
return self._args() == other._args()
|
||||
|
||||
def is_concrete(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true iff all typevars in the constraint are singletons.
|
||||
"""
|
||||
return [] == list(filter(lambda x: x.singleton_type() is None,
|
||||
self.tvs()))
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
return hash(self._args())
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
"""
|
||||
Return a tuple with the exact arguments passed to __init__ to create
|
||||
this object.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def tvs(self):
|
||||
# type: () -> Iterable[TypeVar]
|
||||
"""
|
||||
Return the typevars contained in this constraint.
|
||||
"""
|
||||
return filter(lambda x: isinstance(x, TypeVar), self._args())
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true if this constrain is statically decidable.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Evaluate this constraint. Should only be called when the constraint has
|
||||
been translated to concrete types.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return (self.__class__.__name__ + '(' +
|
||||
', '.join(map(str, self._args())) + ')')
|
||||
|
||||
|
||||
class TypesEqual(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that two derived type vars must have the same runtime
|
||||
type.
|
||||
"""
|
||||
def __init__(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv1, self.tv2)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
return self.tv1 == self.tv2 or self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
return self.tv1.singleton_type() == self.tv2.singleton_type()
|
||||
|
||||
|
||||
class InTypeset(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that a type var must belong to some typeset.
|
||||
"""
|
||||
def __init__(self, tv, ts):
|
||||
# type: (TypeVar, TypeSet) -> None
|
||||
assert not tv.is_derived and tv.name.startswith("typeof_")
|
||||
self.tv = tv
|
||||
self.ts = ts
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv, self.ts)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
tv_ts = self.tv.get_typeset().copy()
|
||||
|
||||
# Trivially True
|
||||
if (tv_ts.issubset(self.ts)):
|
||||
return True
|
||||
|
||||
# Trivially false
|
||||
tv_ts &= self.ts
|
||||
if (tv_ts.size() == 0):
|
||||
return True
|
||||
|
||||
return self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
return self.tv.get_typeset().issubset(self.ts)
|
||||
|
||||
|
||||
class WiderOrEq(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that a type var tv1 must be wider than or equal to
|
||||
type var tv2 at runtime. This requires that:
|
||||
1) They have the same number of lanes
|
||||
2) In a lane tv1 has at least as many bits as tv2.
|
||||
"""
|
||||
def __init__(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
self.tv1 = tv1
|
||||
self.tv2 = tv2
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv1, self.tv2)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
# Trivially true
|
||||
if (self.tv1 == self.tv2):
|
||||
return True
|
||||
|
||||
ts1 = self.tv1.get_typeset()
|
||||
ts2 = self.tv2.get_typeset()
|
||||
|
||||
def set_wider_or_equal(s1, s2):
|
||||
# type: (Set[int], Set[int]) -> bool
|
||||
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
|
||||
|
||||
# Trivially True
|
||||
if set_wider_or_equal(ts1.ints, ts2.ints) and\
|
||||
set_wider_or_equal(ts1.floats, ts2.floats) and\
|
||||
set_wider_or_equal(ts1.bools, ts2.bools):
|
||||
return True
|
||||
|
||||
def set_narrower(s1, s2):
|
||||
# type: (Set[int], Set[int]) -> bool
|
||||
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
|
||||
|
||||
# Trivially False
|
||||
if set_narrower(ts1.ints, ts2.ints) and\
|
||||
set_narrower(ts1.floats, ts2.floats) and\
|
||||
set_narrower(ts1.bools, ts2.bools):
|
||||
return True
|
||||
|
||||
# Trivially False
|
||||
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
|
||||
return True
|
||||
|
||||
return self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
typ1 = self.tv1.singleton_type()
|
||||
typ2 = self.tv2.singleton_type()
|
||||
|
||||
return typ1.wider_or_equal(typ2)
|
||||
|
||||
|
||||
class SameWidth(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that two types have the same width. E.g. i32x2 has
|
||||
the same width as i64x1, i16x4, f32x2, f64, b1x64 etc.
|
||||
"""
|
||||
def __init__(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
self.tv1 = tv1
|
||||
self.tv2 = tv2
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv1, self.tv2)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
# Trivially true
|
||||
if (self.tv1 == self.tv2):
|
||||
return True
|
||||
|
||||
ts1 = self.tv1.get_typeset()
|
||||
ts2 = self.tv2.get_typeset()
|
||||
|
||||
# Trivially False
|
||||
if len(ts1.widths().intersection(ts2.widths())) == 0:
|
||||
return True
|
||||
|
||||
return self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
typ1 = self.tv1.singleton_type()
|
||||
typ2 = self.tv2.singleton_type()
|
||||
|
||||
return (typ1.width() == typ2.width())
|
||||
|
||||
|
||||
class TypeEnv(object):
|
||||
"""
|
||||
Class encapsulating the neccessary book keeping for type inference.
|
||||
:attribute type_map: dict holding the equivalence relations between tvs
|
||||
:attribute constraints: a list of accumulated constraints - tuples
|
||||
(tv1, tv2)) where tv1 and tv2 are equal
|
||||
:attribute ranks: dictionary recording the (optional) ranks for tvs.
|
||||
'rank' is a partial ordering on TVs based on their
|
||||
origin. See comments in rank() and register().
|
||||
:attribute vars: a set containing all known Vars
|
||||
:attribute idx: counter used to get fresh ids
|
||||
"""
|
||||
|
||||
RANK_SINGLETON = 5
|
||||
RANK_INPUT = 4
|
||||
RANK_INTERMEDIATE = 3
|
||||
RANK_OUTPUT = 2
|
||||
RANK_TEMP = 1
|
||||
RANK_INTERNAL = 0
|
||||
|
||||
def __init__(self, arg=None):
|
||||
# type: (Optional[Tuple[TypeMap, List[TypeConstraint]]]) -> None
|
||||
self.ranks = {} # type: Dict[TypeVar, int]
|
||||
self.vars = set() # type: Set[Var]
|
||||
|
||||
if arg is None:
|
||||
self.type_map = {} # type: TypeMap
|
||||
self.constraints = [] # type: List[TypeConstraint]
|
||||
else:
|
||||
self.type_map, self.constraints = arg
|
||||
|
||||
self.idx = 0
|
||||
|
||||
def __getitem__(self, arg):
|
||||
# type: (Union[TypeVar, Var]) -> TypeVar
|
||||
"""
|
||||
Lookup the canonical representative for a Var/TypeVar.
|
||||
"""
|
||||
if (isinstance(arg, Var)):
|
||||
assert arg in self.vars
|
||||
tv = arg.get_typevar()
|
||||
else:
|
||||
assert (isinstance(arg, TypeVar))
|
||||
tv = arg
|
||||
|
||||
while tv in self.type_map:
|
||||
tv = self.type_map[tv]
|
||||
|
||||
if tv.is_derived:
|
||||
tv = TypeVar.derived(self[tv.base], tv.derived_func)
|
||||
return tv
|
||||
|
||||
def equivalent(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
"""
|
||||
Record a that the free tv1 is part of the same equivalence class as
|
||||
tv2. The canonical representative of the merged class is tv2's
|
||||
cannonical representative.
|
||||
"""
|
||||
assert not tv1.is_derived
|
||||
assert self[tv1] == tv1
|
||||
|
||||
# Make sure we don't create cycles
|
||||
if tv2.is_derived:
|
||||
assert self[tv2.base] != tv1
|
||||
|
||||
self.type_map[tv1] = tv2
|
||||
|
||||
def add_constraint(self, constr):
|
||||
# type: (TypeConstraint) -> None
|
||||
"""
|
||||
Add a new constraint
|
||||
"""
|
||||
if (constr in self.constraints):
|
||||
return
|
||||
|
||||
# InTypeset constraints can be expressed by constraining the typeset of
|
||||
# a variable. No need to add them to self.constraints
|
||||
if (isinstance(constr, InTypeset)):
|
||||
self[constr.tv].constrain_types_by_ts(constr.ts)
|
||||
return
|
||||
|
||||
self.constraints.append(constr)
|
||||
|
||||
def get_uid(self):
|
||||
# type: () -> str
|
||||
r = str(self.idx)
|
||||
self.idx += 1
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return self.dot()
|
||||
|
||||
def rank(self, tv):
|
||||
# type: (TypeVar) -> int
|
||||
"""
|
||||
Get the rank of tv in the partial order. TVs directly associated with a
|
||||
Var get their rank from the Var (see register()). Internally generated
|
||||
non-derived TVs implicitly get the lowest rank (0). Derived variables
|
||||
get their rank from their free typevar. Singletons have the highest
|
||||
rank. TVs associated with vars in a source pattern have a higher rank
|
||||
than TVs associted with temporary vars.
|
||||
"""
|
||||
default_rank = TypeEnv.RANK_INTERNAL if tv.singleton_type() is None \
|
||||
else TypeEnv.RANK_SINGLETON
|
||||
|
||||
if tv.is_derived:
|
||||
tv = tv.free_typevar()
|
||||
|
||||
return self.ranks.get(tv, default_rank)
|
||||
|
||||
def register(self, v):
|
||||
# type: (Var) -> None
|
||||
"""
|
||||
Register a new Var v. This computes a rank for the associated TypeVar
|
||||
for v, which is used to impose a partial order on type variables.
|
||||
"""
|
||||
self.vars.add(v)
|
||||
|
||||
if v.is_input():
|
||||
r = TypeEnv.RANK_INPUT
|
||||
elif v.is_intermediate():
|
||||
r = TypeEnv.RANK_INTERMEDIATE
|
||||
elif v.is_output():
|
||||
r = TypeEnv.RANK_OUTPUT
|
||||
else:
|
||||
assert(v.is_temp())
|
||||
r = TypeEnv.RANK_TEMP
|
||||
|
||||
self.ranks[v.get_typevar()] = r
|
||||
|
||||
def free_typevars(self):
|
||||
# type: () -> List[TypeVar]
|
||||
"""
|
||||
Get the free typevars in the current type env.
|
||||
"""
|
||||
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
|
||||
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
|
||||
# Filter out None here due to singleton type vars
|
||||
return sorted(filter(lambda x: x is not None, tvs),
|
||||
key=lambda x: x.name)
|
||||
|
||||
def normalize(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Normalize by:
|
||||
- collapsing any roots that don't correspond to a concrete TV AND
|
||||
have a single TV derived from them or equivalent to them
|
||||
|
||||
E.g. if we have a root of the tree that looks like:
|
||||
|
||||
typeof_a typeof_b
|
||||
\ /
|
||||
typeof_x
|
||||
|
|
||||
half_width(1)
|
||||
|
|
||||
1
|
||||
|
||||
we want to collapse the linear path between 1 and typeof_x. The
|
||||
resulting graph is:
|
||||
|
||||
typeof_a typeof_b
|
||||
\ /
|
||||
typeof_x
|
||||
"""
|
||||
source_tvs = set([v.get_typevar() for v in self.vars])
|
||||
children = {} # type: Dict[TypeVar, Set[TypeVar]]
|
||||
for v in self.type_map.values():
|
||||
if not v.is_derived:
|
||||
continue
|
||||
|
||||
t = v.free_typevar()
|
||||
s = children.get(t, set())
|
||||
s.add(v)
|
||||
children[t] = s
|
||||
|
||||
for (a, b) in self.type_map.items():
|
||||
s = children.get(b, set())
|
||||
s.add(a)
|
||||
children[b] = s
|
||||
|
||||
for r in self.free_typevars():
|
||||
while (r not in source_tvs and r in children and
|
||||
len(children[r]) == 1):
|
||||
child = list(children[r])[0]
|
||||
if child in self.type_map:
|
||||
assert self.type_map[child] == r
|
||||
del self.type_map[child]
|
||||
|
||||
r = child
|
||||
|
||||
def extract(self):
|
||||
# type: () -> TypeEnv
|
||||
"""
|
||||
Extract a clean type environment from self, that only mentions
|
||||
TVs associated with real variables
|
||||
"""
|
||||
vars_tvs = set([v.get_typevar() for v in self.vars])
|
||||
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
|
||||
|
||||
new_constraints = [] # type: List[TypeConstraint]
|
||||
for constr in self.constraints:
|
||||
constr = constr.translate(self)
|
||||
|
||||
if constr.is_trivial() or constr in new_constraints:
|
||||
continue
|
||||
|
||||
# Sanity: translated constraints should refer to only real vars
|
||||
for arg in constr._args():
|
||||
if (not isinstance(arg, TypeVar)):
|
||||
continue
|
||||
|
||||
arg_free_tv = arg.free_typevar()
|
||||
assert arg_free_tv is None or arg_free_tv in vars_tvs
|
||||
|
||||
new_constraints.append(constr)
|
||||
|
||||
# Sanity: translated typemap should refer to only real vars
|
||||
for (k, v) in new_type_map.items():
|
||||
assert k in vars_tvs
|
||||
assert v.free_typevar() is None or v.free_typevar() in vars_tvs
|
||||
|
||||
t = TypeEnv()
|
||||
t.type_map = new_type_map
|
||||
t.constraints = new_constraints
|
||||
# ranks and vars contain only TVs associated with real vars
|
||||
t.ranks = copy(self.ranks)
|
||||
t.vars = copy(self.vars)
|
||||
return t
|
||||
|
||||
def concrete_typings(self):
|
||||
# type: () -> Iterable[VarTyping]
|
||||
"""
|
||||
Return an iterable over all possible concrete typings permitted by this
|
||||
TypeEnv.
|
||||
"""
|
||||
free_tvs = self.free_typevars()
|
||||
free_tv_iters = [tv.get_typeset().concrete_types() for tv in free_tvs]
|
||||
for concrete_types in product(*free_tv_iters):
|
||||
# Build type substitutions for all free vars
|
||||
m = {tv: TypeVar.singleton(typ)
|
||||
for (tv, typ) in zip(free_tvs, concrete_types)}
|
||||
|
||||
concrete_var_map = {v: subst(self[v.get_typevar()], m)
|
||||
for v in self.vars}
|
||||
|
||||
# Check if constraints are satisfied for this typing
|
||||
failed = None
|
||||
for constr in self.constraints:
|
||||
concrete_constr = constr.translate(m)
|
||||
if not concrete_constr.eval():
|
||||
failed = concrete_constr
|
||||
break
|
||||
|
||||
if (failed is not None):
|
||||
continue
|
||||
|
||||
yield concrete_var_map
|
||||
|
||||
def permits(self, concrete_typing):
|
||||
# type: (VarTyping) -> bool
|
||||
"""
|
||||
Return true iff this TypeEnv permits the (possibly partial) concrete
|
||||
variable type mapping concrete_typing.
|
||||
"""
|
||||
# Each variable has a concrete type, that is a subset of its inferred
|
||||
# typeset.
|
||||
for (v, typ) in concrete_typing.items():
|
||||
assert typ.singleton_type() is not None
|
||||
if not typ.get_typeset().issubset(self[v].get_typeset()):
|
||||
return False
|
||||
|
||||
m = {self[v]: typ for (v, typ) in concrete_typing.items()}
|
||||
|
||||
# Constraints involving vars in concrete_typing are satisfied
|
||||
for constr in self.constraints:
|
||||
try:
|
||||
# If the constraint includes only vars in concrete_typing, we
|
||||
# can translate it using m. Otherwise we encounter a KeyError
|
||||
# and ignore it
|
||||
constr = constr.translate(m)
|
||||
if not constr.eval():
|
||||
return False
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
def dot(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Return a representation of self as a graph in dot format.
|
||||
Nodes correspond to TypeVariables.
|
||||
Dotted edges correspond to equivalences between TVS
|
||||
Solid edges correspond to derivation relations between TVs.
|
||||
Dashed edges correspond to equivalence constraints.
|
||||
"""
|
||||
def label(s):
|
||||
# type: (TypeVar) -> str
|
||||
return "\"" + str(s) + "\""
|
||||
|
||||
# Add all registered TVs (as some of them may be singleton nodes not
|
||||
# appearing in the graph
|
||||
nodes = set() # type: Set[TypeVar]
|
||||
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
|
||||
|
||||
def add_nodes(*args):
|
||||
# type: (*TypeVar) -> None
|
||||
for tv in args:
|
||||
nodes.add(tv)
|
||||
while (tv.is_derived):
|
||||
nodes.add(tv.base)
|
||||
edges.add((tv, tv.base, "solid", "forward",
|
||||
tv.derived_func))
|
||||
tv = tv.base
|
||||
|
||||
for v in self.vars:
|
||||
add_nodes(v.get_typevar())
|
||||
|
||||
for (tv1, tv2) in self.type_map.items():
|
||||
# Add all intermediate TVs appearing in edges
|
||||
add_nodes(tv1, tv2)
|
||||
edges.add((tv1, tv2, "dotted", "forward", None))
|
||||
|
||||
for constr in self.constraints:
|
||||
if isinstance(constr, TypesEqual):
|
||||
add_nodes(constr.tv1, constr.tv2)
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
|
||||
elif isinstance(constr, WiderOrEq):
|
||||
add_nodes(constr.tv1, constr.tv2)
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
|
||||
elif isinstance(constr, SameWidth):
|
||||
add_nodes(constr.tv1, constr.tv2)
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "none",
|
||||
"same_width"))
|
||||
else:
|
||||
assert False, "Can't display constraint {}".format(constr)
|
||||
|
||||
root_nodes = set([x for x in nodes
|
||||
if x not in self.type_map and not x.is_derived])
|
||||
|
||||
r = "digraph {\n"
|
||||
for n in nodes:
|
||||
r += label(n)
|
||||
if n in root_nodes:
|
||||
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
|
||||
r += ";\n"
|
||||
|
||||
for (n1, n2, style, direction, elabel) in edges:
|
||||
e = label(n1) + "->" + label(n2)
|
||||
e += "[style={},dir={}".format(style, direction)
|
||||
|
||||
if elabel is not None:
|
||||
e += ",label=\"{}\"".format(elabel)
|
||||
e += "];\n"
|
||||
|
||||
r += e
|
||||
r += "}"
|
||||
|
||||
return r
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
TypingError = str
|
||||
TypingOrError = Union[TypeEnv, TypingError]
|
||||
|
||||
|
||||
def get_error(typing_or_err):
|
||||
# type: (TypingOrError) -> Optional[TypingError]
|
||||
"""
|
||||
Helper function to appease mypy when checking the result of typing.
|
||||
"""
|
||||
if isinstance(typing_or_err, str):
|
||||
if (TYPE_CHECKING):
|
||||
return cast(TypingError, typing_or_err)
|
||||
else:
|
||||
return typing_or_err
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_type_env(typing_or_err):
|
||||
# type: (TypingOrError) -> TypeEnv
|
||||
"""
|
||||
Helper function to appease mypy when checking the result of typing.
|
||||
"""
|
||||
assert isinstance(typing_or_err, TypeEnv), \
|
||||
"Unexpected error: {}".format(typing_or_err)
|
||||
|
||||
if (TYPE_CHECKING):
|
||||
return cast(TypeEnv, typing_or_err)
|
||||
else:
|
||||
return typing_or_err
|
||||
|
||||
|
||||
def subst(tv, tv_map):
|
||||
# type: (TypeVar, TypeMap) -> TypeVar
|
||||
"""
|
||||
Perform substition on the input tv using the TypeMap tv_map.
|
||||
"""
|
||||
if tv in tv_map:
|
||||
return tv_map[tv]
|
||||
|
||||
if tv.is_derived:
|
||||
return TypeVar.derived(subst(tv.base, tv_map), tv.derived_func)
|
||||
|
||||
return tv
|
||||
|
||||
|
||||
def normalize_tv(tv):
|
||||
# type: (TypeVar) -> TypeVar
|
||||
"""
|
||||
Normalize a (potentially derived) TV using the following rules:
|
||||
- vector and width derived functions commute
|
||||
{HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) ->
|
||||
{HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base))
|
||||
|
||||
- half/double pairs collapse
|
||||
{HALF,DOUBLE}WIDTH({DOUBLE,HALF}WIDTH(base)) -> base
|
||||
{HALF,DOUBLE}VECTOR({DOUBLE,HALF}VECTOR(base)) -> base
|
||||
"""
|
||||
vector_derives = [TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR]
|
||||
width_derives = [TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
|
||||
|
||||
if not tv.is_derived:
|
||||
return tv
|
||||
|
||||
df = tv.derived_func
|
||||
|
||||
if (tv.base.is_derived):
|
||||
base_df = tv.base.derived_func
|
||||
|
||||
# Reordering: {HALFWIDTH, DOUBLEWIDTH} commute with {HALFVECTOR,
|
||||
# DOUBLEVECTOR}. Arbitrarily pick WIDTH < VECTOR
|
||||
if df in vector_derives and base_df in width_derives:
|
||||
return normalize_tv(
|
||||
TypeVar.derived(
|
||||
TypeVar.derived(tv.base.base, df), base_df))
|
||||
|
||||
# Cancelling: HALFWIDTH, DOUBLEWIDTH and HALFVECTOR, DOUBLEVECTOR
|
||||
# cancel each other. Note: This doesn't hide any over/underflows,
|
||||
# since we 1) assert the safety of each TV in the chain upon its
|
||||
# creation, and 2) the base typeset is only allowed to shrink.
|
||||
|
||||
if (df, base_df) in \
|
||||
[(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR),
|
||||
(TypeVar.DOUBLEVECTOR, TypeVar.HALFVECTOR),
|
||||
(TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH),
|
||||
(TypeVar.DOUBLEWIDTH, TypeVar.HALFWIDTH)]:
|
||||
return normalize_tv(tv.base.base)
|
||||
|
||||
return TypeVar.derived(normalize_tv(tv.base), df)
|
||||
|
||||
|
||||
def constrain_fixpoint(tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
"""
|
||||
Given typevars tv1 and tv2 (which could be derived from one another)
|
||||
constrain their typesets to be the same. When one is derived from the
|
||||
other, repeat the constrain process until fixpoint.
|
||||
"""
|
||||
# Constrain tv2's typeset as long as tv1's typeset is changing.
|
||||
while True:
|
||||
old_tv1_ts = tv1.get_typeset().copy()
|
||||
tv2.constrain_types(tv1)
|
||||
if tv1.get_typeset() == old_tv1_ts:
|
||||
break
|
||||
|
||||
old_tv2_ts = tv2.get_typeset().copy()
|
||||
tv1.constrain_types(tv2)
|
||||
assert old_tv2_ts == tv2.get_typeset()
|
||||
|
||||
|
||||
def unify(tv1, tv2, typ):
|
||||
# type: (TypeVar, TypeVar, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Unify tv1 and tv2 in the current type environment typ, and return an
|
||||
updated type environment or error.
|
||||
"""
|
||||
tv1 = normalize_tv(typ[tv1])
|
||||
tv2 = normalize_tv(typ[tv2])
|
||||
|
||||
# Already unified
|
||||
if tv1 == tv2:
|
||||
return typ
|
||||
|
||||
if typ.rank(tv2) < typ.rank(tv1):
|
||||
return unify(tv2, tv1, typ)
|
||||
|
||||
constrain_fixpoint(tv1, tv2)
|
||||
|
||||
if (tv1.get_typeset().size() == 0 or tv2.get_typeset().size() == 0):
|
||||
return "Error: empty type created when unifying {} and {}"\
|
||||
.format(tv1, tv2)
|
||||
|
||||
# Free -> Derived(Free)
|
||||
if not tv1.is_derived:
|
||||
typ.equivalent(tv1, tv2)
|
||||
return typ
|
||||
|
||||
if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)):
|
||||
inv_f = TypeVar.inverse_func(tv1.derived_func)
|
||||
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
|
||||
|
||||
typ.add_constraint(TypesEqual(tv1, tv2))
|
||||
return typ
|
||||
|
||||
|
||||
def move_first(l, i):
|
||||
# type: (List[T], int) -> List[T]
|
||||
return [l[i]] + l[:i] + l[i+1:]
|
||||
|
||||
|
||||
def ti_def(definition, typ):
|
||||
# type: (Def, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Perform type inference on one Def in the current type environment typ and
|
||||
return an updated type environment or error.
|
||||
|
||||
At a high level this works by creating fresh copies of each formal type var
|
||||
in the Def's instruction's signature, and unifying the formal tv with the
|
||||
corresponding actual tv.
|
||||
"""
|
||||
expr = definition.expr
|
||||
inst = expr.inst
|
||||
|
||||
# Create a dict m mapping each free typevar in the signature of definition
|
||||
# to a fresh copy of itself.
|
||||
free_formal_tvs = inst.all_typevars()
|
||||
m = {tv: tv.get_fresh_copy(str(typ.get_uid())) for tv in free_formal_tvs}
|
||||
|
||||
# Update m with any explicitly bound type vars
|
||||
for (idx, bound_typ) in enumerate(expr.typevars):
|
||||
m[free_formal_tvs[idx]] = TypeVar.singleton(bound_typ)
|
||||
|
||||
# Get fresh copies for each typevar in the signature (both free and
|
||||
# derived)
|
||||
fresh_formal_tvs = \
|
||||
[subst(inst.outs[i].typevar, m) for i in inst.value_results] +\
|
||||
[subst(inst.ins[i].typevar, m) for i in inst.value_opnums]
|
||||
|
||||
# Get the list of actual Vars
|
||||
actual_vars = [] # type: List[Expr]
|
||||
actual_vars += [definition.defs[i] for i in inst.value_results]
|
||||
actual_vars += [expr.args[i] for i in inst.value_opnums]
|
||||
|
||||
# Get the list of the actual TypeVars
|
||||
actual_tvs = []
|
||||
for v in actual_vars:
|
||||
assert(isinstance(v, Var))
|
||||
# Register with TypeEnv that this typevar corresponds ot variable v,
|
||||
# and thus has a given rank
|
||||
typ.register(v)
|
||||
actual_tvs.append(v.get_typevar())
|
||||
|
||||
# Make sure we unify the control typevar first.
|
||||
if inst.is_polymorphic:
|
||||
idx = fresh_formal_tvs.index(m[inst.ctrl_typevar])
|
||||
fresh_formal_tvs = move_first(fresh_formal_tvs, idx)
|
||||
actual_tvs = move_first(actual_tvs, idx)
|
||||
|
||||
# Unify each actual typevar with the correpsonding fresh formal tv
|
||||
for (actual_tv, formal_tv) in zip(actual_tvs, fresh_formal_tvs):
|
||||
typ_or_err = unify(actual_tv, formal_tv, typ)
|
||||
err = get_error(typ_or_err)
|
||||
if (err):
|
||||
return "fail ti on {} <: {}: ".format(actual_tv, formal_tv) + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
# Add any instruction specific constraints
|
||||
for constr in inst.constraints:
|
||||
typ.add_constraint(constr.translate(m))
|
||||
|
||||
return typ
|
||||
|
||||
|
||||
def ti_rtl(rtl, typ):
|
||||
# type: (Rtl, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Perform type inference on an Rtl in a starting type env typ. Return an
|
||||
updated type environment or error.
|
||||
"""
|
||||
for (i, d) in enumerate(rtl.rtl):
|
||||
assert (isinstance(d, Def))
|
||||
typ_or_err = ti_def(d, typ)
|
||||
err = get_error(typ_or_err) # type: Optional[TypingError]
|
||||
if (err):
|
||||
return "On line {}: ".format(i) + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
return typ
|
||||
|
||||
|
||||
def ti_xform(xform, typ):
|
||||
# type: (XForm, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Perform type inference on an Rtl in a starting type env typ. Return an
|
||||
updated type environment or error.
|
||||
"""
|
||||
typ_or_err = ti_rtl(xform.src, typ)
|
||||
err = get_error(typ_or_err) # type: Optional[TypingError]
|
||||
if (err):
|
||||
return "In src pattern: " + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
typ_or_err = ti_rtl(xform.dst, typ)
|
||||
err = get_error(typ_or_err)
|
||||
if (err):
|
||||
return "In dst pattern: " + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
return get_type_env(typ_or_err)
|
||||
348
lib/codegen/meta-python/cdsl/types.py
Normal file
348
lib/codegen/meta-python/cdsl/types.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""Cranelift ValueType hierarchy"""
|
||||
from __future__ import absolute_import
|
||||
import math
|
||||
|
||||
try:
|
||||
from typing import Dict, List, cast, TYPE_CHECKING # noqa
|
||||
except ImportError:
|
||||
TYPE_CHECKING = False
|
||||
pass
|
||||
|
||||
|
||||
# Numbering scheme for value types:
|
||||
#
|
||||
# 0: Void
|
||||
# 0x01-0x6f: Special types
|
||||
# 0x70-0x7f: Lane types
|
||||
# 0x80-0xff: Vector types
|
||||
#
|
||||
# Vector types are encoded with the lane type in the low 4 bits and log2(lanes)
|
||||
# in the high 4 bits, giving a range of 2-256 lanes.
|
||||
LANE_BASE = 0x70
|
||||
|
||||
|
||||
# ValueType instances (i8, i32, ...) are provided in the `base.types` module.
|
||||
class ValueType(object):
|
||||
"""
|
||||
A concrete SSA value type.
|
||||
|
||||
All SSA values have a type that is described by an instance of `ValueType`
|
||||
or one of its subclasses.
|
||||
"""
|
||||
|
||||
# Map name -> ValueType.
|
||||
_registry = dict() # type: Dict[str, ValueType]
|
||||
|
||||
# List of all the lane types.
|
||||
all_lane_types = list() # type: List[LaneType]
|
||||
|
||||
# List of all the special types (neither lanes nor vectors).
|
||||
all_special_types = list() # type: List[SpecialType]
|
||||
|
||||
def __init__(self, name, membytes, doc):
|
||||
# type: (str, int, str) -> None
|
||||
self.name = name
|
||||
self.number = None # type: int
|
||||
self.membytes = membytes
|
||||
self.__doc__ = doc
|
||||
assert name not in ValueType._registry
|
||||
ValueType._registry[name] = self
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return self.name
|
||||
|
||||
def rust_name(self):
|
||||
# type: () -> str
|
||||
return 'ir::types::' + self.name.upper()
|
||||
|
||||
@staticmethod
|
||||
def by_name(name):
|
||||
# type: (str) -> ValueType
|
||||
if name in ValueType._registry:
|
||||
return ValueType._registry[name]
|
||||
else:
|
||||
raise AttributeError("No type named '{}'".format(name))
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
assert False, "Abstract"
|
||||
|
||||
def lane_count(self):
|
||||
# type: () -> int
|
||||
"""Return the number of lanes."""
|
||||
assert False, "Abstract"
|
||||
|
||||
def width(self):
|
||||
# type: () -> int
|
||||
"""Return the total number of bits of an instance of this type."""
|
||||
return self.lane_count() * self.lane_bits()
|
||||
|
||||
def wider_or_equal(self, other):
|
||||
# type: (ValueType) -> bool
|
||||
"""
|
||||
Return true iff:
|
||||
1. self and other have equal number of lanes
|
||||
2. each lane in self has at least as many bits as a lane in other
|
||||
"""
|
||||
return (self.lane_count() == other.lane_count() and
|
||||
self.lane_bits() >= other.lane_bits())
|
||||
|
||||
|
||||
class LaneType(ValueType):
|
||||
"""
|
||||
A concrete scalar type that can appear as a vector lane too.
|
||||
|
||||
Also tracks a unique set of :py:class:`VectorType` instances with this type
|
||||
as the lane type.
|
||||
"""
|
||||
|
||||
def __init__(self, name, membytes, doc):
|
||||
# type: (str, int, str) -> None
|
||||
super(LaneType, self).__init__(name, membytes, doc)
|
||||
self._vectors = dict() # type: Dict[int, VectorType]
|
||||
# Assign numbers starting from LANE_BASE.
|
||||
n = len(ValueType.all_lane_types)
|
||||
ValueType.all_lane_types.append(self)
|
||||
assert n < 16, 'Too many lane types'
|
||||
self.number = LANE_BASE + n
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return 'LaneType({})'.format(self.name)
|
||||
|
||||
def by(self, lanes):
|
||||
# type: (int) -> VectorType
|
||||
"""
|
||||
Get a vector type with this type as the lane type.
|
||||
|
||||
For example, ``i32.by(4)`` returns the :obj:`i32x4` type.
|
||||
"""
|
||||
if lanes in self._vectors:
|
||||
return self._vectors[lanes]
|
||||
else:
|
||||
v = VectorType(self, lanes)
|
||||
self._vectors[lanes] = v
|
||||
return v
|
||||
|
||||
def lane_count(self):
|
||||
# type: () -> int
|
||||
"""Return the number of lanes."""
|
||||
return 1
|
||||
|
||||
|
||||
class VectorType(ValueType):
|
||||
"""
|
||||
A concrete SIMD vector type.
|
||||
|
||||
A vector type has a lane type which is an instance of :class:`LaneType`,
|
||||
and a positive number of lanes.
|
||||
"""
|
||||
|
||||
def __init__(self, base, lanes):
|
||||
# type: (LaneType, int) -> None
|
||||
super(VectorType, self).__init__(
|
||||
name='{}x{}'.format(base.name, lanes),
|
||||
membytes=lanes*base.membytes,
|
||||
doc="""
|
||||
A SIMD vector with {} lanes containing a `{}` each.
|
||||
""".format(lanes, base.name))
|
||||
assert lanes <= 256, "Too many lanes"
|
||||
self.base = base
|
||||
self.lanes = lanes
|
||||
self.number = 16*int(math.log(lanes, 2)) + base.number
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return ('VectorType(base={}, lanes={})'
|
||||
.format(self.base.name, self.lanes))
|
||||
|
||||
def lane_count(self):
|
||||
# type: () -> int
|
||||
"""Return the number of lanes."""
|
||||
return self.lanes
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
return self.base.lane_bits()
|
||||
|
||||
|
||||
class SpecialType(ValueType):
|
||||
"""
|
||||
A concrete scalar type that is neither a vector nor a lane type.
|
||||
|
||||
Special types cannot be used to form vectors.
|
||||
"""
|
||||
|
||||
def __init__(self, name, membytes, doc):
|
||||
# type: (str, int, str) -> None
|
||||
super(SpecialType, self).__init__(name, membytes, doc)
|
||||
# Assign numbers starting from 1. (0 is VOID)
|
||||
ValueType.all_special_types.append(self)
|
||||
self.number = len(ValueType.all_special_types)
|
||||
assert self.number < LANE_BASE, 'Too many special types'
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return 'SpecialType({})'.format(self.name)
|
||||
|
||||
def lane_count(self):
|
||||
# type: () -> int
|
||||
"""Return the number of lanes."""
|
||||
return 1
|
||||
|
||||
|
||||
class IntType(LaneType):
|
||||
"""A concrete scalar integer type."""
|
||||
|
||||
def __init__(self, bits):
|
||||
# type: (int) -> None
|
||||
assert bits > 0, 'IntType must have positive number of bits'
|
||||
warning = ""
|
||||
if bits < 32:
|
||||
warning += "\nWARNING: "
|
||||
warning += "arithmetic on {}bit integers is incomplete".format(
|
||||
bits)
|
||||
super(IntType, self).__init__(
|
||||
name='i{:d}'.format(bits),
|
||||
membytes=bits // 8,
|
||||
doc="An integer type with {} bits.{}".format(bits, warning))
|
||||
self.bits = bits
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return 'IntType(bits={})'.format(self.bits)
|
||||
|
||||
@staticmethod
|
||||
def with_bits(bits):
|
||||
# type: (int) -> IntType
|
||||
typ = ValueType.by_name('i{:d}'.format(bits))
|
||||
if TYPE_CHECKING:
|
||||
return cast(IntType, typ)
|
||||
else:
|
||||
return typ
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
return self.bits
|
||||
|
||||
|
||||
class FloatType(LaneType):
|
||||
"""A concrete scalar floating point type."""
|
||||
|
||||
def __init__(self, bits, doc):
|
||||
# type: (int, str) -> None
|
||||
assert bits > 0, 'FloatType must have positive number of bits'
|
||||
super(FloatType, self).__init__(
|
||||
name='f{:d}'.format(bits),
|
||||
membytes=bits // 8,
|
||||
doc=doc)
|
||||
self.bits = bits
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return 'FloatType(bits={})'.format(self.bits)
|
||||
|
||||
@staticmethod
|
||||
def with_bits(bits):
|
||||
# type: (int) -> FloatType
|
||||
typ = ValueType.by_name('f{:d}'.format(bits))
|
||||
if TYPE_CHECKING:
|
||||
return cast(FloatType, typ)
|
||||
else:
|
||||
return typ
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
return self.bits
|
||||
|
||||
|
||||
class BoolType(LaneType):
|
||||
"""A concrete scalar boolean type."""
|
||||
|
||||
def __init__(self, bits):
|
||||
# type: (int) -> None
|
||||
assert bits > 0, 'BoolType must have positive number of bits'
|
||||
super(BoolType, self).__init__(
|
||||
name='b{:d}'.format(bits),
|
||||
membytes=bits // 8,
|
||||
doc="A boolean type with {} bits.".format(bits))
|
||||
self.bits = bits
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return 'BoolType(bits={})'.format(self.bits)
|
||||
|
||||
@staticmethod
|
||||
def with_bits(bits):
|
||||
# type: (int) -> BoolType
|
||||
typ = ValueType.by_name('b{:d}'.format(bits))
|
||||
if TYPE_CHECKING:
|
||||
return cast(BoolType, typ)
|
||||
else:
|
||||
return typ
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
return self.bits
|
||||
|
||||
|
||||
class FlagsType(SpecialType):
|
||||
"""
|
||||
A type representing CPU flags.
|
||||
|
||||
Flags can't be stored in memory.
|
||||
"""
|
||||
|
||||
def __init__(self, name, doc):
|
||||
# type: (str, str) -> None
|
||||
super(FlagsType, self).__init__(name, 0, doc)
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return 'FlagsType({})'.format(self.name)
|
||||
|
||||
|
||||
class BVType(ValueType):
|
||||
"""A flat bitvector type. Used for semantics description only."""
|
||||
|
||||
def __init__(self, bits):
|
||||
# type: (int) -> None
|
||||
assert bits > 0, 'Must have positive number of bits'
|
||||
super(BVType, self).__init__(
|
||||
name='bv{:d}'.format(bits),
|
||||
membytes=bits // 8,
|
||||
doc="A bitvector type with {} bits.".format(bits))
|
||||
self.bits = bits
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return 'BVType(bits={})'.format(self.bits)
|
||||
|
||||
@staticmethod
|
||||
def with_bits(bits):
|
||||
# type: (int) -> BVType
|
||||
name = 'bv{:d}'.format(bits)
|
||||
if name not in ValueType._registry:
|
||||
return BVType(bits)
|
||||
|
||||
typ = ValueType.by_name(name)
|
||||
if TYPE_CHECKING:
|
||||
return cast(BVType, typ)
|
||||
else:
|
||||
return typ
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
return self.bits
|
||||
|
||||
def lane_count(self):
|
||||
# type: () -> int
|
||||
"""Return the number of lane. For BVtypes always 1."""
|
||||
return 1
|
||||
906
lib/codegen/meta-python/cdsl/typevar.py
Normal file
906
lib/codegen/meta-python/cdsl/typevar.py
Normal file
@@ -0,0 +1,906 @@
|
||||
"""
|
||||
Type variables for Parametric polymorphism.
|
||||
|
||||
Cranelift instructions and instruction transformations can be specified to be
|
||||
polymorphic by using type variables.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
import math
|
||||
from . import types, is_power_of_two
|
||||
from copy import copy
|
||||
|
||||
try:
|
||||
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
|
||||
if TYPE_CHECKING:
|
||||
from srcgen import Formatter # noqa
|
||||
Interval = Tuple[int, int]
|
||||
# An Interval where `True` means 'everything'
|
||||
BoolInterval = Union[bool, Interval]
|
||||
# Set of special types: None, False, True, or iterable.
|
||||
SpecialSpec = Union[bool, Iterable[types.SpecialType]]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
MAX_LANES = 256
|
||||
MAX_BITS = 64
|
||||
MAX_BITVEC = MAX_BITS * MAX_LANES
|
||||
|
||||
|
||||
def int_log2(x):
|
||||
# type: (int) -> int
|
||||
return int(math.log(x, 2))
|
||||
|
||||
|
||||
def intersect(a, b):
|
||||
# type: (Interval, Interval) -> Interval
|
||||
"""
|
||||
Given two `(min, max)` inclusive intervals, compute their intersection.
|
||||
|
||||
Use `(None, None)` to represent the empty interval on input and output.
|
||||
"""
|
||||
if a[0] is None or b[0] is None:
|
||||
return (None, None)
|
||||
lo = max(a[0], b[0])
|
||||
assert lo is not None
|
||||
hi = min(a[1], b[1])
|
||||
assert hi is not None
|
||||
if lo <= hi:
|
||||
return (lo, hi)
|
||||
else:
|
||||
return (None, None)
|
||||
|
||||
|
||||
def is_empty(intv):
|
||||
# type: (Interval) -> bool
|
||||
return intv is None or intv is False or intv == (None, None)
|
||||
|
||||
|
||||
def encode_bitset(vals, size):
|
||||
# type: (Iterable[int], int) -> int
|
||||
"""
|
||||
Encode a set of values (each between 0 and size) as a bitset of width size.
|
||||
"""
|
||||
res = 0
|
||||
assert is_power_of_two(size) and size <= 64
|
||||
for v in vals:
|
||||
assert 0 <= v and v < size
|
||||
res |= 1 << v
|
||||
return res
|
||||
|
||||
|
||||
def pp_set(s):
|
||||
# type: (Iterable[Any]) -> str
|
||||
"""
|
||||
Return a consistent string representation of a set (ordering is fixed)
|
||||
"""
|
||||
return '{' + ', '.join([repr(x) for x in sorted(s)]) + '}'
|
||||
|
||||
|
||||
def decode_interval(intv, full_range, default=None):
|
||||
# type: (BoolInterval, Interval, int) -> Interval
|
||||
"""
|
||||
Decode an interval specification which can take the following values:
|
||||
|
||||
True
|
||||
Use the `full_range`.
|
||||
`False` or `None`
|
||||
An empty interval
|
||||
(lo, hi)
|
||||
An explicit interval
|
||||
"""
|
||||
if isinstance(intv, tuple):
|
||||
# mypy bug here: 'builtins.None' object is not iterable
|
||||
lo, hi = intv
|
||||
assert is_power_of_two(lo)
|
||||
assert is_power_of_two(hi)
|
||||
assert lo <= hi
|
||||
assert lo >= full_range[0]
|
||||
assert hi <= full_range[1]
|
||||
return intv
|
||||
|
||||
if intv:
|
||||
return full_range
|
||||
else:
|
||||
return (default, default)
|
||||
|
||||
|
||||
def interval_to_set(intv):
|
||||
# type: (Interval) -> Set
|
||||
if is_empty(intv):
|
||||
return set()
|
||||
|
||||
(lo, hi) = intv
|
||||
assert is_power_of_two(lo)
|
||||
assert is_power_of_two(hi)
|
||||
assert lo <= hi
|
||||
return set([2**i for i in range(int_log2(lo), int_log2(hi)+1)])
|
||||
|
||||
|
||||
def legal_bool(bits):
|
||||
# type: (int) -> bool
|
||||
"""
|
||||
True iff bits is a legal bit width for a bool type.
|
||||
bits == 1 || bits \in { 8, 16, .. MAX_BITS }
|
||||
"""
|
||||
return bits == 1 or \
|
||||
(bits >= 8 and bits <= MAX_BITS and is_power_of_two(bits))
|
||||
|
||||
|
||||
class TypeSet(object):
|
||||
"""
|
||||
A set of types.
|
||||
|
||||
We don't allow arbitrary subsets of types, but use a parametrized approach
|
||||
instead.
|
||||
|
||||
Objects of this class can be used as dictionary keys.
|
||||
|
||||
Parametrized type sets are specified in terms of ranges:
|
||||
|
||||
- The permitted range of vector lanes, where 1 indicates a scalar type.
|
||||
- The permitted range of integer types.
|
||||
- The permitted range of floating point types, and
|
||||
- The permitted range of boolean types.
|
||||
|
||||
The ranges are inclusive from smallest bit-width to largest bit-width.
|
||||
|
||||
A typeset representing scalar integer types `i8` through `i32`:
|
||||
|
||||
>>> TypeSet(ints=(8, 32))
|
||||
TypeSet(lanes={1}, ints={8, 16, 32})
|
||||
|
||||
Passing `True` instead of a range selects all available scalar types:
|
||||
|
||||
>>> TypeSet(ints=True)
|
||||
TypeSet(lanes={1}, ints={8, 16, 32, 64})
|
||||
>>> TypeSet(floats=True)
|
||||
TypeSet(lanes={1}, floats={32, 64})
|
||||
>>> TypeSet(bools=True)
|
||||
TypeSet(lanes={1}, bools={1, 8, 16, 32, 64})
|
||||
|
||||
Similarly, passing `True` for the lanes selects all possible scalar and
|
||||
vector types:
|
||||
|
||||
>>> TypeSet(lanes=True, ints=True)
|
||||
TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256}, ints={8, 16, 32, 64})
|
||||
|
||||
Finally, a type set can contain special types (derived from `SpecialType`)
|
||||
which can't appear as lane types.
|
||||
|
||||
:param lanes: `(min, max)` inclusive range of permitted vector lane counts.
|
||||
:param ints: `(min, max)` inclusive range of permitted scalar integer
|
||||
widths.
|
||||
:param floats: `(min, max)` inclusive range of permitted scalar floating
|
||||
point widths.
|
||||
:param bools: `(min, max)` inclusive range of permitted scalar boolean
|
||||
widths.
|
||||
:param bitvecs : `(min, max)` inclusive range of permitted bitvector
|
||||
widths.
|
||||
:param specials: Sequence of special types to appear in the set.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lanes=None, # type: BoolInterval
|
||||
ints=None, # type: BoolInterval
|
||||
floats=None, # type: BoolInterval
|
||||
bools=None, # type: BoolInterval
|
||||
bitvecs=None, # type: BoolInterval
|
||||
specials=None # type: SpecialSpec
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
|
||||
self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS)))
|
||||
self.floats = interval_to_set(decode_interval(floats, (32, 64)))
|
||||
self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS)))
|
||||
self.bools = set(filter(legal_bool, self.bools))
|
||||
self.bitvecs = interval_to_set(decode_interval(bitvecs,
|
||||
(1, MAX_BITVEC)))
|
||||
# Allow specials=None, specials=True, specials=(...)
|
||||
self.specials = set() # type: Set[types.SpecialType]
|
||||
if isinstance(specials, bool):
|
||||
if specials:
|
||||
self.specials = set(types.ValueType.all_special_types)
|
||||
elif specials:
|
||||
self.specials = set(specials)
|
||||
|
||||
def copy(self):
|
||||
# type: (TypeSet) -> TypeSet
|
||||
"""
|
||||
Return a copy of our self.
|
||||
"""
|
||||
n = TypeSet()
|
||||
n.lanes = copy(self.lanes)
|
||||
n.ints = copy(self.ints)
|
||||
n.floats = copy(self.floats)
|
||||
n.bools = copy(self.bools)
|
||||
n.bitvecs = copy(self.bitvecs)
|
||||
n.specials = copy(self.specials)
|
||||
return n
|
||||
|
||||
def typeset_key(self):
|
||||
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple, Tuple, Tuple]
|
||||
"""Key tuple used for hashing and equality."""
|
||||
return (tuple(sorted(list(self.lanes))),
|
||||
tuple(sorted(list(self.ints))),
|
||||
tuple(sorted(list(self.floats))),
|
||||
tuple(sorted(list(self.bools))),
|
||||
tuple(sorted(list(self.bitvecs))),
|
||||
tuple(sorted(s.name for s in self.specials)))
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
h = hash(self.typeset_key())
|
||||
assert h == getattr(self, 'prev_hash', h), "TypeSet changed"
|
||||
self.prev_hash = h
|
||||
return h
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if isinstance(other, TypeSet):
|
||||
return self.typeset_key() == other.typeset_key()
|
||||
else:
|
||||
return False
|
||||
|
||||
def __ne__(self, other):
|
||||
# type: (object) -> bool
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
s = 'TypeSet(lanes={}'.format(pp_set(self.lanes))
|
||||
if len(self.ints) > 0:
|
||||
s += ', ints={}'.format(pp_set(self.ints))
|
||||
if len(self.floats) > 0:
|
||||
s += ', floats={}'.format(pp_set(self.floats))
|
||||
if len(self.bools) > 0:
|
||||
s += ', bools={}'.format(pp_set(self.bools))
|
||||
if len(self.bitvecs) > 0:
|
||||
s += ', bitvecs={}'.format(pp_set(self.bitvecs))
|
||||
if len(self.specials) > 0:
|
||||
s += ', specials=[{}]'.format(pp_set(self.specials))
|
||||
return s + ')'
|
||||
|
||||
def emit_fields(self, fmt):
|
||||
# type: (Formatter) -> None
|
||||
"""Emit field initializers for this typeset."""
|
||||
assert len(self.bitvecs) == 0, "Bitvector types are not emitable."
|
||||
fmt.comment(repr(self))
|
||||
|
||||
fields = (('lanes', 16),
|
||||
('ints', 8),
|
||||
('floats', 8),
|
||||
('bools', 8))
|
||||
|
||||
for (field, bits) in fields:
|
||||
vals = [int_log2(x) for x in getattr(self, field)]
|
||||
fmt.line('{}: BitSet::<u{}>({}),'
|
||||
.format(field, bits, encode_bitset(vals, bits)))
|
||||
|
||||
def __iand__(self, other):
|
||||
# type: (TypeSet) -> TypeSet
|
||||
"""
|
||||
Intersect self with other type set.
|
||||
|
||||
>>> a = TypeSet(lanes=True, ints=(16, 32))
|
||||
>>> a
|
||||
TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256}, ints={16, 32})
|
||||
>>> b = TypeSet(lanes=(4, 16), ints=True)
|
||||
>>> a &= b
|
||||
>>> a
|
||||
TypeSet(lanes={4, 8, 16}, ints={16, 32})
|
||||
|
||||
>>> a = TypeSet(lanes=True, bools=(1, 8))
|
||||
>>> b = TypeSet(lanes=True, bools=(16, 32))
|
||||
>>> a &= b
|
||||
>>> a
|
||||
TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256})
|
||||
"""
|
||||
self.lanes.intersection_update(other.lanes)
|
||||
self.ints.intersection_update(other.ints)
|
||||
self.floats.intersection_update(other.floats)
|
||||
self.bools.intersection_update(other.bools)
|
||||
self.bitvecs.intersection_update(other.bitvecs)
|
||||
self.specials.intersection_update(other.specials)
|
||||
|
||||
return self
|
||||
|
||||
def issubset(self, other):
|
||||
# type: (TypeSet) -> bool
|
||||
"""
|
||||
Return true iff self is a subset of other
|
||||
"""
|
||||
return self.lanes.issubset(other.lanes) and \
|
||||
self.ints.issubset(other.ints) and \
|
||||
self.floats.issubset(other.floats) and \
|
||||
self.bools.issubset(other.bools) and \
|
||||
self.bitvecs.issubset(other.bitvecs) and \
|
||||
self.specials.issubset(other.specials)
|
||||
|
||||
def lane_of(self):
|
||||
# type: () -> TypeSet
|
||||
"""
|
||||
Return a TypeSet describing the image of self across lane_of
|
||||
"""
|
||||
new = self.copy()
|
||||
new.lanes = set([1])
|
||||
new.bitvecs = set()
|
||||
return new
|
||||
|
||||
def as_bool(self):
|
||||
# type: () -> TypeSet
|
||||
"""
|
||||
Return a TypeSet describing the image of self across as_bool
|
||||
"""
|
||||
new = self.copy()
|
||||
new.ints = set()
|
||||
new.floats = set()
|
||||
new.bitvecs = set()
|
||||
|
||||
if len(self.lanes.difference(set([1]))) > 0:
|
||||
new.bools = self.ints.union(self.floats).union(self.bools)
|
||||
|
||||
if 1 in self.lanes:
|
||||
new.bools.add(1)
|
||||
return new
|
||||
|
||||
def half_width(self):
|
||||
# type: () -> TypeSet
|
||||
"""
|
||||
Return a TypeSet describing the image of self across halfwidth
|
||||
"""
|
||||
new = self.copy()
|
||||
new.ints = set([x//2 for x in self.ints if x > 8])
|
||||
new.floats = set([x//2 for x in self.floats if x > 32])
|
||||
new.bools = set([x//2 for x in self.bools if x > 8])
|
||||
new.bitvecs = set([x//2 for x in self.bitvecs if x > 1])
|
||||
new.specials = set()
|
||||
|
||||
return new
|
||||
|
||||
def double_width(self):
|
||||
# type: () -> TypeSet
|
||||
"""
|
||||
Return a TypeSet describing the image of self across doublewidth
|
||||
"""
|
||||
new = self.copy()
|
||||
new.ints = set([x*2 for x in self.ints if x < MAX_BITS])
|
||||
new.floats = set([x*2 for x in self.floats if x < MAX_BITS])
|
||||
new.bools = set(filter(legal_bool,
|
||||
set([x*2 for x in self.bools if x < MAX_BITS])))
|
||||
new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC])
|
||||
new.specials = set()
|
||||
|
||||
return new
|
||||
|
||||
def half_vector(self):
|
||||
# type: () -> TypeSet
|
||||
"""
|
||||
Return a TypeSet describing the image of self across halfvector
|
||||
"""
|
||||
new = self.copy()
|
||||
new.bitvecs = set()
|
||||
new.lanes = set([x//2 for x in self.lanes if x > 1])
|
||||
new.specials = set()
|
||||
|
||||
return new
|
||||
|
||||
def double_vector(self):
|
||||
# type: () -> TypeSet
|
||||
"""
|
||||
Return a TypeSet describing the image of self across doublevector
|
||||
"""
|
||||
new = self.copy()
|
||||
new.bitvecs = set()
|
||||
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
|
||||
new.specials = set()
|
||||
|
||||
return new
|
||||
|
||||
def to_bitvec(self):
|
||||
# type: () -> TypeSet
|
||||
"""
|
||||
Return a TypeSet describing the image of self across to_bitvec
|
||||
"""
|
||||
assert len(self.bitvecs) == 0
|
||||
all_scalars = self.ints.union(self.floats.union(self.bools))
|
||||
|
||||
new = self.copy()
|
||||
new.lanes = set([1])
|
||||
new.ints = set()
|
||||
new.bools = set()
|
||||
new.floats = set()
|
||||
new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars
|
||||
for nlanes in self.lanes])
|
||||
new.specials = set()
|
||||
|
||||
return new
|
||||
|
||||
def image(self, func):
|
||||
# type: (str) -> TypeSet
|
||||
"""
|
||||
Return the image of self across the derived function func
|
||||
"""
|
||||
if (func == TypeVar.LANEOF):
|
||||
return self.lane_of()
|
||||
elif (func == TypeVar.ASBOOL):
|
||||
return self.as_bool()
|
||||
elif (func == TypeVar.HALFWIDTH):
|
||||
return self.half_width()
|
||||
elif (func == TypeVar.DOUBLEWIDTH):
|
||||
return self.double_width()
|
||||
elif (func == TypeVar.HALFVECTOR):
|
||||
return self.half_vector()
|
||||
elif (func == TypeVar.DOUBLEVECTOR):
|
||||
return self.double_vector()
|
||||
elif (func == TypeVar.TOBITVEC):
|
||||
return self.to_bitvec()
|
||||
else:
|
||||
assert False, "Unknown derived function: " + func
|
||||
|
||||
def preimage(self, func):
|
||||
# type: (str) -> TypeSet
|
||||
"""
|
||||
Return the inverse image of self across the derived function func
|
||||
"""
|
||||
# The inverse of the empty set is always empty
|
||||
if (self.size() == 0):
|
||||
return self
|
||||
|
||||
if (func == TypeVar.LANEOF):
|
||||
new = self.copy()
|
||||
new.bitvecs = set()
|
||||
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
|
||||
return new
|
||||
elif (func == TypeVar.ASBOOL):
|
||||
new = self.copy()
|
||||
new.bitvecs = set()
|
||||
|
||||
if 1 not in self.bools:
|
||||
new.ints = self.bools.difference(set([1]))
|
||||
new.floats = self.bools.intersection(set([32, 64]))
|
||||
# If b1 is not in our typeset, than lanes=1 cannot be in the
|
||||
# pre-image, as as_bool() of scalars is always b1.
|
||||
new.lanes = self.lanes.difference(set([1]))
|
||||
else:
|
||||
new.ints = set([2**x for x in range(3, 7)])
|
||||
new.floats = set([32, 64])
|
||||
|
||||
return new
|
||||
elif (func == TypeVar.HALFWIDTH):
|
||||
return self.double_width()
|
||||
elif (func == TypeVar.DOUBLEWIDTH):
|
||||
return self.half_width()
|
||||
elif (func == TypeVar.HALFVECTOR):
|
||||
return self.double_vector()
|
||||
elif (func == TypeVar.DOUBLEVECTOR):
|
||||
return self.half_vector()
|
||||
elif (func == TypeVar.TOBITVEC):
|
||||
new = TypeSet()
|
||||
|
||||
# Start with all possible lanes/ints/floats/bools
|
||||
lanes = interval_to_set(decode_interval(True, (1, MAX_LANES), 1))
|
||||
ints = interval_to_set(decode_interval(True, (8, MAX_BITS)))
|
||||
floats = interval_to_set(decode_interval(True, (32, 64)))
|
||||
bools = interval_to_set(decode_interval(True, (1, MAX_BITS)))
|
||||
|
||||
# See which combinations have a size that appears in self.bitvecs
|
||||
has_t = set() # type: Set[Tuple[str, int, int]]
|
||||
for l in lanes:
|
||||
for i in ints:
|
||||
if i * l in self.bitvecs:
|
||||
has_t.add(('i', i, l))
|
||||
for i in bools:
|
||||
if i * l in self.bitvecs:
|
||||
has_t.add(('b', i, l))
|
||||
for i in floats:
|
||||
if i * l in self.bitvecs:
|
||||
has_t.add(('f', i, l))
|
||||
|
||||
for (t, width, lane) in has_t:
|
||||
new.lanes.add(lane)
|
||||
if (t == 'i'):
|
||||
new.ints.add(width)
|
||||
elif (t == 'b'):
|
||||
new.bools.add(width)
|
||||
else:
|
||||
assert t == 'f'
|
||||
new.floats.add(width)
|
||||
|
||||
return new
|
||||
else:
|
||||
assert False, "Unknown derived function: " + func
|
||||
|
||||
def size(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
Return the number of concrete types represented by this typeset
|
||||
"""
|
||||
return (len(self.lanes) * (len(self.ints) + len(self.floats) +
|
||||
len(self.bools) + len(self.bitvecs)) +
|
||||
len(self.specials))
|
||||
|
||||
def concrete_types(self):
|
||||
# type: () -> Iterable[types.ValueType]
|
||||
def by(scalar, lanes):
|
||||
# type: (types.LaneType, int) -> types.ValueType
|
||||
if (lanes == 1):
|
||||
return scalar
|
||||
else:
|
||||
return scalar.by(lanes)
|
||||
|
||||
for nlanes in self.lanes:
|
||||
for bits in self.ints:
|
||||
yield by(types.IntType.with_bits(bits), nlanes)
|
||||
for bits in self.floats:
|
||||
yield by(types.FloatType.with_bits(bits), nlanes)
|
||||
for bits in self.bools:
|
||||
yield by(types.BoolType.with_bits(bits), nlanes)
|
||||
for bits in self.bitvecs:
|
||||
assert nlanes == 1
|
||||
yield types.BVType.with_bits(bits)
|
||||
|
||||
for spec in self.specials:
|
||||
yield spec
|
||||
|
||||
def get_singleton(self):
|
||||
# type: () -> types.ValueType
|
||||
"""
|
||||
Return the singleton type represented by self. Can only call on
|
||||
typesets containing 1 type.
|
||||
"""
|
||||
types = list(self.concrete_types())
|
||||
assert len(types) == 1
|
||||
return types[0]
|
||||
|
||||
def widths(self):
|
||||
# type: () -> Set[int]
|
||||
""" Return a set of the widths of all possible types in self"""
|
||||
scalar_w = self.ints.union(self.floats.union(self.bools))
|
||||
scalar_w = scalar_w.union(self.bitvecs)
|
||||
return set(w * l for l in self.lanes for w in scalar_w)
|
||||
|
||||
|
||||
class TypeVar(object):
|
||||
"""
|
||||
Type variables can be used in place of concrete types when defining
|
||||
instructions. This makes the instructions *polymorphic*.
|
||||
|
||||
A type variable is restricted to vary over a subset of the value types.
|
||||
This subset is specified by a set of flags that control the permitted base
|
||||
types and whether the type variable can assume scalar or vector types, or
|
||||
both.
|
||||
|
||||
:param name: Short name of type variable used in instruction descriptions.
|
||||
:param doc: Documentation string.
|
||||
:param ints: Allow all integer base types, or `(min, max)` bit-range.
|
||||
:param floats: Allow all floating point base types, or `(min, max)`
|
||||
bit-range.
|
||||
:param bools: Allow all boolean base types, or `(min, max)` bit-range.
|
||||
:param scalars: Allow type variable to assume scalar types.
|
||||
:param simd: Allow type variable to assume vector types, or `(min, max)`
|
||||
lane count range.
|
||||
:param bitvecs: Allow all BitVec base types, or `(min, max)` bit-range.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name, # type: str
|
||||
doc, # type: str
|
||||
ints=False, # type: BoolInterval
|
||||
floats=False, # type: BoolInterval
|
||||
bools=False, # type: BoolInterval
|
||||
scalars=True, # type: bool
|
||||
simd=False, # type: BoolInterval
|
||||
bitvecs=False, # type: BoolInterval
|
||||
base=None, # type: TypeVar
|
||||
derived_func=None, # type: str
|
||||
specials=None # type: SpecialSpec
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.name = name
|
||||
self.__doc__ = doc
|
||||
self.is_derived = isinstance(base, TypeVar)
|
||||
if base:
|
||||
assert self.is_derived
|
||||
assert derived_func
|
||||
self.base = base
|
||||
self.derived_func = derived_func
|
||||
self.name = '{}({})'.format(derived_func, base.name)
|
||||
else:
|
||||
min_lanes = 1 if scalars else 2
|
||||
lanes = decode_interval(simd, (min_lanes, MAX_LANES), 1)
|
||||
self.type_set = TypeSet(
|
||||
lanes=lanes,
|
||||
ints=ints,
|
||||
floats=floats,
|
||||
bools=bools,
|
||||
bitvecs=bitvecs,
|
||||
specials=specials)
|
||||
|
||||
@staticmethod
|
||||
def singleton(typ):
|
||||
# type: (types.ValueType) -> TypeVar
|
||||
"""Create a type variable that can only assume a single type."""
|
||||
scalar = None # type: types.ValueType
|
||||
if isinstance(typ, types.VectorType):
|
||||
scalar = typ.base
|
||||
lanes = (typ.lanes, typ.lanes)
|
||||
elif isinstance(typ, types.LaneType):
|
||||
scalar = typ
|
||||
lanes = (1, 1)
|
||||
elif isinstance(typ, types.SpecialType):
|
||||
return TypeVar(typ.name, typ.__doc__, specials=[typ])
|
||||
else:
|
||||
assert isinstance(typ, types.BVType)
|
||||
scalar = typ
|
||||
lanes = (1, 1)
|
||||
|
||||
ints = None
|
||||
floats = None
|
||||
bools = None
|
||||
bitvecs = None
|
||||
|
||||
if isinstance(scalar, types.IntType):
|
||||
ints = (scalar.bits, scalar.bits)
|
||||
elif isinstance(scalar, types.FloatType):
|
||||
floats = (scalar.bits, scalar.bits)
|
||||
elif isinstance(scalar, types.BoolType):
|
||||
bools = (scalar.bits, scalar.bits)
|
||||
elif isinstance(scalar, types.BVType):
|
||||
bitvecs = (scalar.bits, scalar.bits)
|
||||
|
||||
tv = TypeVar(
|
||||
typ.name, typ.__doc__,
|
||||
ints=ints, floats=floats, bools=bools,
|
||||
bitvecs=bitvecs, simd=lanes)
|
||||
return tv
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return "`{}`".format(self.name)
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
if self.is_derived:
|
||||
return (
|
||||
'TypeVar({}, base={}, derived_func={})'
|
||||
.format(self.name, self.base, self.derived_func))
|
||||
else:
|
||||
return (
|
||||
'TypeVar({}, {})'
|
||||
.format(self.name, self.type_set))
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
if (not self.is_derived):
|
||||
return object.__hash__(self)
|
||||
|
||||
return hash((self.derived_func, self.base))
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if not isinstance(other, TypeVar):
|
||||
return False
|
||||
if self.is_derived and other.is_derived:
|
||||
return (
|
||||
self.derived_func == other.derived_func and
|
||||
self.base == other.base)
|
||||
else:
|
||||
return self is other
|
||||
|
||||
def __ne__(self, other):
|
||||
# type: (object) -> bool
|
||||
return not self.__eq__(other)
|
||||
|
||||
# Supported functions for derived type variables.
|
||||
# The names here must match the method names on `ir::types::Type`.
|
||||
# The camel_case of the names must match `enum OperandConstraint` in
|
||||
# `instructions.rs`.
|
||||
LANEOF = 'lane_of'
|
||||
ASBOOL = 'as_bool'
|
||||
HALFWIDTH = 'half_width'
|
||||
DOUBLEWIDTH = 'double_width'
|
||||
HALFVECTOR = 'half_vector'
|
||||
DOUBLEVECTOR = 'double_vector'
|
||||
TOBITVEC = 'to_bitvec'
|
||||
|
||||
@staticmethod
|
||||
def is_bijection(func):
|
||||
# type: (str) -> bool
|
||||
return func in [
|
||||
TypeVar.HALFWIDTH,
|
||||
TypeVar.DOUBLEWIDTH,
|
||||
TypeVar.HALFVECTOR,
|
||||
TypeVar.DOUBLEVECTOR]
|
||||
|
||||
@staticmethod
|
||||
def inverse_func(func):
|
||||
# type: (str) -> str
|
||||
return {
|
||||
TypeVar.HALFWIDTH: TypeVar.DOUBLEWIDTH,
|
||||
TypeVar.DOUBLEWIDTH: TypeVar.HALFWIDTH,
|
||||
TypeVar.HALFVECTOR: TypeVar.DOUBLEVECTOR,
|
||||
TypeVar.DOUBLEVECTOR: TypeVar.HALFVECTOR
|
||||
}[func]
|
||||
|
||||
@staticmethod
|
||||
def derived(base, derived_func):
|
||||
# type: (TypeVar, str) -> TypeVar
|
||||
"""Create a type variable that is a function of another."""
|
||||
|
||||
# Safety checks to avoid over/underflows.
|
||||
ts = base.get_typeset()
|
||||
|
||||
assert len(ts.specials) == 0, "Can't derive from special types"
|
||||
|
||||
if derived_func == TypeVar.HALFWIDTH:
|
||||
if len(ts.ints) > 0:
|
||||
assert min(ts.ints) > 8, "Can't halve all integer types"
|
||||
if len(ts.floats) > 0:
|
||||
assert min(ts.floats) > 32, "Can't halve all float types"
|
||||
if len(ts.bools) > 0:
|
||||
assert min(ts.bools) > 8, "Can't halve all boolean types"
|
||||
elif derived_func == TypeVar.DOUBLEWIDTH:
|
||||
if len(ts.ints) > 0:
|
||||
assert max(ts.ints) < MAX_BITS,\
|
||||
"Can't double all integer types."
|
||||
if len(ts.floats) > 0:
|
||||
assert max(ts.floats) < MAX_BITS,\
|
||||
"Can't double all float types."
|
||||
if len(ts.bools) > 0:
|
||||
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
|
||||
elif derived_func == TypeVar.HALFVECTOR:
|
||||
assert min(ts.lanes) > 1, "Can't halve a scalar type"
|
||||
elif derived_func == TypeVar.DOUBLEVECTOR:
|
||||
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
|
||||
|
||||
return TypeVar(None, None, base=base, derived_func=derived_func)
|
||||
|
||||
@staticmethod
|
||||
def from_typeset(ts):
|
||||
# type: (TypeSet) -> TypeVar
|
||||
""" Create a type variable from a type set."""
|
||||
tv = TypeVar(None, None)
|
||||
tv.type_set = ts
|
||||
return tv
|
||||
|
||||
def lane_of(self):
|
||||
# type: () -> TypeVar
|
||||
"""
|
||||
Return a derived type variable that is the scalar lane type of this
|
||||
type variable.
|
||||
|
||||
When this type variable assumes a scalar type, the derived type will be
|
||||
the same scalar type.
|
||||
"""
|
||||
return TypeVar.derived(self, self.LANEOF)
|
||||
|
||||
def as_bool(self):
|
||||
# type: () -> TypeVar
|
||||
"""
|
||||
Return a derived type variable that has the same vector geometry as
|
||||
this type variable, but with boolean lanes. Scalar types map to `b1`.
|
||||
"""
|
||||
return TypeVar.derived(self, self.ASBOOL)
|
||||
|
||||
def half_width(self):
|
||||
# type: () -> TypeVar
|
||||
"""
|
||||
Return a derived type variable that has the same number of vector lanes
|
||||
as this one, but the lanes are half the width.
|
||||
"""
|
||||
return TypeVar.derived(self, self.HALFWIDTH)
|
||||
|
||||
def double_width(self):
|
||||
# type: () -> TypeVar
|
||||
"""
|
||||
Return a derived type variable that has the same number of vector lanes
|
||||
as this one, but the lanes are double the width.
|
||||
"""
|
||||
return TypeVar.derived(self, self.DOUBLEWIDTH)
|
||||
|
||||
def half_vector(self):
|
||||
# type: () -> TypeVar
|
||||
"""
|
||||
Return a derived type variable that has half the number of vector lanes
|
||||
as this one, with the same lane type.
|
||||
"""
|
||||
return TypeVar.derived(self, self.HALFVECTOR)
|
||||
|
||||
def double_vector(self):
|
||||
# type: () -> TypeVar
|
||||
"""
|
||||
Return a derived type variable that has twice the number of vector
|
||||
lanes as this one, with the same lane type.
|
||||
"""
|
||||
return TypeVar.derived(self, self.DOUBLEVECTOR)
|
||||
|
||||
def to_bitvec(self):
|
||||
# type: () -> TypeVar
|
||||
"""
|
||||
Return a derived type variable that represent a flat bitvector with
|
||||
the same size as self
|
||||
"""
|
||||
return TypeVar.derived(self, self.TOBITVEC)
|
||||
|
||||
def singleton_type(self):
|
||||
# type: () -> types.ValueType
|
||||
"""
|
||||
If the associated typeset has a single type return it. Otherwise return
|
||||
None
|
||||
"""
|
||||
ts = self.get_typeset()
|
||||
if ts.size() != 1:
|
||||
return None
|
||||
|
||||
return ts.get_singleton()
|
||||
|
||||
def free_typevar(self):
|
||||
# type: () -> TypeVar
|
||||
"""
|
||||
Get the free type variable controlling this one.
|
||||
"""
|
||||
if self.is_derived:
|
||||
return self.base.free_typevar()
|
||||
elif self.singleton_type() is not None:
|
||||
# A singleton type variable is not a proper free variable.
|
||||
return None
|
||||
else:
|
||||
return self
|
||||
|
||||
def rust_expr(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Get a Rust expression that computes the type of this type variable.
|
||||
"""
|
||||
if self.is_derived:
|
||||
return '{}.{}()'.format(
|
||||
self.base.rust_expr(), self.derived_func)
|
||||
elif self.singleton_type():
|
||||
return self.singleton_type().rust_name()
|
||||
else:
|
||||
return self.name
|
||||
|
||||
def constrain_types_by_ts(self, ts):
|
||||
# type: (TypeSet) -> None
|
||||
"""
|
||||
Constrain the range of types this variable can assume to a subset of
|
||||
those in the typeset ts.
|
||||
"""
|
||||
if not self.is_derived:
|
||||
self.type_set &= ts
|
||||
else:
|
||||
self.base.constrain_types_by_ts(ts.preimage(self.derived_func))
|
||||
|
||||
def constrain_types(self, other):
|
||||
# type: (TypeVar) -> None
|
||||
"""
|
||||
Constrain the range of types this variable can assume to a subset of
|
||||
those `other` can assume.
|
||||
"""
|
||||
if self is other:
|
||||
return
|
||||
|
||||
self.constrain_types_by_ts(other.get_typeset())
|
||||
|
||||
def get_typeset(self):
|
||||
# type: () -> TypeSet
|
||||
"""
|
||||
Returns the typeset for this TV. If the TV is derived, computes it
|
||||
recursively from the derived function and the base's typeset.
|
||||
"""
|
||||
if not self.is_derived:
|
||||
return self.type_set
|
||||
else:
|
||||
return self.base.get_typeset().image(self.derived_func)
|
||||
|
||||
def get_fresh_copy(self, name):
|
||||
# type: (str) -> TypeVar
|
||||
"""
|
||||
Get a fresh copy of self. Can only be called on free typevars.
|
||||
"""
|
||||
assert not self.is_derived
|
||||
tv = TypeVar.from_typeset(self.type_set.copy())
|
||||
tv.name = name
|
||||
return tv
|
||||
423
lib/codegen/meta-python/cdsl/xform.py
Normal file
423
lib/codegen/meta-python/cdsl/xform.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Instruction transformations.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from .ast import Def, Var, Apply
|
||||
from .ti import ti_xform, TypeEnv, get_type_env, TypeConstraint
|
||||
from collections import OrderedDict
|
||||
from functools import reduce
|
||||
|
||||
try:
|
||||
from typing import Union, Iterator, Sequence, Iterable, List, Dict # noqa
|
||||
from typing import Optional, Set # noqa
|
||||
from .ast import Expr, VarAtomMap # noqa
|
||||
from .isa import TargetISA # noqa
|
||||
from .typevar import TypeVar # noqa
|
||||
from .instructions import ConstrList, Instruction # noqa
|
||||
DefApply = Union[Def, Apply]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def canonicalize_defapply(node):
|
||||
# type: (DefApply) -> Def
|
||||
"""
|
||||
Canonicalize a `Def` or `Apply` node into a `Def`.
|
||||
|
||||
An `Apply` becomes a `Def` with an empty list of defs.
|
||||
"""
|
||||
if isinstance(node, Apply):
|
||||
return Def((), node)
|
||||
else:
|
||||
return node
|
||||
|
||||
|
||||
class Rtl(object):
|
||||
"""
|
||||
Register Transfer Language list.
|
||||
|
||||
An RTL object contains a list of register assignments in the form of `Def`
|
||||
objects.
|
||||
|
||||
An RTL list can represent both a source pattern to be matched, or a
|
||||
destination pattern to be inserted.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
# type: (*DefApply) -> None
|
||||
self.rtl = tuple(map(canonicalize_defapply, args))
|
||||
|
||||
def copy(self, m):
|
||||
# type: (VarAtomMap) -> Rtl
|
||||
"""
|
||||
Return a copy of this rtl with all Vars substituted with copies or
|
||||
according to m. Update m as neccessary.
|
||||
"""
|
||||
return Rtl(*[d.copy(m) for d in self.rtl])
|
||||
|
||||
def vars(self):
|
||||
# type: () -> Set[Var]
|
||||
"""Return the set of all Vars in self that correspond to SSA values"""
|
||||
return reduce(lambda x, y: x.union(y),
|
||||
[d.vars() for d in self.rtl],
|
||||
set([]))
|
||||
|
||||
def definitions(self):
|
||||
# type: () -> Set[Var]
|
||||
""" Return the set of all Vars defined in self"""
|
||||
return reduce(lambda x, y: x.union(y),
|
||||
[d.definitions() for d in self.rtl],
|
||||
set([]))
|
||||
|
||||
def free_vars(self):
|
||||
# type: () -> Set[Var]
|
||||
"""Return the set of free Vars corresp. to SSA vals used in self"""
|
||||
def flow_f(s, d):
|
||||
# type: (Set[Var], Def) -> Set[Var]
|
||||
"""Compute the change in the set of free vars across a Def"""
|
||||
s = s.difference(set(d.defs))
|
||||
uses = set(d.expr.args[i] for i in d.expr.inst.value_opnums)
|
||||
for v in uses:
|
||||
assert isinstance(v, Var)
|
||||
s.add(v)
|
||||
|
||||
return s
|
||||
|
||||
return reduce(flow_f, reversed(self.rtl), set([]))
|
||||
|
||||
def substitution(self, other, s):
|
||||
# type: (Rtl, VarAtomMap) -> Optional[VarAtomMap]
|
||||
"""
|
||||
If the Rtl self agrees structurally with the Rtl other, return a
|
||||
substitution to transform self to other. Two Rtls agree structurally if
|
||||
they have the same sequence of Defs, that agree structurally.
|
||||
"""
|
||||
if len(self.rtl) != len(other.rtl):
|
||||
return None
|
||||
|
||||
for i in range(len(self.rtl)):
|
||||
s = self.rtl[i].substitution(other.rtl[i], s)
|
||||
|
||||
if s is None:
|
||||
return None
|
||||
|
||||
return s
|
||||
|
||||
def is_concrete(self):
|
||||
# type: (Rtl) -> bool
|
||||
"""Return True iff every Var in the self has a singleton type."""
|
||||
return all(v.get_typevar().singleton_type() is not None
|
||||
for v in self.vars())
|
||||
|
||||
def cleanup_concrete_rtl(self):
|
||||
# type: (Rtl) -> None
|
||||
"""
|
||||
Given that there is only 1 possible concrete typing T for self, assign
|
||||
a singleton TV with type t=T[v] for each Var v \in self. Its an error
|
||||
to call this on an Rtl with more than 1 possible typing. This modifies
|
||||
the Rtl in-place.
|
||||
"""
|
||||
from .ti import ti_rtl, TypeEnv
|
||||
# 1) Infer the types of all vars in res
|
||||
typenv = get_type_env(ti_rtl(self, TypeEnv()))
|
||||
typenv.normalize()
|
||||
typenv = typenv.extract()
|
||||
|
||||
# 2) Make sure there is only one possible type assignment
|
||||
typings = list(typenv.concrete_typings())
|
||||
assert len(typings) == 1
|
||||
typing = typings[0]
|
||||
|
||||
# 3) Assign the only possible type to each variable.
|
||||
for v in typenv.vars:
|
||||
assert typing[v].singleton_type() is not None
|
||||
v.set_typevar(typing[v])
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return "\n".join(map(str, self.rtl))
|
||||
|
||||
|
||||
class XForm(object):
|
||||
"""
|
||||
An instruction transformation consists of a source and destination pattern.
|
||||
|
||||
Patterns are expressed in *register transfer language* as tuples of
|
||||
`ast.Def` or `ast.Expr` nodes. A pattern may optionally have a sequence of
|
||||
TypeConstraints, that additionally limit the set of cases when it applies.
|
||||
|
||||
A legalization pattern must have a source pattern containing only a single
|
||||
instruction.
|
||||
|
||||
>>> from base.instructions import iconst, iadd, iadd_imm
|
||||
>>> a = Var('a')
|
||||
>>> c = Var('c')
|
||||
>>> v = Var('v')
|
||||
>>> x = Var('x')
|
||||
>>> XForm(
|
||||
... Rtl(c << iconst(v),
|
||||
... a << iadd(x, c)),
|
||||
... Rtl(a << iadd_imm(x, v)))
|
||||
XForm(inputs=[Var(v), Var(x)], defs=[Var(c, src), Var(a, src, dst)],
|
||||
c << iconst(v)
|
||||
a << iadd(x, c)
|
||||
=>
|
||||
a << iadd_imm(x, v)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, src, dst, constraints=None):
|
||||
# type: (Rtl, Rtl, Optional[ConstrList]) -> None
|
||||
self.src = src
|
||||
self.dst = dst
|
||||
# Variables that are inputs to the source pattern.
|
||||
self.inputs = list() # type: List[Var]
|
||||
# Variables defined in either src or dst.
|
||||
self.defs = list() # type: List[Var]
|
||||
|
||||
# Rewrite variables in src and dst RTL lists to our own copies.
|
||||
# Map name -> private Var.
|
||||
symtab = dict() # type: Dict[str, Var]
|
||||
self._rewrite_rtl(src, symtab, Var.SRCCTX)
|
||||
num_src_inputs = len(self.inputs)
|
||||
self._rewrite_rtl(dst, symtab, Var.DSTCTX)
|
||||
# Needed for testing type inference on XForms
|
||||
self.symtab = symtab
|
||||
|
||||
# Check for inconsistently used inputs.
|
||||
for i in self.inputs:
|
||||
if not i.is_input():
|
||||
raise AssertionError(
|
||||
"'{}' used as both input and def".format(i))
|
||||
|
||||
# Check for spurious inputs in dst.
|
||||
if len(self.inputs) > num_src_inputs:
|
||||
raise AssertionError(
|
||||
"extra inputs in dst RTL: {}".format(
|
||||
self.inputs[num_src_inputs:]))
|
||||
|
||||
# Perform type inference and cleanup
|
||||
raw_ti = get_type_env(ti_xform(self, TypeEnv()))
|
||||
raw_ti.normalize()
|
||||
self.ti = raw_ti.extract()
|
||||
|
||||
def interp_tv(tv):
|
||||
# type: (TypeVar) -> TypeVar
|
||||
""" Convert typevars according to symtab """
|
||||
if not tv.name.startswith("typeof_"):
|
||||
return tv
|
||||
return symtab[tv.name[len("typeof_"):]].get_typevar()
|
||||
|
||||
self.constraints = [] # type: List[TypeConstraint]
|
||||
if constraints is not None:
|
||||
if isinstance(constraints, TypeConstraint):
|
||||
constr_list = [constraints] # type: Sequence[TypeConstraint]
|
||||
else:
|
||||
constr_list = constraints
|
||||
|
||||
for c in constr_list:
|
||||
type_m = {tv: interp_tv(tv) for tv in c.tvs()}
|
||||
inner_c = c.translate(type_m)
|
||||
self.constraints.append(inner_c)
|
||||
self.ti.add_constraint(inner_c)
|
||||
|
||||
# Sanity: The set of inferred free typevars should be a subset of the
|
||||
# TVs corresponding to Vars appearing in src
|
||||
free_typevars = set(self.ti.free_typevars())
|
||||
src_vars = set(self.inputs).union(
|
||||
[x for x in self.defs if not x.is_temp()])
|
||||
src_tvs = set([v.get_typevar() for v in src_vars])
|
||||
if (not free_typevars.issubset(src_tvs)):
|
||||
raise AssertionError(
|
||||
"Some free vars don't appear in src - {}"
|
||||
.format(free_typevars.difference(src_tvs)))
|
||||
|
||||
# Update the type vars for each Var to their inferred values
|
||||
for v in self.inputs + self.defs:
|
||||
v.set_typevar(self.ti[v.get_typevar()])
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
s = "XForm(inputs={}, defs={},\n ".format(self.inputs, self.defs)
|
||||
s += '\n '.join(str(n) for n in self.src.rtl)
|
||||
s += '\n=>\n '
|
||||
s += '\n '.join(str(n) for n in self.dst.rtl)
|
||||
s += '\n)'
|
||||
return s
|
||||
|
||||
def _rewrite_rtl(self, rtl, symtab, context):
|
||||
# type: (Rtl, Dict[str, Var], int) -> None
|
||||
for line in rtl.rtl:
|
||||
if isinstance(line, Def):
|
||||
line.defs = tuple(
|
||||
self._rewrite_defs(line, symtab, context))
|
||||
expr = line.expr
|
||||
else:
|
||||
expr = line
|
||||
self._rewrite_expr(expr, symtab, context)
|
||||
|
||||
def _rewrite_expr(self, expr, symtab, context):
|
||||
# type: (Apply, Dict[str, Var], int) -> None
|
||||
"""
|
||||
Find all uses of variables in `expr` and replace them with our own
|
||||
local symbols.
|
||||
"""
|
||||
|
||||
# Accept a whole expression tree.
|
||||
stack = [expr]
|
||||
while len(stack) > 0:
|
||||
expr = stack.pop()
|
||||
expr.args = tuple(
|
||||
self._rewrite_uses(expr, stack, symtab, context))
|
||||
|
||||
def _rewrite_defs(self, line, symtab, context):
|
||||
# type: (Def, Dict[str, Var], int) -> Iterable[Var]
|
||||
"""
|
||||
Given a tuple of symbols defined in a Def, rewrite them to local
|
||||
symbols. Yield the new locals.
|
||||
"""
|
||||
for sym in line.defs:
|
||||
name = str(sym)
|
||||
if name in symtab:
|
||||
var = symtab[name]
|
||||
if var.get_def(context):
|
||||
raise AssertionError("'{}' multiply defined".format(name))
|
||||
else:
|
||||
var = Var(name)
|
||||
symtab[name] = var
|
||||
self.defs.append(var)
|
||||
var.set_def(context, line)
|
||||
yield var
|
||||
|
||||
def _rewrite_uses(self, expr, stack, symtab, context):
|
||||
# type: (Apply, List[Apply], Dict[str, Var], int) -> Iterable[Expr]
|
||||
"""
|
||||
Given an `Apply` expr, rewrite all uses in its arguments to local
|
||||
variables. Yield a sequence of new arguments.
|
||||
|
||||
Append any `Apply` arguments to `stack`.
|
||||
"""
|
||||
for arg, operand in zip(expr.args, expr.inst.ins):
|
||||
# Nested instructions are allowed. Visit recursively.
|
||||
if isinstance(arg, Apply):
|
||||
stack.append(arg)
|
||||
yield arg
|
||||
continue
|
||||
if not isinstance(arg, Var):
|
||||
assert not operand.is_value(), "Value arg must be `Var`"
|
||||
yield arg
|
||||
continue
|
||||
# This is supposed to be a symbolic value reference.
|
||||
name = str(arg)
|
||||
if name in symtab:
|
||||
var = symtab[name]
|
||||
# The variable must be used consistently as a def or input.
|
||||
if not var.is_input() and not var.get_def(context):
|
||||
raise AssertionError(
|
||||
"'{}' used as both input and def"
|
||||
.format(name))
|
||||
else:
|
||||
# First time use of variable.
|
||||
var = Var(name)
|
||||
symtab[name] = var
|
||||
self.inputs.append(var)
|
||||
yield var
|
||||
|
||||
def verify_legalize(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Verify that this is a valid legalization XForm.
|
||||
|
||||
- The source pattern must describe a single instruction.
|
||||
- All values defined in the output pattern must be defined in the
|
||||
destination pattern.
|
||||
"""
|
||||
assert len(self.src.rtl) == 1, "Legalize needs single instruction."
|
||||
for d in self.src.rtl[0].defs:
|
||||
if not d.is_output():
|
||||
raise AssertionError(
|
||||
'{} not defined in dest pattern'.format(d))
|
||||
|
||||
def apply(self, r, suffix=None):
|
||||
# type: (Rtl, str) -> Rtl
|
||||
"""
|
||||
Given a concrete Rtl r s.t. r matches self.src, return the
|
||||
corresponding concrete self.dst. If suffix is provided, any temporary
|
||||
defs are renamed with '.suffix' appended to their old name.
|
||||
"""
|
||||
assert r.is_concrete()
|
||||
s = self.src.substitution(r, {}) # type: VarAtomMap
|
||||
assert s is not None
|
||||
|
||||
if (suffix is not None):
|
||||
for v in self.dst.vars():
|
||||
if v.is_temp():
|
||||
assert v not in s
|
||||
s[v] = Var(v.name + '.' + suffix)
|
||||
|
||||
dst = self.dst.copy(s)
|
||||
dst.cleanup_concrete_rtl()
|
||||
return dst
|
||||
|
||||
|
||||
class XFormGroup(object):
|
||||
"""
|
||||
A group of related transformations.
|
||||
|
||||
:param isa: A target ISA whose instructions are allowed.
|
||||
:param chain: A next level group to try if this one doesn't match.
|
||||
"""
|
||||
|
||||
def __init__(self, name, doc, isa=None, chain=None):
|
||||
# type: (str, str, TargetISA, XFormGroup) -> None
|
||||
self.xforms = list() # type: List[XForm]
|
||||
self.custom = OrderedDict() # type: OrderedDict[Instruction, str]
|
||||
self.name = name
|
||||
self.__doc__ = doc
|
||||
self.isa = isa
|
||||
self.chain = chain
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
if self.isa:
|
||||
return '{}.{}'.format(self.isa.name, self.name)
|
||||
else:
|
||||
return self.name
|
||||
|
||||
def rust_name(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Get the Rust name of this function implementing this transform.
|
||||
"""
|
||||
if self.isa:
|
||||
# This is a function in the same module as the LEGALIZE_ACTION
|
||||
# table referring to it.
|
||||
return self.name
|
||||
else:
|
||||
return '::legalizer::{}'.format(self.name)
|
||||
|
||||
def legalize(self, src, dst):
|
||||
# type: (Union[Def, Apply], Rtl) -> None
|
||||
"""
|
||||
Add a legalization pattern to this group.
|
||||
|
||||
:param src: Single `Def` or `Apply` to be legalized.
|
||||
:param dst: `Rtl` list of replacement instructions.
|
||||
"""
|
||||
xform = XForm(Rtl(src), dst)
|
||||
xform.verify_legalize()
|
||||
self.xforms.append(xform)
|
||||
|
||||
def custom_legalize(self, inst, funcname):
|
||||
# type: (Instruction, str) -> None
|
||||
"""
|
||||
Add a custom legalization action for `inst`.
|
||||
|
||||
The `funcname` parameter is the fully qualified name of a Rust function
|
||||
which takes the same arguments as the `isa::Legalize` actions.
|
||||
|
||||
The custom function will be called to legalize `inst` and any return
|
||||
value is ignored.
|
||||
"""
|
||||
assert inst not in self.custom, "Duplicate custom_legalize"
|
||||
self.custom[inst] = funcname
|
||||
Reference in New Issue
Block a user