Add better type inference and encapsulate it in its own file (#110)
* Add more rigorous type inference and encapsulate the type inferece code in its own file (ti.py). Add constraints accumulation during type inference, to represent constraints that cannot be expressed using bijective derivation functions between typevars. Add testing for new type inference code. * Additional annotations to appease mypy
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
f867ddbf0c
commit
a5c96ef6bf
@@ -100,8 +100,8 @@ class Var(Expr):
|
||||
# TypeVar representing the type of this variable.
|
||||
self.typevar = None # type: TypeVar
|
||||
# The original 'typeof(x)' type variable that was created for this Var.
|
||||
# This one doesn't change. `self.typevar` above may be joined with
|
||||
# other typevars.
|
||||
# This one doesn't change. `self.typevar` above may be changed to
|
||||
# another typevar by type inference.
|
||||
self.original_typevar = None # type: TypeVar
|
||||
|
||||
def __str__(self):
|
||||
@@ -180,16 +180,9 @@ class Var(Expr):
|
||||
self.typevar = tv
|
||||
return self.typevar
|
||||
|
||||
def link_typevar(self, base, derived_func):
|
||||
# type: (TypeVar, str) -> None
|
||||
"""
|
||||
Link the type variable on this Var to the type variable `base` using
|
||||
`derived_func`.
|
||||
"""
|
||||
self.original_typevar = None
|
||||
self.typevar.change_to_derived(base, derived_func)
|
||||
# Possibly eliminate redundant SAMEAS links.
|
||||
self.typevar = self.typevar.strip_sameas()
|
||||
def set_typevar(self, tv):
|
||||
# type: (TypeVar) -> None
|
||||
self.typevar = tv
|
||||
|
||||
def has_free_typevar(self):
|
||||
# type: () -> bool
|
||||
@@ -213,136 +206,6 @@ class Var(Expr):
|
||||
"""
|
||||
return self.typevar.rust_expr()
|
||||
|
||||
def constrain_typevar(self, sym_typevar, sym_ctrl, ctrl_var):
|
||||
# type: (TypeVar, TypeVar, Var) -> None
|
||||
"""
|
||||
Constrain the set of allowed types for this variable.
|
||||
|
||||
Merge type variables for the involved variables to minimize the set for
|
||||
free type variables.
|
||||
|
||||
Suppose we're looking at an instruction defined like this:
|
||||
|
||||
c = Operand('c', TxN.as_bool())
|
||||
x = Operand('x', TxN)
|
||||
y = Operand('y', TxN)
|
||||
a = Operand('a', TxN)
|
||||
vselect = Instruction('vselect', ins=(c, x, y), outs=a)
|
||||
|
||||
And suppose the instruction is used in a pattern like this:
|
||||
|
||||
v0 << vselect(v1, v2, v3)
|
||||
|
||||
We want to reconcile the types of the variables v0-v3 with the
|
||||
constraints from the definition of vselect. This means that v0, v2, and
|
||||
v3 must all have the same type, and v1 must have the type
|
||||
`typeof(v2).as_bool()`.
|
||||
|
||||
The types are reconciled by calling this function once for each
|
||||
input/output operand on the instruction in the pattern with these
|
||||
arguments.
|
||||
|
||||
:param sym_typevar: Symbolic type variable constraining this variable
|
||||
in the definition of the instruction.
|
||||
:param sym_ctrl: Controlling type variable of `sym_typevar` in the
|
||||
definition of the instruction.
|
||||
:param ctrl_var: Variable determining the type of `sym_ctrl`.
|
||||
|
||||
When processing `v1` as used in the pattern above, we would get:
|
||||
|
||||
- self: v1
|
||||
- sym_typevar: TxN.as_bool()
|
||||
- sym_ctrl: TxN
|
||||
- ctrl_var: v2
|
||||
|
||||
Here, 'v2' represents the controlling variable because of how the
|
||||
`Ternary` instruction format is defined with `typevar_operand=1`.
|
||||
"""
|
||||
# First check if sym_typevar is tied to the controlling type variable
|
||||
# in the instruction definition. We also allow free type variables on
|
||||
# instruction inputs that can't be tied to anything else.
|
||||
#
|
||||
# This also covers non-polymorphic instructions and other cases where
|
||||
# we don't have a Var representing the controlling type variable.
|
||||
sym_free_var = sym_typevar.free_typevar()
|
||||
if not sym_free_var or sym_free_var is not sym_ctrl or not ctrl_var:
|
||||
# Just constrain our type to be compatible with the required
|
||||
# typeset.
|
||||
self.get_typevar().constrain_types(sym_typevar)
|
||||
return
|
||||
|
||||
# Now sym_typevar is known to be tied to (or identical to) the
|
||||
# controlling type variable.
|
||||
|
||||
if not self.typevar:
|
||||
# If this variable is not yet constrained, just infer its type and
|
||||
# link it to the controlling type variable.
|
||||
if not sym_typevar.is_derived:
|
||||
assert sym_typevar is sym_ctrl
|
||||
# Identity mapping.
|
||||
# Note that `self == ctrl_var` is both possible and common.
|
||||
self.typevar = ctrl_var.get_typevar()
|
||||
else:
|
||||
assert self is not ctrl_var, (
|
||||
'Impossible type constraints for {}: {}'
|
||||
.format(self, sym_typevar))
|
||||
# Create a derived type variable identical to sym_typevar, but
|
||||
# with a different base.
|
||||
self.typevar = TypeVar.derived(
|
||||
ctrl_var.get_typevar(),
|
||||
sym_typevar.derived_func)
|
||||
# Match the type set constraints of the instruction.
|
||||
self.typevar.constrain_types(sym_typevar)
|
||||
return
|
||||
|
||||
# We already have a self.typevar describing our constraints. We need to
|
||||
# reconcile with the additional constraints.
|
||||
|
||||
# It's likely that ctrl_var and self already share a type
|
||||
# variable. (Often because `ctrl_var == self`).
|
||||
if ctrl_var.typevar == self.typevar:
|
||||
return
|
||||
|
||||
if not sym_typevar.is_derived:
|
||||
assert sym_typevar is sym_ctrl
|
||||
# sym_typevar is a direct use of sym_ctrl, so we need to reconcile
|
||||
# self with ctrl_var.
|
||||
assert not sym_typevar.is_derived
|
||||
self.typevar.constrain_types(sym_typevar)
|
||||
|
||||
# It's possible that ctrl_var has not yet been assigned a type
|
||||
# variable.
|
||||
if not ctrl_var.typevar:
|
||||
ctrl_var.typevar = self.typevar
|
||||
return
|
||||
|
||||
# We can also bind variables with a free type variable to another
|
||||
# variable. Prefer to do this to temps because they aren't allowed
|
||||
# to be free,
|
||||
if self.is_temp() and self.has_free_typevar():
|
||||
self.link_typevar(ctrl_var.typevar, TypeVar.SAMEAS)
|
||||
return
|
||||
if ctrl_var.is_temp() and ctrl_var.has_free_typevar():
|
||||
ctrl_var.link_typevar(self.typevar, TypeVar.SAMEAS)
|
||||
return
|
||||
if self.has_free_typevar():
|
||||
self.link_typevar(ctrl_var.typevar, TypeVar.SAMEAS)
|
||||
return
|
||||
if ctrl_var.has_free_typevar():
|
||||
ctrl_var.link_typevar(self.typevar, TypeVar.SAMEAS)
|
||||
return
|
||||
|
||||
# TODO: Other cases are harder to handle.
|
||||
#
|
||||
# - If either variable is an independent free type variable, it
|
||||
# should be changed to be linked to the other.
|
||||
# - If both variable are free, we should pick one to link to the
|
||||
# other. In particular, if one is a temp, it should be linked.
|
||||
else:
|
||||
# sym_typevar is derived from sym_ctrl.
|
||||
# TODO: Other cases are harder to handle.
|
||||
pass
|
||||
|
||||
|
||||
class Apply(Expr):
|
||||
"""
|
||||
|
||||
432
lib/cretonne/meta/cdsl/test_ti.py
Normal file
432
lib/cretonne/meta/cdsl/test_ti.py
Normal file
@@ -0,0 +1,432 @@
|
||||
from __future__ import absolute_import
|
||||
from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\
|
||||
b1, icmp, iadd_cout, iadd_cin
|
||||
from base.legalize import narrow, expand
|
||||
from base.immediates import intcc
|
||||
from .typevar import TypeVar
|
||||
from .ast import Var, Def
|
||||
from .xform import Rtl, XForm
|
||||
from .ti import ti_rtl, subst, TypeEnv, get_type_env
|
||||
from unittest import TestCase
|
||||
from functools import reduce
|
||||
|
||||
try:
|
||||
from .ti import TypeMap, ConstraintList, VarMap, TypingOrError # noqa
|
||||
from .ti import Constraint
|
||||
from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa
|
||||
except ImportError:
|
||||
TYPE_CHECKING = False
|
||||
|
||||
|
||||
def sort_constr(c):
|
||||
# type: (Constraint) -> Constraint
|
||||
"""
|
||||
Sort the 2 typevars in a constraint by name for comparison
|
||||
"""
|
||||
r = tuple(sorted(c, key=lambda y: y.name))
|
||||
if TYPE_CHECKING:
|
||||
return cast(Constraint, r)
|
||||
else:
|
||||
return r
|
||||
|
||||
|
||||
def agree(me, other):
|
||||
# type: (TypeEnv, TypeEnv) -> bool
|
||||
"""
|
||||
Given TypeEnvs me and other, check if they agree. As part of that build
|
||||
a map m from TVs in me to their corresponding TVs in other.
|
||||
Specifically:
|
||||
|
||||
1. Check that all TVs that are keys in me.type_map are also defined
|
||||
in other.type_map
|
||||
|
||||
2. For any tv in me.type_map check that:
|
||||
me[tv].get_typeset() == other[tv].get_typeset()
|
||||
|
||||
3. Set m[me[tv]] = other[tv] in the substitution m
|
||||
|
||||
4. If we find another tv1 such that me[tv1] == me[tv], assert that
|
||||
other[tv1] == m[me[tv1]] == m[me[tv]] = other[tv]
|
||||
|
||||
5. Check that me and other have the same constraints under the
|
||||
substitution m
|
||||
"""
|
||||
m = {} # type: TypeMap
|
||||
# Check that our type map and other's agree and built substitution m
|
||||
for tv in me.type_map:
|
||||
if (me[tv] not in m):
|
||||
m[me[tv]] = other[tv]
|
||||
if me[tv].get_typeset() != other[tv].get_typeset():
|
||||
return False
|
||||
else:
|
||||
if m[me[tv]] != other[tv]:
|
||||
return False
|
||||
|
||||
# Tranlsate our constraints using m, and sort
|
||||
me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints]
|
||||
me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr])
|
||||
|
||||
# Sort other's constraints
|
||||
other_equiv_constr = sorted([sort_constr(x) for x in other.constraints],
|
||||
key=lambda y: y[0].name)
|
||||
|
||||
return me_equiv_constr == other_equiv_constr
|
||||
|
||||
|
||||
def check_typing(got_or_err, expected, symtab=None):
|
||||
# type: (TypingOrError, Tuple[VarMap, ConstraintList], Dict[str, Var]) -> None # noqa
|
||||
"""
|
||||
Check that a the typying we received (got_or_err) complies with the
|
||||
expected typing (expected). If symtab is specified, substitute the Vars in
|
||||
expected using symtab first (used when checking type inference on XForms)
|
||||
"""
|
||||
(m, c) = expected
|
||||
got = get_type_env(got_or_err)
|
||||
|
||||
if (symtab is not None):
|
||||
# For xforms we first need to re-write our TVs in terms of the tvs
|
||||
# stored internally in the XForm. Use the symtab passed
|
||||
subst_m = {k.get_typevar(): symtab[str(k)].get_typevar()
|
||||
for k in m.keys()}
|
||||
# Convert m from a Var->TypeVar map to TypeVar->TypeVar map where
|
||||
# the key TypeVar is re-written to its XForm internal version
|
||||
tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()}
|
||||
# Rewrite the TVs in the input constraints to their XForm internal
|
||||
# versions
|
||||
c = [(subst(a, subst_m), subst(b, subst_m)) for (a, b) in c]
|
||||
else:
|
||||
# If no symtab, just convert m from Var->TypeVar map to a
|
||||
# TypeVar->TypeVar map
|
||||
tv_m = {k.get_typevar(): v for (k, v) in m.items()}
|
||||
|
||||
expected_typ = TypeEnv((tv_m, c))
|
||||
assert agree(expected_typ, got), \
|
||||
"typings disagree:\n {} \n {}".format(got.dot(),
|
||||
expected_typ.dot())
|
||||
|
||||
|
||||
def check_concrete_typing_rtl(var_types, rtl):
|
||||
# type: (VarMap, Rtl) -> None
|
||||
"""
|
||||
Check that a concrete type assignment var_types (Dict[Var, TypeVar]) is
|
||||
valid for an Rtl rtl. Specifically check that:
|
||||
|
||||
1) For each Var v \in rtl, v is defined in var_types
|
||||
|
||||
2) For all v, var_types[v] is a singleton type
|
||||
|
||||
3) For each v, and each location u, where v is used with expected type
|
||||
tv_u, var_types[v].get_typeset() is a subset of
|
||||
subst(tv_u, m).get_typeset() where m is the substitution of
|
||||
formals->actuals we are building so far.
|
||||
|
||||
4) If tv_u is non-derived and not in m, set m[tv_u]= var_types[v]
|
||||
"""
|
||||
for d in rtl.rtl:
|
||||
assert isinstance(d, Def)
|
||||
inst = d.expr.inst
|
||||
# Accumulate all actual TVs for value defs/opnums in actual_tvs
|
||||
actual_tvs = [var_types[d.defs[i]] for i in inst.value_results]
|
||||
for v in [d.expr.args[i] for i in inst.value_opnums]:
|
||||
assert isinstance(v, Var)
|
||||
actual_tvs.append(var_types[v])
|
||||
|
||||
# Accumulate all formal TVs for value defs/opnums in actual_tvs
|
||||
formal_tvs = [inst.outs[i].typevar for i in inst.value_results] +\
|
||||
[inst.ins[i].typevar for i in inst.value_opnums]
|
||||
m = {} # type: TypeMap
|
||||
|
||||
# For each actual/formal pair check that they agree
|
||||
for (actual_tv, formal_tv) in zip(actual_tvs, formal_tvs):
|
||||
# actual should be a singleton
|
||||
assert actual_tv.singleton_type() is not None
|
||||
formal_tv = subst(formal_tv, m)
|
||||
# actual should agree with the concretized formal
|
||||
assert actual_tv.get_typeset().issubset(formal_tv.get_typeset())
|
||||
|
||||
if formal_tv not in m and not formal_tv.is_derived:
|
||||
m[formal_tv] = actual_tv
|
||||
|
||||
|
||||
def check_concrete_typing_xform(var_types, xform):
|
||||
# type: (VarMap, XForm) -> None
|
||||
"""
|
||||
Check a concrete type assignment var_types for an XForm xform
|
||||
"""
|
||||
check_concrete_typing_rtl(var_types, xform.src)
|
||||
check_concrete_typing_rtl(var_types, xform.dst)
|
||||
|
||||
|
||||
class TypeCheckingBaseTest(TestCase):
|
||||
def setUp(self):
|
||||
# type: () -> None
|
||||
self.v0 = Var("v0")
|
||||
self.v1 = Var("v1")
|
||||
self.v2 = Var("v2")
|
||||
self.v3 = Var("v3")
|
||||
self.v4 = Var("v4")
|
||||
self.v5 = Var("v5")
|
||||
self.v6 = Var("v6")
|
||||
self.v7 = Var("v7")
|
||||
self.v8 = Var("v8")
|
||||
self.v9 = Var("v9")
|
||||
self.imm0 = Var("imm0")
|
||||
self.IxN_nonscalar = TypeVar("IxN_nonscalar", "", ints=True,
|
||||
scalars=False, simd=True)
|
||||
self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True,
|
||||
scalars=False, simd=True)
|
||||
self.b1 = TypeVar.singleton(b1)
|
||||
|
||||
|
||||
class TestRTL(TypeCheckingBaseTest):
|
||||
def test_bad_rtl1(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
(self.v0, self.v1) << vsplit(self.v2),
|
||||
self.v3 << vconcat(self.v0, self.v2),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
self.assertEqual(ti_rtl(r, ti),
|
||||
"On line 1: fail ti on `typeof_v2` <: `2`: " +
|
||||
"Error: empty type created when unifying " +
|
||||
"`typeof_v2` and `half_vector(typeof_v2)`")
|
||||
|
||||
def test_vselect(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v0 << vselect(self.v1, self.v2, self.v3),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
txn = self.TxN.get_fresh_copy("TxN1")
|
||||
check_typing(typing, ({
|
||||
self.v0: txn,
|
||||
self.v1: txn.as_bool(),
|
||||
self.v2: txn,
|
||||
self.v3: txn
|
||||
}, []))
|
||||
|
||||
def test_vselect_icmpimm(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v0 << iconst(self.imm0),
|
||||
self.v1 << icmp(intcc.eq, self.v2, self.v0),
|
||||
self.v5 << vselect(self.v1, self.v3, self.v4),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
ixn = self.IxN_nonscalar.get_fresh_copy("IxN1")
|
||||
txn = self.TxN.get_fresh_copy("TxN1")
|
||||
check_typing(typing, ({
|
||||
self.v0: ixn,
|
||||
self.v1: ixn.as_bool(),
|
||||
self.v2: ixn,
|
||||
self.v3: txn,
|
||||
self.v4: txn,
|
||||
self.v5: txn,
|
||||
}, [(ixn.as_bool(), txn.as_bool())]))
|
||||
|
||||
def test_vselect_vsplits(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v3 << vselect(self.v0, self.v1, self.v2),
|
||||
(self.v4, self.v5) << vsplit(self.v3),
|
||||
(self.v6, self.v7) << vsplit(self.v4),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
t = TypeVar("t", "", ints=True, bools=True, floats=True,
|
||||
simd=(4, 256))
|
||||
check_typing(typing, ({
|
||||
self.v0: t.as_bool(),
|
||||
self.v1: t,
|
||||
self.v2: t,
|
||||
self.v3: t,
|
||||
self.v4: t.half_vector(),
|
||||
self.v5: t.half_vector(),
|
||||
self.v6: t.half_vector().half_vector(),
|
||||
self.v7: t.half_vector().half_vector(),
|
||||
}, []))
|
||||
|
||||
def test_vselect_vconcats(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v3 << vselect(self.v0, self.v1, self.v2),
|
||||
self.v8 << vconcat(self.v3, self.v3),
|
||||
self.v9 << vconcat(self.v8, self.v8),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
t = TypeVar("t", "", ints=True, bools=True, floats=True,
|
||||
simd=(2, 64))
|
||||
check_typing(typing, ({
|
||||
self.v0: t.as_bool(),
|
||||
self.v1: t,
|
||||
self.v2: t,
|
||||
self.v3: t,
|
||||
self.v8: t.double_vector(),
|
||||
self.v9: t.double_vector().double_vector(),
|
||||
}, []))
|
||||
|
||||
def test_vselect_vsplits_vconcats(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v3 << vselect(self.v0, self.v1, self.v2),
|
||||
(self.v4, self.v5) << vsplit(self.v3),
|
||||
(self.v6, self.v7) << vsplit(self.v4),
|
||||
self.v8 << vconcat(self.v3, self.v3),
|
||||
self.v9 << vconcat(self.v8, self.v8),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
t = TypeVar("t", "", ints=True, bools=True, floats=True,
|
||||
simd=(4, 64))
|
||||
check_typing(typing, ({
|
||||
self.v0: t.as_bool(),
|
||||
self.v1: t,
|
||||
self.v2: t,
|
||||
self.v3: t,
|
||||
self.v4: t.half_vector(),
|
||||
self.v5: t.half_vector(),
|
||||
self.v6: t.half_vector().half_vector(),
|
||||
self.v7: t.half_vector().half_vector(),
|
||||
self.v8: t.double_vector(),
|
||||
self.v9: t.double_vector().double_vector(),
|
||||
}, []))
|
||||
|
||||
def test_bint(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v4 << iadd(self.v1, self.v2),
|
||||
self.v5 << bint(self.v3),
|
||||
self.v0 << iadd(self.v4, self.v5)
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
itype = TypeVar("t", "", ints=True, simd=(1, 256))
|
||||
btype = TypeVar("b", "", bools=True, simd=True)
|
||||
|
||||
# Check that self.v5 gets the same integer type as
|
||||
# the rest of them
|
||||
# TODO: Add constraint nlanes(v3) == nlanes(v1) when we
|
||||
# add that type constraint to bint
|
||||
check_typing(typing, ({
|
||||
self.v1: itype,
|
||||
self.v2: itype,
|
||||
self.v4: itype,
|
||||
self.v5: itype,
|
||||
self.v3: btype,
|
||||
self.v0: itype,
|
||||
}, []))
|
||||
|
||||
|
||||
class TestXForm(TypeCheckingBaseTest):
|
||||
def test_iadd_cout(self):
|
||||
# type: () -> None
|
||||
x = XForm(Rtl((self.v0, self.v1) << iadd_cout(self.v2, self.v3),),
|
||||
Rtl(
|
||||
self.v0 << iadd(self.v2, self.v3),
|
||||
self.v1 << icmp(intcc.ult, self.v0, self.v2)
|
||||
))
|
||||
itype = TypeVar("t", "", ints=True, simd=(1, 1))
|
||||
|
||||
check_typing(x.ti, ({
|
||||
self.v0: itype,
|
||||
self.v2: itype,
|
||||
self.v3: itype,
|
||||
self.v1: itype.as_bool(),
|
||||
}, []), x.symtab)
|
||||
|
||||
def test_iadd_cin(self):
|
||||
# type: () -> None
|
||||
x = XForm(Rtl(self.v0 << iadd_cin(self.v1, self.v2, self.v3)),
|
||||
Rtl(
|
||||
self.v4 << iadd(self.v1, self.v2),
|
||||
self.v5 << bint(self.v3),
|
||||
self.v0 << iadd(self.v4, self.v5)
|
||||
))
|
||||
itype = TypeVar("t", "", ints=True, simd=(1, 1))
|
||||
|
||||
check_typing(x.ti, ({
|
||||
self.v0: itype,
|
||||
self.v1: itype,
|
||||
self.v2: itype,
|
||||
self.v3: self.b1,
|
||||
self.v4: itype,
|
||||
self.v5: itype,
|
||||
}, []), x.symtab)
|
||||
|
||||
def test_enumeration_with_constraints(self):
|
||||
# type: () -> None
|
||||
xform = XForm(
|
||||
Rtl(
|
||||
self.v0 << iconst(self.imm0),
|
||||
self.v1 << icmp(intcc.eq, self.v2, self.v0),
|
||||
self.v5 << vselect(self.v1, self.v3, self.v4)
|
||||
),
|
||||
Rtl(
|
||||
self.v0 << iconst(self.imm0),
|
||||
self.v1 << icmp(intcc.eq, self.v2, self.v0),
|
||||
self.v5 << vselect(self.v1, self.v3, self.v4)
|
||||
))
|
||||
|
||||
# Check all var assigns are correct
|
||||
assert len(xform.ti.constraints) > 0
|
||||
concrete_var_assigns = list(xform.ti.concrete_typings())
|
||||
|
||||
v0 = xform.symtab[str(self.v0)]
|
||||
v1 = xform.symtab[str(self.v1)]
|
||||
v2 = xform.symtab[str(self.v2)]
|
||||
v3 = xform.symtab[str(self.v3)]
|
||||
v4 = xform.symtab[str(self.v4)]
|
||||
v5 = xform.symtab[str(self.v5)]
|
||||
|
||||
for var_m in concrete_var_assigns:
|
||||
assert var_m[v0] == var_m[v2] and \
|
||||
var_m[v3] == var_m[v4] and\
|
||||
var_m[v5] == var_m[v3] and\
|
||||
var_m[v1] == var_m[v2].as_bool() and\
|
||||
var_m[v1].get_typeset() == var_m[v3].as_bool().get_typeset()
|
||||
check_concrete_typing_xform(var_m, xform)
|
||||
|
||||
# The number of possible typings here is:
|
||||
# 8 cases for v0 = i8xN times 2 options for v3 - i8, b8 = 16
|
||||
# 8 cases for v0 = i16xN times 2 options for v3 - i16, b16 = 16
|
||||
# 8 cases for v0 = i32xN times 3 options for v3 - i32, b32, f32 = 24
|
||||
# 8 cases for v0 = i64xN times 3 options for v3 - i64, b64, f64 = 24
|
||||
#
|
||||
# (Note we have 8 cases for lanes since vselect prevents scalars)
|
||||
# Total: 2*16 + 2*24 = 80
|
||||
assert len(concrete_var_assigns) == 80
|
||||
|
||||
def test_base_legalizations_enumeration(self):
|
||||
# type: () -> None
|
||||
for xform in narrow.xforms + expand.xforms:
|
||||
# Any legalization patterns we defined should have at least 1
|
||||
# concrete typing
|
||||
concrete_typings_list = list(xform.ti.concrete_typings())
|
||||
assert len(concrete_typings_list) > 0
|
||||
|
||||
# If there are no free_typevars, this is a non-polymorphic pattern.
|
||||
# There should be only one possible concrete typing.
|
||||
if (len(xform.free_typevars) == 0):
|
||||
assert len(concrete_typings_list) == 1
|
||||
continue
|
||||
|
||||
# For any patterns where the type env includes constraints, at
|
||||
# least one of the "theoretically possible" concrete typings must
|
||||
# be prevented by the constraints. (i.e. we are not emitting
|
||||
# unneccessary constraints).
|
||||
# We check that by asserting that the number of concrete typings is
|
||||
# less than the number of all possible free typevar assignments
|
||||
if (len(xform.ti.constraints) > 0):
|
||||
theoretical_num_typings =\
|
||||
reduce(lambda x, y: x*y,
|
||||
[tv.get_typeset().size()
|
||||
for tv in xform.free_typevars], 1)
|
||||
assert len(concrete_typings_list) < theoretical_num_typings
|
||||
|
||||
# Check the validity of each individual concrete typing against the
|
||||
# xform
|
||||
for concrete_typing in concrete_typings_list:
|
||||
check_concrete_typing_xform(concrete_typing, xform)
|
||||
@@ -125,59 +125,56 @@ class TestTypeSet(TestCase):
|
||||
self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(),
|
||||
i32.by(4))
|
||||
|
||||
def test_map_inverse(self):
|
||||
def test_preimage(self):
|
||||
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
|
||||
self.assertEqual(t, t.map_inverse(TypeVar.SAMEAS))
|
||||
self.assertEqual(t, t.preimage(TypeVar.SAMEAS))
|
||||
|
||||
# LANEOF
|
||||
self.assertEqual(TypeSet(lanes=True, ints=(8, 8), floats=(32, 32)),
|
||||
t.map_inverse(TypeVar.LANEOF))
|
||||
t.preimage(TypeVar.LANEOF))
|
||||
# Inverse of empty set is still empty across LANEOF
|
||||
self.assertEqual(TypeSet(),
|
||||
TypeSet().map_inverse(TypeVar.LANEOF))
|
||||
TypeSet().preimage(TypeVar.LANEOF))
|
||||
|
||||
# ASBOOL
|
||||
t = TypeSet(lanes=(1, 4), bools=(1, 64))
|
||||
self.assertEqual(t.map_inverse(TypeVar.ASBOOL),
|
||||
self.assertEqual(t.preimage(TypeVar.ASBOOL),
|
||||
TypeSet(lanes=(1, 4), ints=True, bools=True,
|
||||
floats=True))
|
||||
|
||||
# Inverse image across ASBOOL of TS not involving b1 cannot have
|
||||
# lanes=1
|
||||
t = TypeSet(lanes=(1, 4), bools=(16, 32))
|
||||
self.assertEqual(t.map_inverse(TypeVar.ASBOOL),
|
||||
TypeSet(lanes=(2, 4), ints=(16, 32), bools=(16, 32),
|
||||
floats=(32, 32)))
|
||||
|
||||
# Half/Double Vector
|
||||
t = TypeSet(lanes=(1, 1), ints=(8, 8))
|
||||
t1 = TypeSet(lanes=(256, 256), ints=(8, 8))
|
||||
self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR).size(), 0)
|
||||
self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR).size(), 0)
|
||||
self.assertEqual(t.preimage(TypeVar.DOUBLEVECTOR).size(), 0)
|
||||
self.assertEqual(t1.preimage(TypeVar.HALFVECTOR).size(), 0)
|
||||
|
||||
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 32))
|
||||
t1 = TypeSet(lanes=(64, 256), bools=(1, 32))
|
||||
|
||||
self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR),
|
||||
self.assertEqual(t.preimage(TypeVar.DOUBLEVECTOR),
|
||||
TypeSet(lanes=(1, 8), ints=(8, 16), floats=(32, 32)))
|
||||
self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR),
|
||||
self.assertEqual(t1.preimage(TypeVar.HALFVECTOR),
|
||||
TypeSet(lanes=(128, 256), bools=(1, 32)))
|
||||
|
||||
# Half/Double Width
|
||||
t = TypeSet(ints=(8, 8), floats=(32, 32), bools=(1, 8))
|
||||
t1 = TypeSet(ints=(64, 64), floats=(64, 64), bools=(64, 64))
|
||||
self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH).size(), 0)
|
||||
self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH).size(), 0)
|
||||
self.assertEqual(t.preimage(TypeVar.DOUBLEWIDTH).size(), 0)
|
||||
self.assertEqual(t1.preimage(TypeVar.HALFWIDTH).size(), 0)
|
||||
|
||||
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 64))
|
||||
t1 = TypeSet(lanes=(64, 256), bools=(1, 64))
|
||||
|
||||
self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH),
|
||||
self.assertEqual(t.preimage(TypeVar.DOUBLEWIDTH),
|
||||
TypeSet(lanes=(1, 16), ints=(8, 8), floats=(32, 32)))
|
||||
self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH),
|
||||
self.assertEqual(t1.preimage(TypeVar.HALFWIDTH),
|
||||
TypeSet(lanes=(64, 256), bools=(16, 64)))
|
||||
|
||||
|
||||
def has_non_bijective_derived_f(iterable):
|
||||
return any(not TypeVar.is_bijection(x) for x in iterable)
|
||||
|
||||
|
||||
class TestTypeVar(TestCase):
|
||||
def test_functions(self):
|
||||
x = TypeVar('x', 'all ints', ints=True)
|
||||
@@ -220,7 +217,7 @@ class TestTypeVar(TestCase):
|
||||
self.assertEqual(len(x.type_set.bools), 0)
|
||||
|
||||
def test_stress_constrain_types(self):
|
||||
# Get all 49 possible derived vars of lentgh 2. Since we have SAMEAS
|
||||
# Get all 49 possible derived vars of length 2. Since we have SAMEAS
|
||||
# this includes singly derived and non-derived vars
|
||||
funcs = [TypeVar.SAMEAS, TypeVar.LANEOF,
|
||||
TypeVar.ASBOOL, TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR,
|
||||
@@ -231,18 +228,18 @@ class TestTypeVar(TestCase):
|
||||
for (i1, i2) in product(v, v):
|
||||
# Compute the derived sets for each starting with a full typeset
|
||||
full_ts = TypeSet(lanes=True, floats=True, ints=True, bools=True)
|
||||
ts1 = reduce(lambda ts, func: ts.map(func), i1, full_ts)
|
||||
ts2 = reduce(lambda ts, func: ts.map(func), i2, full_ts)
|
||||
ts1 = reduce(lambda ts, func: ts.image(func), i1, full_ts)
|
||||
ts2 = reduce(lambda ts, func: ts.image(func), i2, full_ts)
|
||||
|
||||
# Compute intersection
|
||||
intersect = ts1.copy()
|
||||
intersect &= ts2
|
||||
|
||||
# Propagate instersections backward
|
||||
ts1_src = reduce(lambda ts, func: ts.map_inverse(func),
|
||||
ts1_src = reduce(lambda ts, func: ts.preimage(func),
|
||||
reversed(i1),
|
||||
intersect)
|
||||
ts2_src = reduce(lambda ts, func: ts.map_inverse(func),
|
||||
ts2_src = reduce(lambda ts, func: ts.preimage(func),
|
||||
reversed(i2),
|
||||
intersect)
|
||||
|
||||
@@ -262,13 +259,10 @@ class TestTypeVar(TestCase):
|
||||
i2,
|
||||
TypeVar.from_typeset(ts2_src))
|
||||
|
||||
# The typesets of the two derived variables should be subsets of
|
||||
# the intersection we computed originally
|
||||
assert tv1.get_typeset().issubset(intersect)
|
||||
assert tv2.get_typeset().issubset(intersect)
|
||||
|
||||
# In the absence of AS_BOOL map(map_inverse(f)) == f so the
|
||||
# In the absence of AS_BOOL image(preimage(f)) == f so the
|
||||
# typesets of tv1 and tv2 should be exactly intersection
|
||||
assert (tv1.get_typeset() == tv2.get_typeset() and
|
||||
tv1.get_typeset() == intersect) or\
|
||||
TypeVar.ASBOOL in set(i1 + i2)
|
||||
assert tv1.get_typeset() == intersect or\
|
||||
has_non_bijective_derived_f(i1)
|
||||
|
||||
assert tv2.get_typeset() == intersect or\
|
||||
has_non_bijective_derived_f(i2)
|
||||
|
||||
556
lib/cretonne/meta/cdsl/ti.py
Normal file
556
lib/cretonne/meta/cdsl/ti.py
Normal file
@@ -0,0 +1,556 @@
|
||||
"""
|
||||
Type Inference
|
||||
"""
|
||||
from .typevar import TypeVar
|
||||
from .ast import Def, Var
|
||||
from copy import copy
|
||||
from itertools import product
|
||||
|
||||
try:
|
||||
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
||||
from typing import Iterable # noqa
|
||||
from typing import cast, List
|
||||
from .xform import Rtl, XForm # noqa
|
||||
from .ast import Expr # noqa
|
||||
if TYPE_CHECKING:
|
||||
Constraint = Tuple[TypeVar, TypeVar]
|
||||
ConstraintList = List[Constraint]
|
||||
TypeMap = Dict[TypeVar, TypeVar]
|
||||
VarMap = Dict[Var, TypeVar]
|
||||
except ImportError:
|
||||
TYPE_CHECKING = False
|
||||
pass
|
||||
|
||||
|
||||
class TypeEnv(object):
|
||||
"""
|
||||
Class encapsulating the neccessary book keeping for type inference.
|
||||
:attribute type_map: dict holding the equivalence relations between tvs
|
||||
:attribute constraints: a list of accumulated constraints - tuples
|
||||
(tv1, tv2)) where tv1 and tv2 are equal
|
||||
:attribute ranks: dictionary recording the (optional) ranks for tvs.
|
||||
tvs corresponding to real variables have explicitly
|
||||
specified ranks.
|
||||
:attribute vars: a set containing all known Vars
|
||||
:attribute idx: counter used to get fresh ids
|
||||
"""
|
||||
def __init__(self, arg=None):
|
||||
# type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None
|
||||
self.ranks = {} # type: Dict[TypeVar, int]
|
||||
self.vars = set() # type: Set[Var]
|
||||
|
||||
if arg is None:
|
||||
self.type_map = {} # type: TypeMap
|
||||
self.constraints = [] # type: ConstraintList
|
||||
else:
|
||||
self.type_map, self.constraints = arg
|
||||
|
||||
self.idx = 0
|
||||
|
||||
def __getitem__(self, arg):
|
||||
# type: (Union[TypeVar, Var]) -> TypeVar
|
||||
"""
|
||||
Lookup the canonical representative for a Var/TypeVar.
|
||||
"""
|
||||
if (isinstance(arg, Var)):
|
||||
tv = arg.get_typevar()
|
||||
else:
|
||||
assert (isinstance(arg, TypeVar))
|
||||
tv = arg
|
||||
|
||||
while tv in self.type_map:
|
||||
tv = self.type_map[tv]
|
||||
|
||||
if tv.is_derived:
|
||||
tv = TypeVar.derived(self[tv.base], tv.derived_func)
|
||||
return tv
|
||||
|
||||
def equivalent(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
"""
|
||||
Record a that the free tv1 is part of the same equivalence class as
|
||||
tv2. The canonical representative of the merged class is tv2's
|
||||
cannonical representative.
|
||||
"""
|
||||
assert not tv1.is_derived
|
||||
assert self[tv1] == tv1
|
||||
|
||||
# Make sure we don't create cycles
|
||||
if tv2.is_derived:
|
||||
assert self[tv2.base] != tv1
|
||||
|
||||
self.type_map[tv1] = tv2
|
||||
|
||||
def add_constraint(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
"""
|
||||
Add a new equivalence constraint between tv1 and tv2
|
||||
"""
|
||||
self.constraints.append((tv1, tv2))
|
||||
|
||||
def get_uid(self):
|
||||
# type: () -> str
|
||||
r = str(self.idx)
|
||||
self.idx += 1
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return self.dot()
|
||||
|
||||
def rank(self, tv):
|
||||
# type: (TypeVar) -> int
|
||||
"""
|
||||
Get the rank of tv in the partial order. TVs directly associated with a
|
||||
Var get their rank from the Var (see register()).
|
||||
Internally generated non-derived TVs implicitly get the lowest rank (0)
|
||||
Internal derived variables get the highest rank.
|
||||
"""
|
||||
default_rank = 5 if tv.is_derived else 0
|
||||
return self.ranks.get(tv, default_rank)
|
||||
|
||||
def register(self, v):
|
||||
# type: (Var) -> None
|
||||
"""
|
||||
Register a new Var v. This computes a rank for the associated TypeVar
|
||||
for v, which is used to impose a partial order on type variables.
|
||||
"""
|
||||
self.vars.add(v)
|
||||
|
||||
if v.is_input():
|
||||
r = 4
|
||||
elif v.is_intermediate():
|
||||
r = 3
|
||||
elif v.is_output():
|
||||
r = 2
|
||||
else:
|
||||
assert(v.is_temp())
|
||||
r = 1
|
||||
|
||||
self.ranks[v.get_typevar()] = r
|
||||
|
||||
def free_typevars(self):
|
||||
# type: () -> Set[TypeVar]
|
||||
"""
|
||||
Get the free typevars in the current type env.
|
||||
"""
|
||||
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
|
||||
# Filter out None here due to singleton type vars
|
||||
return set(filter(lambda x: x is not None, tvs))
|
||||
|
||||
def normalize(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Normalize by:
|
||||
- collapsing any roots that don't correspond to a concrete TV AND
|
||||
have a single TV derived from them or equivalent to them
|
||||
|
||||
E.g. if we have a root of the tree that looks like:
|
||||
|
||||
typeof_a typeof_b
|
||||
\ /
|
||||
typeof_x
|
||||
|
|
||||
half_width(1)
|
||||
|
|
||||
1
|
||||
|
||||
we want to collapse the linear path between 1 and typeof_x. The
|
||||
resulting graph is:
|
||||
|
||||
typeof_a typeof_b
|
||||
\ /
|
||||
typeof_x
|
||||
"""
|
||||
source_tvs = set([v.get_typevar() for v in self.vars])
|
||||
children = {} # type: Dict[TypeVar, Set[TypeVar]]
|
||||
for v in self.type_map.values():
|
||||
if not v.is_derived:
|
||||
continue
|
||||
|
||||
t = v.free_typevar()
|
||||
s = children.get(t, set())
|
||||
s.add(v)
|
||||
children[t] = s
|
||||
|
||||
for (a, b) in self.type_map.items():
|
||||
s = children.get(b, set())
|
||||
s.add(a)
|
||||
children[b] = s
|
||||
|
||||
for r in list(self.free_typevars()):
|
||||
while (r not in source_tvs and r in children and
|
||||
len(children[r]) == 1):
|
||||
child = list(children[r])[0]
|
||||
if child in self.type_map:
|
||||
assert self.type_map[child] == r
|
||||
del self.type_map[child]
|
||||
|
||||
r = child
|
||||
|
||||
def extract(self):
|
||||
# type: () -> TypeEnv
|
||||
"""
|
||||
Extract a clean type environment from self, that only mentions
|
||||
TVs associated with real variables
|
||||
"""
|
||||
vars_tvs = set([v.get_typevar() for v in self.vars])
|
||||
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
|
||||
new_constraints = [(self[tv1], self[tv2])
|
||||
for (tv1, tv2) in self.constraints]
|
||||
|
||||
# Sanity: new constraints and the new type_map should only contain
|
||||
# tvs associated with real vars
|
||||
for (a, b) in new_constraints:
|
||||
assert a.free_typevar() in vars_tvs and\
|
||||
b.free_typevar() in vars_tvs
|
||||
|
||||
for (k, v) in new_type_map.items():
|
||||
assert k in vars_tvs
|
||||
assert v.free_typevar() is None or v.free_typevar() in vars_tvs
|
||||
|
||||
t = TypeEnv()
|
||||
t.type_map = new_type_map
|
||||
t.constraints = new_constraints
|
||||
# ranks and vars contain only TVs associated with real vars
|
||||
t.ranks = copy(self.ranks)
|
||||
t.vars = copy(self.vars)
|
||||
return t
|
||||
|
||||
def concrete_typings(self):
|
||||
# type: () -> Iterable[VarMap]
|
||||
"""
|
||||
Return an iterable over all possible concrete typings permitted by this
|
||||
TypeEnv.
|
||||
"""
|
||||
free_tvs = self.free_typevars()
|
||||
free_tv_iters = [tv.get_typeset().concrete_types() for tv in free_tvs]
|
||||
for concrete_types in product(*free_tv_iters):
|
||||
# Build type substitutions for all free vars
|
||||
m = {tv: TypeVar.singleton(typ)
|
||||
for (tv, typ) in zip(free_tvs, concrete_types)}
|
||||
|
||||
concrete_var_map = {v: subst(self[v.get_typevar()], m)
|
||||
for v in self.vars}
|
||||
|
||||
# Check if constraints are satisfied for this typing
|
||||
failed = None
|
||||
for (tv1, tv2) in self.constraints:
|
||||
tv1 = subst(tv1, m)
|
||||
tv2 = subst(tv2, m)
|
||||
assert tv1.get_typeset().size() == 1 and\
|
||||
tv2.get_typeset().size() == 1
|
||||
if (tv1.get_typeset() != tv2.get_typeset()):
|
||||
failed = (tv1, tv2)
|
||||
break
|
||||
|
||||
if (failed is not None):
|
||||
continue
|
||||
|
||||
yield concrete_var_map
|
||||
|
||||
def dot(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Return a representation of self as a graph in dot format.
|
||||
Nodes correspond to TypeVariables.
|
||||
Dotted edges correspond to equivalences between TVS
|
||||
Solid edges correspond to derivation relations between TVs.
|
||||
Dashed edges correspond to equivalence constraints.
|
||||
"""
|
||||
def label(s):
|
||||
# type: (TypeVar) -> str
|
||||
return "\"" + str(s) + "\""
|
||||
|
||||
# Add all registered TVs (as some of them may be singleton nodes not
|
||||
# appearing in the graph
|
||||
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
|
||||
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, Optional[str]]]
|
||||
|
||||
for (k, v) in self.type_map.items():
|
||||
# Add all intermediate TVs appearing in edges
|
||||
nodes.add(k)
|
||||
nodes.add(v)
|
||||
edges.add((k, v, "dotted", None))
|
||||
while (v.is_derived):
|
||||
nodes.add(v.base)
|
||||
edges.add((v, v.base, "solid", v.derived_func))
|
||||
v = v.base
|
||||
|
||||
for (a, b) in self.constraints:
|
||||
assert a in nodes and b in nodes
|
||||
edges.add((a, b, "dashed", None))
|
||||
|
||||
root_nodes = set([x for x in nodes
|
||||
if x not in self.type_map and not x.is_derived])
|
||||
|
||||
r = "digraph {\n"
|
||||
for n in nodes:
|
||||
r += label(n)
|
||||
if n in root_nodes:
|
||||
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
|
||||
r += ";\n"
|
||||
|
||||
for (n1, n2, style, elabel) in edges:
|
||||
e = label(n1)
|
||||
if style == "dashed":
|
||||
e += '--'
|
||||
else:
|
||||
e += '->'
|
||||
e += label(n2)
|
||||
e += "[style={}".format(style)
|
||||
|
||||
if elabel is not None:
|
||||
e += ",label={}".format(elabel)
|
||||
e += "];\n"
|
||||
|
||||
r += e
|
||||
r += "}"
|
||||
|
||||
return r
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
TypingError = str
|
||||
TypingOrError = Union[TypeEnv, TypingError]
|
||||
|
||||
|
||||
def get_error(typing_or_err):
|
||||
# type: (TypingOrError) -> Optional[TypingError]
|
||||
"""
|
||||
Helper function to appease mypy when checking the result of typing.
|
||||
"""
|
||||
if isinstance(typing_or_err, str):
|
||||
if (TYPE_CHECKING):
|
||||
return cast(TypingError, typing_or_err)
|
||||
else:
|
||||
return typing_or_err
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_type_env(typing_or_err):
|
||||
# type: (TypingOrError) -> TypeEnv
|
||||
"""
|
||||
Helper function to appease mypy when checking the result of typing.
|
||||
"""
|
||||
assert isinstance(typing_or_err, TypeEnv)
|
||||
if (TYPE_CHECKING):
|
||||
return cast(TypeEnv, typing_or_err)
|
||||
else:
|
||||
return typing_or_err
|
||||
|
||||
|
||||
def subst(tv, tv_map):
|
||||
# type: (TypeVar, TypeMap) -> TypeVar
|
||||
"""
|
||||
Perform substition on the input tv using the TypeMap tv_map.
|
||||
"""
|
||||
if tv in tv_map:
|
||||
return tv_map[tv]
|
||||
|
||||
if tv.is_derived:
|
||||
return TypeVar.derived(subst(tv.base, tv_map), tv.derived_func)
|
||||
|
||||
return tv
|
||||
|
||||
|
||||
def normalize_tv(tv):
|
||||
# type: (TypeVar) -> TypeVar
|
||||
"""
|
||||
Normalize a (potentially derived) TV using the following rules:
|
||||
- collapse SAMEAS
|
||||
SAMEAS(base) -> base
|
||||
|
||||
- vector and width derived functions commute
|
||||
{HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) ->
|
||||
{HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base))
|
||||
|
||||
- half/double pairs collapse
|
||||
{HALF,DOUBLE}WIDTH({DOUBLE,HALF}WIDTH(base)) -> base
|
||||
{HALF,DOUBLE}VECTOR({DOUBLE,HALF}VECTOR(base)) -> base
|
||||
"""
|
||||
vector_derives = [TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR]
|
||||
width_derives = [TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
|
||||
|
||||
if not tv.is_derived:
|
||||
return tv
|
||||
|
||||
df = tv.derived_func
|
||||
|
||||
# Collapse SAMEAS edges
|
||||
if (df == TypeVar.SAMEAS):
|
||||
return normalize_tv(tv.base)
|
||||
|
||||
if (tv.base.is_derived):
|
||||
base_df = tv.base.derived_func
|
||||
|
||||
# Reordering: {HALFWIDTH, DOUBLEWIDTH} commute with {HALFVECTOR,
|
||||
# DOUBLEVECTOR}. Arbitrarily pick WIDTH < VECTOR
|
||||
if df in vector_derives and base_df in width_derives:
|
||||
return normalize_tv(
|
||||
TypeVar.derived(
|
||||
TypeVar.derived(tv.base.base, df), base_df))
|
||||
|
||||
# Cancelling: HALFWIDTH, DOUBLEWIDTH and HALFVECTOR, DOUBLEVECTOR
|
||||
# cancel each other. TODO: Does this cancellation hide type
|
||||
# overflow/underflow?
|
||||
|
||||
if (df, base_df) in \
|
||||
[(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR),
|
||||
(TypeVar.DOUBLEVECTOR, TypeVar.HALFVECTOR),
|
||||
(TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH),
|
||||
(TypeVar.DOUBLEWIDTH, TypeVar.HALFWIDTH)]:
|
||||
return normalize_tv(tv.base.base)
|
||||
|
||||
return TypeVar.derived(normalize_tv(tv.base), df)
|
||||
|
||||
|
||||
def constrain_fixpoint(tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
"""
|
||||
Given typevars tv1 and tv2 (which could be derived from one another)
|
||||
constrain their typesets to be the same. When one is derived from the
|
||||
other, repeat the constrain process until fixpoint.
|
||||
"""
|
||||
# Constrain tv2's typeset as long as tv1's typeset is changing.
|
||||
while True:
|
||||
old_tv1_ts = tv1.get_typeset().copy()
|
||||
tv2.constrain_types(tv1)
|
||||
if tv1.get_typeset() == old_tv1_ts:
|
||||
break
|
||||
|
||||
old_tv2_ts = tv2.get_typeset().copy()
|
||||
tv1.constrain_types(tv2)
|
||||
assert old_tv2_ts == tv2.get_typeset()
|
||||
|
||||
|
||||
def unify(tv1, tv2, typ):
|
||||
# type: (TypeVar, TypeVar, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Unify tv1 and tv2 in the current type environment typ, and return an
|
||||
updated type environment or error.
|
||||
"""
|
||||
tv1 = normalize_tv(typ[tv1])
|
||||
tv2 = normalize_tv(typ[tv2])
|
||||
|
||||
# Already unified
|
||||
if tv1 == tv2:
|
||||
return typ
|
||||
|
||||
if typ.rank(tv2) < typ.rank(tv1):
|
||||
return unify(tv2, tv1, typ)
|
||||
|
||||
constrain_fixpoint(tv1, tv2)
|
||||
|
||||
if (tv1.get_typeset().size() == 0 or tv2.get_typeset().size() == 0):
|
||||
return "Error: empty type created when unifying {} and {}"\
|
||||
.format(tv1, tv2)
|
||||
|
||||
# Free -> Derived(Free)
|
||||
if not tv1.is_derived:
|
||||
typ.equivalent(tv1, tv2)
|
||||
return typ
|
||||
|
||||
assert tv2.is_derived, "Ordering gives us !tv1.is_derived==>tv2.is_derived"
|
||||
|
||||
if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)):
|
||||
inv_f = TypeVar.inverse_func(tv1.derived_func)
|
||||
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
|
||||
|
||||
typ.add_constraint(tv1, tv2)
|
||||
return typ
|
||||
|
||||
|
||||
def ti_def(definition, typ):
|
||||
# type: (Def, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Perform type inference on one Def in the current type environment typ and
|
||||
return an updated type environment or error.
|
||||
|
||||
At a high level this works by creating fresh copies of each formal type var
|
||||
in the Def's instruction's signature, and unifying the formal tv with the
|
||||
corresponding actual tv.
|
||||
"""
|
||||
expr = definition.expr
|
||||
inst = expr.inst
|
||||
|
||||
# Create a map m mapping each free typevar in the signature of definition
|
||||
# to a fresh copy of itself
|
||||
all_formal_tvs = \
|
||||
[inst.outs[i].typevar for i in inst.value_results] +\
|
||||
[inst.ins[i].typevar for i in inst.value_opnums]
|
||||
free_formal_tvs = [tv for tv in all_formal_tvs if not tv.is_derived]
|
||||
m = {tv: tv.get_fresh_copy(str(typ.get_uid())) for tv in free_formal_tvs}
|
||||
|
||||
# Get fresh copies for each typevar in the signature (both free and
|
||||
# derived)
|
||||
fresh_formal_tvs = \
|
||||
[subst(inst.outs[i].typevar, m) for i in inst.value_results] +\
|
||||
[subst(inst.ins[i].typevar, m) for i in inst.value_opnums]
|
||||
|
||||
# Get the list of actual Vars
|
||||
actual_vars = [] # type: List[Expr]
|
||||
actual_vars += [definition.defs[i] for i in inst.value_results]
|
||||
actual_vars += [expr.args[i] for i in inst.value_opnums]
|
||||
|
||||
# Get the list of the actual TypeVars
|
||||
actual_tvs = []
|
||||
for v in actual_vars:
|
||||
assert(isinstance(v, Var))
|
||||
# Register with TypeEnv that this typevar corresponds ot variable v,
|
||||
# and thus has a given rank
|
||||
typ.register(v)
|
||||
actual_tvs.append(v.get_typevar())
|
||||
|
||||
# Unify each actual typevar with the correpsonding fresh formal tv
|
||||
for (actual_tv, formal_tv) in zip(actual_tvs, fresh_formal_tvs):
|
||||
typ_or_err = unify(actual_tv, formal_tv, typ)
|
||||
err = get_error(typ_or_err)
|
||||
if (err):
|
||||
return "fail ti on {} <: {}: ".format(actual_tv, formal_tv) + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
return typ
|
||||
|
||||
|
||||
def ti_rtl(rtl, typ):
|
||||
# type: (Rtl, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Perform type inference on an Rtl in a starting type env typ. Return an
|
||||
updated type environment or error.
|
||||
"""
|
||||
for (i, d) in enumerate(rtl.rtl):
|
||||
assert (isinstance(d, Def))
|
||||
typ_or_err = ti_def(d, typ)
|
||||
err = get_error(typ_or_err) # type: Optional[TypingError]
|
||||
if (err):
|
||||
return "On line {}: ".format(i) + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
return typ
|
||||
|
||||
|
||||
def ti_xform(xform, typ):
|
||||
# type: (XForm, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Perform type inference on an Rtl in a starting type env typ. Return an
|
||||
updated type environment or error.
|
||||
"""
|
||||
typ_or_err = ti_rtl(xform.src, typ)
|
||||
err = get_error(typ_or_err) # type: Optional[TypingError]
|
||||
if (err):
|
||||
return "In src pattern: " + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
typ_or_err = ti_rtl(xform.dst, typ)
|
||||
err = get_error(typ_or_err)
|
||||
if (err):
|
||||
return "In dst pattern: " + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
return get_type_env(typ_or_err)
|
||||
@@ -8,7 +8,6 @@ from __future__ import absolute_import
|
||||
import math
|
||||
from . import types, is_power_of_two
|
||||
from copy import deepcopy
|
||||
from .types import IntType, FloatType, BoolType
|
||||
|
||||
try:
|
||||
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
|
||||
@@ -210,6 +209,10 @@ class TypeSet(object):
|
||||
else:
|
||||
return False
|
||||
|
||||
def __ne__(self, other):
|
||||
# type: (object) -> bool
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
s = 'TypeSet(lanes={}'.format(pp_set(self.lanes))
|
||||
@@ -289,7 +292,9 @@ class TypeSet(object):
|
||||
new = self.copy()
|
||||
new.ints = set()
|
||||
new.floats = set()
|
||||
new.bools = self.ints.union(self.floats).union(self.bools)
|
||||
|
||||
if len(self.lanes.difference(set([1]))) > 0:
|
||||
new.bools = self.ints.union(self.floats).union(self.bools)
|
||||
|
||||
if 1 in self.lanes:
|
||||
new.bools.add(1)
|
||||
@@ -340,7 +345,7 @@ class TypeSet(object):
|
||||
|
||||
return new
|
||||
|
||||
def map(self, func):
|
||||
def image(self, func):
|
||||
# type: (str) -> TypeSet
|
||||
"""
|
||||
Return the image of self across the derived function func
|
||||
@@ -362,7 +367,7 @@ class TypeSet(object):
|
||||
else:
|
||||
assert False, "Unknown derived function: " + func
|
||||
|
||||
def map_inverse(self, func):
|
||||
def preimage(self, func):
|
||||
# type: (str) -> TypeSet
|
||||
"""
|
||||
Return the inverse image of self across the derived function func
|
||||
@@ -379,16 +384,14 @@ class TypeSet(object):
|
||||
return new
|
||||
elif (func == TypeVar.ASBOOL):
|
||||
new = self.copy()
|
||||
new.ints = self.bools.difference(set([1]))
|
||||
new.floats = self.bools.intersection(set([32, 64]))
|
||||
|
||||
if 1 not in self.bools:
|
||||
try:
|
||||
# If the range doesn't have b1, then the domain can't
|
||||
# include scalars, as as_bool(scalar)=b1
|
||||
new.lanes.remove(1)
|
||||
except KeyError:
|
||||
pass
|
||||
new.ints = self.bools.difference(set([1]))
|
||||
new.floats = self.bools.intersection(set([32, 64]))
|
||||
else:
|
||||
new.ints = set([2**x for x in range(3, 7)])
|
||||
new.floats = set([32, 64])
|
||||
|
||||
return new
|
||||
elif (func == TypeVar.HALFWIDTH):
|
||||
return self.double_width()
|
||||
@@ -409,27 +412,32 @@ class TypeSet(object):
|
||||
return len(self.lanes) * (len(self.ints) + len(self.floats) +
|
||||
len(self.bools))
|
||||
|
||||
def concrete_types(self):
|
||||
# type: () -> Iterable[types.ValueType]
|
||||
def by(scalar, lanes):
|
||||
# type: (types.ScalarType, int) -> types.ValueType
|
||||
if (lanes == 1):
|
||||
return scalar
|
||||
else:
|
||||
return scalar.by(lanes)
|
||||
|
||||
for nlanes in self.lanes:
|
||||
for bits in self.ints:
|
||||
yield by(types.IntType.with_bits(bits), nlanes)
|
||||
for bits in self.floats:
|
||||
yield by(types.FloatType.with_bits(bits), nlanes)
|
||||
for bits in self.bools:
|
||||
yield by(types.BoolType.with_bits(bits), nlanes)
|
||||
|
||||
def get_singleton(self):
|
||||
# type: () -> types.ValueType
|
||||
"""
|
||||
Return the singleton type represented by self. Can only call on
|
||||
typesets containing 1 type.
|
||||
"""
|
||||
assert self.size() == 1
|
||||
scalar_type = None # type: types.ScalarType
|
||||
if len(self.ints) > 0:
|
||||
scalar_type = IntType.with_bits(tuple(self.ints)[0])
|
||||
elif len(self.floats) > 0:
|
||||
scalar_type = FloatType.with_bits(tuple(self.floats)[0])
|
||||
else:
|
||||
scalar_type = BoolType.with_bits(tuple(self.bools)[0])
|
||||
|
||||
nlanes = tuple(self.lanes)[0]
|
||||
|
||||
if nlanes == 1:
|
||||
return scalar_type
|
||||
else:
|
||||
return scalar_type.by(nlanes)
|
||||
types = list(self.concrete_types())
|
||||
assert len(types) == 1
|
||||
return types[0]
|
||||
|
||||
|
||||
class TypeVar(object):
|
||||
@@ -519,6 +527,13 @@ class TypeVar(object):
|
||||
'TypeVar({}, {})'
|
||||
.format(self.name, self.type_set))
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
if (not self.is_derived):
|
||||
return object.__hash__(self)
|
||||
|
||||
return hash((self.derived_func, self.base))
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if not isinstance(other, TypeVar):
|
||||
@@ -530,6 +545,10 @@ class TypeVar(object):
|
||||
else:
|
||||
return self is other
|
||||
|
||||
def __ne__(self, other):
|
||||
# type: (object) -> bool
|
||||
return not self.__eq__(other)
|
||||
|
||||
# Supported functions for derived type variables.
|
||||
# The names here must match the method names on `ir::types::Type`.
|
||||
# The camel_case of the names must match `enum OperandConstraint` in
|
||||
@@ -542,6 +561,27 @@ class TypeVar(object):
|
||||
HALFVECTOR = 'half_vector'
|
||||
DOUBLEVECTOR = 'double_vector'
|
||||
|
||||
@staticmethod
|
||||
def is_bijection(func):
|
||||
# type: (str) -> bool
|
||||
return func in [
|
||||
TypeVar.SAMEAS,
|
||||
TypeVar.HALFWIDTH,
|
||||
TypeVar.DOUBLEWIDTH,
|
||||
TypeVar.HALFVECTOR,
|
||||
TypeVar.DOUBLEVECTOR]
|
||||
|
||||
@staticmethod
|
||||
def inverse_func(func):
|
||||
# type: (str) -> str
|
||||
return {
|
||||
TypeVar.SAMEAS: TypeVar.SAMEAS,
|
||||
TypeVar.HALFWIDTH: TypeVar.DOUBLEWIDTH,
|
||||
TypeVar.DOUBLEWIDTH: TypeVar.HALFWIDTH,
|
||||
TypeVar.HALFVECTOR: TypeVar.DOUBLEVECTOR,
|
||||
TypeVar.DOUBLEVECTOR: TypeVar.HALFVECTOR
|
||||
}[func]
|
||||
|
||||
@staticmethod
|
||||
def derived(base, derived_func):
|
||||
# type: (TypeVar, str) -> TypeVar
|
||||
@@ -668,7 +708,7 @@ class TypeVar(object):
|
||||
Get the free type variable controlling this one.
|
||||
"""
|
||||
if self.is_derived:
|
||||
return self.base
|
||||
return self.base.free_typevar()
|
||||
elif self.singleton_type() is not None:
|
||||
# A singleton type variable is not a proper free variable.
|
||||
return None
|
||||
@@ -697,7 +737,7 @@ class TypeVar(object):
|
||||
if not self.is_derived:
|
||||
self.type_set &= ts
|
||||
else:
|
||||
self.base.constrain_types_by_ts(ts.map_inverse(self.derived_func))
|
||||
self.base.constrain_types_by_ts(ts.preimage(self.derived_func))
|
||||
|
||||
def constrain_types(self, other):
|
||||
# type: (TypeVar) -> None
|
||||
@@ -723,19 +763,14 @@ class TypeVar(object):
|
||||
if not self.is_derived:
|
||||
return self.type_set
|
||||
else:
|
||||
if (self.derived_func == TypeVar.SAMEAS):
|
||||
return self.base.get_typeset()
|
||||
elif (self.derived_func == TypeVar.LANEOF):
|
||||
return self.base.get_typeset().lane_of()
|
||||
elif (self.derived_func == TypeVar.ASBOOL):
|
||||
return self.base.get_typeset().as_bool()
|
||||
elif (self.derived_func == TypeVar.HALFWIDTH):
|
||||
return self.base.get_typeset().half_width()
|
||||
elif (self.derived_func == TypeVar.DOUBLEWIDTH):
|
||||
return self.base.get_typeset().double_width()
|
||||
elif (self.derived_func == TypeVar.HALFVECTOR):
|
||||
return self.base.get_typeset().half_vector()
|
||||
elif (self.derived_func == TypeVar.DOUBLEVECTOR):
|
||||
return self.base.get_typeset().double_vector()
|
||||
else:
|
||||
assert False, "Unknown derived function: " + self.derived_func
|
||||
return self.base.get_typeset().image(self.derived_func)
|
||||
|
||||
def get_fresh_copy(self, name):
|
||||
# type: (str) -> TypeVar
|
||||
"""
|
||||
Get a fresh copy of self. Can only be called on free typevars.
|
||||
"""
|
||||
assert not self.is_derived
|
||||
tv = TypeVar.from_typeset(self.type_set.copy())
|
||||
tv.name = name
|
||||
return tv
|
||||
|
||||
@@ -3,6 +3,7 @@ Instruction transformations.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from .ast import Def, Var, Apply
|
||||
from .ti import ti_xform, TypeEnv, get_type_env
|
||||
|
||||
try:
|
||||
from typing import Union, Iterator, Sequence, Iterable, List, Dict # noqa
|
||||
@@ -83,6 +84,8 @@ class XForm(object):
|
||||
self._rewrite_rtl(src, symtab, Var.SRCCTX)
|
||||
num_src_inputs = len(self.inputs)
|
||||
self._rewrite_rtl(dst, symtab, Var.DSTCTX)
|
||||
# Needed for testing type inference on XForms
|
||||
self.symtab = symtab
|
||||
|
||||
# Check for inconsistently used inputs.
|
||||
for i in self.inputs:
|
||||
@@ -96,9 +99,25 @@ class XForm(object):
|
||||
"extra inputs in dst RTL: {}".format(
|
||||
self.inputs[num_src_inputs:]))
|
||||
|
||||
self._infer_types(self.src)
|
||||
self._infer_types(self.dst)
|
||||
self._collect_typevars()
|
||||
# Perform type inference and cleanup
|
||||
raw_ti = get_type_env(ti_xform(self, TypeEnv()))
|
||||
raw_ti.normalize()
|
||||
self.ti = raw_ti.extract()
|
||||
|
||||
# Sanity: The set of inferred free typevars should be a subset of the
|
||||
# TVs corresponding to Vars appearing in src
|
||||
self.free_typevars = self.ti.free_typevars()
|
||||
src_vars = set(self.inputs).union(
|
||||
[x for x in self.defs if not x.is_temp()])
|
||||
src_tvs = set([v.get_typevar() for v in src_vars])
|
||||
if (not self.free_typevars.issubset(src_tvs)):
|
||||
raise AssertionError(
|
||||
"Some free vars don't appear in src - {}"
|
||||
.format(self.free_typevars.difference(src_tvs)))
|
||||
|
||||
# Update the type vars for each Var to their inferred values
|
||||
for v in self.inputs + self.defs:
|
||||
v.set_typevar(self.ti[v.get_typevar()])
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
@@ -202,63 +221,6 @@ class XForm(object):
|
||||
raise AssertionError(
|
||||
'{} not defined in dest pattern'.format(d))
|
||||
|
||||
def _infer_types(self, rtl):
|
||||
# type: (Rtl) -> None
|
||||
"""Assign type variables to all value variables used in `rtl`."""
|
||||
for d in rtl.rtl:
|
||||
inst = d.expr.inst
|
||||
|
||||
# Get the Var corresponding to the controlling type variable.
|
||||
ctrl_var = None # type: Var
|
||||
if inst.is_polymorphic:
|
||||
if inst.use_typevar_operand:
|
||||
# Should this be an assertion instead?
|
||||
# Should all value operands be required to be Vars?
|
||||
arg = d.expr.args[inst.format.typevar_operand]
|
||||
if isinstance(arg, Var):
|
||||
ctrl_var = arg
|
||||
else:
|
||||
ctrl_var = d.defs[inst.value_results[0]]
|
||||
|
||||
# Reconcile arguments with the requirements of `inst`.
|
||||
for opnum in inst.value_opnums:
|
||||
inst_tv = inst.ins[opnum].typevar
|
||||
v = d.expr.args[opnum]
|
||||
if isinstance(v, Var):
|
||||
v.constrain_typevar(inst_tv, inst.ctrl_typevar, ctrl_var)
|
||||
|
||||
# Reconcile results with the requirements of `inst`.
|
||||
for resnum in inst.value_results:
|
||||
inst_tv = inst.outs[resnum].typevar
|
||||
v = d.defs[resnum]
|
||||
v.constrain_typevar(inst_tv, inst.ctrl_typevar, ctrl_var)
|
||||
|
||||
def _collect_typevars(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Collect a list of variables whose type can be used to infer the types
|
||||
of all expressions.
|
||||
|
||||
This should be called after `_infer_types()` above has computed type
|
||||
variables for all the used vars.
|
||||
"""
|
||||
fvars = list(v for v in self.inputs if v.has_free_typevar())
|
||||
fvars += list(v for v in self.defs if v.has_free_typevar())
|
||||
self.free_typevars = fvars
|
||||
|
||||
# When substituting a pattern, we know the types of all variables that
|
||||
# appear on the source side: inut, output, and intermediate values.
|
||||
# However, temporary values which appear only on the destination side
|
||||
# must have their type computed somehow.
|
||||
#
|
||||
# Some variables have a fixed type which appears as a type variable
|
||||
# with a singleton_type field set. That's allowed for temps too.
|
||||
for v in fvars:
|
||||
if v.is_temp() and not v.typevar.singleton_type():
|
||||
raise AssertionError(
|
||||
"Cannot determine type of temp '{}' in xform:\n{}"
|
||||
.format(v, self))
|
||||
|
||||
|
||||
class XFormGroup(object):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user