* Add more rigorous type inference and encapsulate the type inferece code in its own file (ti.py). Add constraints accumulation during type inference, to represent constraints that cannot be expressed using bijective derivation functions between typevars. Add testing for new type inference code. * Additional annotations to appease mypy
433 lines
16 KiB
Python
433 lines
16 KiB
Python
from __future__ import absolute_import
|
|
from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\
|
|
b1, icmp, iadd_cout, iadd_cin
|
|
from base.legalize import narrow, expand
|
|
from base.immediates import intcc
|
|
from .typevar import TypeVar
|
|
from .ast import Var, Def
|
|
from .xform import Rtl, XForm
|
|
from .ti import ti_rtl, subst, TypeEnv, get_type_env
|
|
from unittest import TestCase
|
|
from functools import reduce
|
|
|
|
try:
|
|
from .ti import TypeMap, ConstraintList, VarMap, TypingOrError # noqa
|
|
from .ti import Constraint
|
|
from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa
|
|
except ImportError:
|
|
TYPE_CHECKING = False
|
|
|
|
|
|
def sort_constr(c):
|
|
# type: (Constraint) -> Constraint
|
|
"""
|
|
Sort the 2 typevars in a constraint by name for comparison
|
|
"""
|
|
r = tuple(sorted(c, key=lambda y: y.name))
|
|
if TYPE_CHECKING:
|
|
return cast(Constraint, r)
|
|
else:
|
|
return r
|
|
|
|
|
|
def agree(me, other):
|
|
# type: (TypeEnv, TypeEnv) -> bool
|
|
"""
|
|
Given TypeEnvs me and other, check if they agree. As part of that build
|
|
a map m from TVs in me to their corresponding TVs in other.
|
|
Specifically:
|
|
|
|
1. Check that all TVs that are keys in me.type_map are also defined
|
|
in other.type_map
|
|
|
|
2. For any tv in me.type_map check that:
|
|
me[tv].get_typeset() == other[tv].get_typeset()
|
|
|
|
3. Set m[me[tv]] = other[tv] in the substitution m
|
|
|
|
4. If we find another tv1 such that me[tv1] == me[tv], assert that
|
|
other[tv1] == m[me[tv1]] == m[me[tv]] = other[tv]
|
|
|
|
5. Check that me and other have the same constraints under the
|
|
substitution m
|
|
"""
|
|
m = {} # type: TypeMap
|
|
# Check that our type map and other's agree and built substitution m
|
|
for tv in me.type_map:
|
|
if (me[tv] not in m):
|
|
m[me[tv]] = other[tv]
|
|
if me[tv].get_typeset() != other[tv].get_typeset():
|
|
return False
|
|
else:
|
|
if m[me[tv]] != other[tv]:
|
|
return False
|
|
|
|
# Tranlsate our constraints using m, and sort
|
|
me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints]
|
|
me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr])
|
|
|
|
# Sort other's constraints
|
|
other_equiv_constr = sorted([sort_constr(x) for x in other.constraints],
|
|
key=lambda y: y[0].name)
|
|
|
|
return me_equiv_constr == other_equiv_constr
|
|
|
|
|
|
def check_typing(got_or_err, expected, symtab=None):
|
|
# type: (TypingOrError, Tuple[VarMap, ConstraintList], Dict[str, Var]) -> None # noqa
|
|
"""
|
|
Check that a the typying we received (got_or_err) complies with the
|
|
expected typing (expected). If symtab is specified, substitute the Vars in
|
|
expected using symtab first (used when checking type inference on XForms)
|
|
"""
|
|
(m, c) = expected
|
|
got = get_type_env(got_or_err)
|
|
|
|
if (symtab is not None):
|
|
# For xforms we first need to re-write our TVs in terms of the tvs
|
|
# stored internally in the XForm. Use the symtab passed
|
|
subst_m = {k.get_typevar(): symtab[str(k)].get_typevar()
|
|
for k in m.keys()}
|
|
# Convert m from a Var->TypeVar map to TypeVar->TypeVar map where
|
|
# the key TypeVar is re-written to its XForm internal version
|
|
tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()}
|
|
# Rewrite the TVs in the input constraints to their XForm internal
|
|
# versions
|
|
c = [(subst(a, subst_m), subst(b, subst_m)) for (a, b) in c]
|
|
else:
|
|
# If no symtab, just convert m from Var->TypeVar map to a
|
|
# TypeVar->TypeVar map
|
|
tv_m = {k.get_typevar(): v for (k, v) in m.items()}
|
|
|
|
expected_typ = TypeEnv((tv_m, c))
|
|
assert agree(expected_typ, got), \
|
|
"typings disagree:\n {} \n {}".format(got.dot(),
|
|
expected_typ.dot())
|
|
|
|
|
|
def check_concrete_typing_rtl(var_types, rtl):
|
|
# type: (VarMap, Rtl) -> None
|
|
"""
|
|
Check that a concrete type assignment var_types (Dict[Var, TypeVar]) is
|
|
valid for an Rtl rtl. Specifically check that:
|
|
|
|
1) For each Var v \in rtl, v is defined in var_types
|
|
|
|
2) For all v, var_types[v] is a singleton type
|
|
|
|
3) For each v, and each location u, where v is used with expected type
|
|
tv_u, var_types[v].get_typeset() is a subset of
|
|
subst(tv_u, m).get_typeset() where m is the substitution of
|
|
formals->actuals we are building so far.
|
|
|
|
4) If tv_u is non-derived and not in m, set m[tv_u]= var_types[v]
|
|
"""
|
|
for d in rtl.rtl:
|
|
assert isinstance(d, Def)
|
|
inst = d.expr.inst
|
|
# Accumulate all actual TVs for value defs/opnums in actual_tvs
|
|
actual_tvs = [var_types[d.defs[i]] for i in inst.value_results]
|
|
for v in [d.expr.args[i] for i in inst.value_opnums]:
|
|
assert isinstance(v, Var)
|
|
actual_tvs.append(var_types[v])
|
|
|
|
# Accumulate all formal TVs for value defs/opnums in actual_tvs
|
|
formal_tvs = [inst.outs[i].typevar for i in inst.value_results] +\
|
|
[inst.ins[i].typevar for i in inst.value_opnums]
|
|
m = {} # type: TypeMap
|
|
|
|
# For each actual/formal pair check that they agree
|
|
for (actual_tv, formal_tv) in zip(actual_tvs, formal_tvs):
|
|
# actual should be a singleton
|
|
assert actual_tv.singleton_type() is not None
|
|
formal_tv = subst(formal_tv, m)
|
|
# actual should agree with the concretized formal
|
|
assert actual_tv.get_typeset().issubset(formal_tv.get_typeset())
|
|
|
|
if formal_tv not in m and not formal_tv.is_derived:
|
|
m[formal_tv] = actual_tv
|
|
|
|
|
|
def check_concrete_typing_xform(var_types, xform):
|
|
# type: (VarMap, XForm) -> None
|
|
"""
|
|
Check a concrete type assignment var_types for an XForm xform
|
|
"""
|
|
check_concrete_typing_rtl(var_types, xform.src)
|
|
check_concrete_typing_rtl(var_types, xform.dst)
|
|
|
|
|
|
class TypeCheckingBaseTest(TestCase):
|
|
def setUp(self):
|
|
# type: () -> None
|
|
self.v0 = Var("v0")
|
|
self.v1 = Var("v1")
|
|
self.v2 = Var("v2")
|
|
self.v3 = Var("v3")
|
|
self.v4 = Var("v4")
|
|
self.v5 = Var("v5")
|
|
self.v6 = Var("v6")
|
|
self.v7 = Var("v7")
|
|
self.v8 = Var("v8")
|
|
self.v9 = Var("v9")
|
|
self.imm0 = Var("imm0")
|
|
self.IxN_nonscalar = TypeVar("IxN_nonscalar", "", ints=True,
|
|
scalars=False, simd=True)
|
|
self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True,
|
|
scalars=False, simd=True)
|
|
self.b1 = TypeVar.singleton(b1)
|
|
|
|
|
|
class TestRTL(TypeCheckingBaseTest):
|
|
def test_bad_rtl1(self):
|
|
# type: () -> None
|
|
r = Rtl(
|
|
(self.v0, self.v1) << vsplit(self.v2),
|
|
self.v3 << vconcat(self.v0, self.v2),
|
|
)
|
|
ti = TypeEnv()
|
|
self.assertEqual(ti_rtl(r, ti),
|
|
"On line 1: fail ti on `typeof_v2` <: `2`: " +
|
|
"Error: empty type created when unifying " +
|
|
"`typeof_v2` and `half_vector(typeof_v2)`")
|
|
|
|
def test_vselect(self):
|
|
# type: () -> None
|
|
r = Rtl(
|
|
self.v0 << vselect(self.v1, self.v2, self.v3),
|
|
)
|
|
ti = TypeEnv()
|
|
typing = ti_rtl(r, ti)
|
|
txn = self.TxN.get_fresh_copy("TxN1")
|
|
check_typing(typing, ({
|
|
self.v0: txn,
|
|
self.v1: txn.as_bool(),
|
|
self.v2: txn,
|
|
self.v3: txn
|
|
}, []))
|
|
|
|
def test_vselect_icmpimm(self):
|
|
# type: () -> None
|
|
r = Rtl(
|
|
self.v0 << iconst(self.imm0),
|
|
self.v1 << icmp(intcc.eq, self.v2, self.v0),
|
|
self.v5 << vselect(self.v1, self.v3, self.v4),
|
|
)
|
|
ti = TypeEnv()
|
|
typing = ti_rtl(r, ti)
|
|
ixn = self.IxN_nonscalar.get_fresh_copy("IxN1")
|
|
txn = self.TxN.get_fresh_copy("TxN1")
|
|
check_typing(typing, ({
|
|
self.v0: ixn,
|
|
self.v1: ixn.as_bool(),
|
|
self.v2: ixn,
|
|
self.v3: txn,
|
|
self.v4: txn,
|
|
self.v5: txn,
|
|
}, [(ixn.as_bool(), txn.as_bool())]))
|
|
|
|
def test_vselect_vsplits(self):
|
|
# type: () -> None
|
|
r = Rtl(
|
|
self.v3 << vselect(self.v0, self.v1, self.v2),
|
|
(self.v4, self.v5) << vsplit(self.v3),
|
|
(self.v6, self.v7) << vsplit(self.v4),
|
|
)
|
|
ti = TypeEnv()
|
|
typing = ti_rtl(r, ti)
|
|
t = TypeVar("t", "", ints=True, bools=True, floats=True,
|
|
simd=(4, 256))
|
|
check_typing(typing, ({
|
|
self.v0: t.as_bool(),
|
|
self.v1: t,
|
|
self.v2: t,
|
|
self.v3: t,
|
|
self.v4: t.half_vector(),
|
|
self.v5: t.half_vector(),
|
|
self.v6: t.half_vector().half_vector(),
|
|
self.v7: t.half_vector().half_vector(),
|
|
}, []))
|
|
|
|
def test_vselect_vconcats(self):
|
|
# type: () -> None
|
|
r = Rtl(
|
|
self.v3 << vselect(self.v0, self.v1, self.v2),
|
|
self.v8 << vconcat(self.v3, self.v3),
|
|
self.v9 << vconcat(self.v8, self.v8),
|
|
)
|
|
ti = TypeEnv()
|
|
typing = ti_rtl(r, ti)
|
|
t = TypeVar("t", "", ints=True, bools=True, floats=True,
|
|
simd=(2, 64))
|
|
check_typing(typing, ({
|
|
self.v0: t.as_bool(),
|
|
self.v1: t,
|
|
self.v2: t,
|
|
self.v3: t,
|
|
self.v8: t.double_vector(),
|
|
self.v9: t.double_vector().double_vector(),
|
|
}, []))
|
|
|
|
def test_vselect_vsplits_vconcats(self):
|
|
# type: () -> None
|
|
r = Rtl(
|
|
self.v3 << vselect(self.v0, self.v1, self.v2),
|
|
(self.v4, self.v5) << vsplit(self.v3),
|
|
(self.v6, self.v7) << vsplit(self.v4),
|
|
self.v8 << vconcat(self.v3, self.v3),
|
|
self.v9 << vconcat(self.v8, self.v8),
|
|
)
|
|
ti = TypeEnv()
|
|
typing = ti_rtl(r, ti)
|
|
t = TypeVar("t", "", ints=True, bools=True, floats=True,
|
|
simd=(4, 64))
|
|
check_typing(typing, ({
|
|
self.v0: t.as_bool(),
|
|
self.v1: t,
|
|
self.v2: t,
|
|
self.v3: t,
|
|
self.v4: t.half_vector(),
|
|
self.v5: t.half_vector(),
|
|
self.v6: t.half_vector().half_vector(),
|
|
self.v7: t.half_vector().half_vector(),
|
|
self.v8: t.double_vector(),
|
|
self.v9: t.double_vector().double_vector(),
|
|
}, []))
|
|
|
|
def test_bint(self):
|
|
# type: () -> None
|
|
r = Rtl(
|
|
self.v4 << iadd(self.v1, self.v2),
|
|
self.v5 << bint(self.v3),
|
|
self.v0 << iadd(self.v4, self.v5)
|
|
)
|
|
ti = TypeEnv()
|
|
typing = ti_rtl(r, ti)
|
|
itype = TypeVar("t", "", ints=True, simd=(1, 256))
|
|
btype = TypeVar("b", "", bools=True, simd=True)
|
|
|
|
# Check that self.v5 gets the same integer type as
|
|
# the rest of them
|
|
# TODO: Add constraint nlanes(v3) == nlanes(v1) when we
|
|
# add that type constraint to bint
|
|
check_typing(typing, ({
|
|
self.v1: itype,
|
|
self.v2: itype,
|
|
self.v4: itype,
|
|
self.v5: itype,
|
|
self.v3: btype,
|
|
self.v0: itype,
|
|
}, []))
|
|
|
|
|
|
class TestXForm(TypeCheckingBaseTest):
|
|
def test_iadd_cout(self):
|
|
# type: () -> None
|
|
x = XForm(Rtl((self.v0, self.v1) << iadd_cout(self.v2, self.v3),),
|
|
Rtl(
|
|
self.v0 << iadd(self.v2, self.v3),
|
|
self.v1 << icmp(intcc.ult, self.v0, self.v2)
|
|
))
|
|
itype = TypeVar("t", "", ints=True, simd=(1, 1))
|
|
|
|
check_typing(x.ti, ({
|
|
self.v0: itype,
|
|
self.v2: itype,
|
|
self.v3: itype,
|
|
self.v1: itype.as_bool(),
|
|
}, []), x.symtab)
|
|
|
|
def test_iadd_cin(self):
|
|
# type: () -> None
|
|
x = XForm(Rtl(self.v0 << iadd_cin(self.v1, self.v2, self.v3)),
|
|
Rtl(
|
|
self.v4 << iadd(self.v1, self.v2),
|
|
self.v5 << bint(self.v3),
|
|
self.v0 << iadd(self.v4, self.v5)
|
|
))
|
|
itype = TypeVar("t", "", ints=True, simd=(1, 1))
|
|
|
|
check_typing(x.ti, ({
|
|
self.v0: itype,
|
|
self.v1: itype,
|
|
self.v2: itype,
|
|
self.v3: self.b1,
|
|
self.v4: itype,
|
|
self.v5: itype,
|
|
}, []), x.symtab)
|
|
|
|
def test_enumeration_with_constraints(self):
|
|
# type: () -> None
|
|
xform = XForm(
|
|
Rtl(
|
|
self.v0 << iconst(self.imm0),
|
|
self.v1 << icmp(intcc.eq, self.v2, self.v0),
|
|
self.v5 << vselect(self.v1, self.v3, self.v4)
|
|
),
|
|
Rtl(
|
|
self.v0 << iconst(self.imm0),
|
|
self.v1 << icmp(intcc.eq, self.v2, self.v0),
|
|
self.v5 << vselect(self.v1, self.v3, self.v4)
|
|
))
|
|
|
|
# Check all var assigns are correct
|
|
assert len(xform.ti.constraints) > 0
|
|
concrete_var_assigns = list(xform.ti.concrete_typings())
|
|
|
|
v0 = xform.symtab[str(self.v0)]
|
|
v1 = xform.symtab[str(self.v1)]
|
|
v2 = xform.symtab[str(self.v2)]
|
|
v3 = xform.symtab[str(self.v3)]
|
|
v4 = xform.symtab[str(self.v4)]
|
|
v5 = xform.symtab[str(self.v5)]
|
|
|
|
for var_m in concrete_var_assigns:
|
|
assert var_m[v0] == var_m[v2] and \
|
|
var_m[v3] == var_m[v4] and\
|
|
var_m[v5] == var_m[v3] and\
|
|
var_m[v1] == var_m[v2].as_bool() and\
|
|
var_m[v1].get_typeset() == var_m[v3].as_bool().get_typeset()
|
|
check_concrete_typing_xform(var_m, xform)
|
|
|
|
# The number of possible typings here is:
|
|
# 8 cases for v0 = i8xN times 2 options for v3 - i8, b8 = 16
|
|
# 8 cases for v0 = i16xN times 2 options for v3 - i16, b16 = 16
|
|
# 8 cases for v0 = i32xN times 3 options for v3 - i32, b32, f32 = 24
|
|
# 8 cases for v0 = i64xN times 3 options for v3 - i64, b64, f64 = 24
|
|
#
|
|
# (Note we have 8 cases for lanes since vselect prevents scalars)
|
|
# Total: 2*16 + 2*24 = 80
|
|
assert len(concrete_var_assigns) == 80
|
|
|
|
def test_base_legalizations_enumeration(self):
|
|
# type: () -> None
|
|
for xform in narrow.xforms + expand.xforms:
|
|
# Any legalization patterns we defined should have at least 1
|
|
# concrete typing
|
|
concrete_typings_list = list(xform.ti.concrete_typings())
|
|
assert len(concrete_typings_list) > 0
|
|
|
|
# If there are no free_typevars, this is a non-polymorphic pattern.
|
|
# There should be only one possible concrete typing.
|
|
if (len(xform.free_typevars) == 0):
|
|
assert len(concrete_typings_list) == 1
|
|
continue
|
|
|
|
# For any patterns where the type env includes constraints, at
|
|
# least one of the "theoretically possible" concrete typings must
|
|
# be prevented by the constraints. (i.e. we are not emitting
|
|
# unneccessary constraints).
|
|
# We check that by asserting that the number of concrete typings is
|
|
# less than the number of all possible free typevar assignments
|
|
if (len(xform.ti.constraints) > 0):
|
|
theoretical_num_typings =\
|
|
reduce(lambda x, y: x*y,
|
|
[tv.get_typeset().size()
|
|
for tv in xform.free_typevars], 1)
|
|
assert len(concrete_typings_list) < theoretical_num_typings
|
|
|
|
# Check the validity of each individual concrete typing against the
|
|
# xform
|
|
for concrete_typing in concrete_typings_list:
|
|
check_concrete_typing_xform(concrete_typing, xform)
|