Add better type inference and encapsulate it in its own file (#110)

* Add more rigorous type inference and encapsulate the type inferece code in its own file (ti.py).

Add constraints accumulation during type inference, to represent constraints that cannot be expressed
using bijective derivation functions between typevars.

Add testing for new type inference code.

* Additional annotations to appease mypy
This commit is contained in:
d1m0
2017-07-05 09:16:44 -07:00
committed by Jakob Stoklund Olesen
parent f867ddbf0c
commit a5c96ef6bf
6 changed files with 1123 additions and 281 deletions

View File

@@ -100,8 +100,8 @@ class Var(Expr):
# TypeVar representing the type of this variable.
self.typevar = None # type: TypeVar
# The original 'typeof(x)' type variable that was created for this Var.
# This one doesn't change. `self.typevar` above may be joined with
# other typevars.
# This one doesn't change. `self.typevar` above may be changed to
# another typevar by type inference.
self.original_typevar = None # type: TypeVar
def __str__(self):
@@ -180,16 +180,9 @@ class Var(Expr):
self.typevar = tv
return self.typevar
def link_typevar(self, base, derived_func):
# type: (TypeVar, str) -> None
"""
Link the type variable on this Var to the type variable `base` using
`derived_func`.
"""
self.original_typevar = None
self.typevar.change_to_derived(base, derived_func)
# Possibly eliminate redundant SAMEAS links.
self.typevar = self.typevar.strip_sameas()
def set_typevar(self, tv):
# type: (TypeVar) -> None
self.typevar = tv
def has_free_typevar(self):
# type: () -> bool
@@ -213,136 +206,6 @@ class Var(Expr):
"""
return self.typevar.rust_expr()
def constrain_typevar(self, sym_typevar, sym_ctrl, ctrl_var):
# type: (TypeVar, TypeVar, Var) -> None
"""
Constrain the set of allowed types for this variable.
Merge type variables for the involved variables to minimize the set for
free type variables.
Suppose we're looking at an instruction defined like this:
c = Operand('c', TxN.as_bool())
x = Operand('x', TxN)
y = Operand('y', TxN)
a = Operand('a', TxN)
vselect = Instruction('vselect', ins=(c, x, y), outs=a)
And suppose the instruction is used in a pattern like this:
v0 << vselect(v1, v2, v3)
We want to reconcile the types of the variables v0-v3 with the
constraints from the definition of vselect. This means that v0, v2, and
v3 must all have the same type, and v1 must have the type
`typeof(v2).as_bool()`.
The types are reconciled by calling this function once for each
input/output operand on the instruction in the pattern with these
arguments.
:param sym_typevar: Symbolic type variable constraining this variable
in the definition of the instruction.
:param sym_ctrl: Controlling type variable of `sym_typevar` in the
definition of the instruction.
:param ctrl_var: Variable determining the type of `sym_ctrl`.
When processing `v1` as used in the pattern above, we would get:
- self: v1
- sym_typevar: TxN.as_bool()
- sym_ctrl: TxN
- ctrl_var: v2
Here, 'v2' represents the controlling variable because of how the
`Ternary` instruction format is defined with `typevar_operand=1`.
"""
# First check if sym_typevar is tied to the controlling type variable
# in the instruction definition. We also allow free type variables on
# instruction inputs that can't be tied to anything else.
#
# This also covers non-polymorphic instructions and other cases where
# we don't have a Var representing the controlling type variable.
sym_free_var = sym_typevar.free_typevar()
if not sym_free_var or sym_free_var is not sym_ctrl or not ctrl_var:
# Just constrain our type to be compatible with the required
# typeset.
self.get_typevar().constrain_types(sym_typevar)
return
# Now sym_typevar is known to be tied to (or identical to) the
# controlling type variable.
if not self.typevar:
# If this variable is not yet constrained, just infer its type and
# link it to the controlling type variable.
if not sym_typevar.is_derived:
assert sym_typevar is sym_ctrl
# Identity mapping.
# Note that `self == ctrl_var` is both possible and common.
self.typevar = ctrl_var.get_typevar()
else:
assert self is not ctrl_var, (
'Impossible type constraints for {}: {}'
.format(self, sym_typevar))
# Create a derived type variable identical to sym_typevar, but
# with a different base.
self.typevar = TypeVar.derived(
ctrl_var.get_typevar(),
sym_typevar.derived_func)
# Match the type set constraints of the instruction.
self.typevar.constrain_types(sym_typevar)
return
# We already have a self.typevar describing our constraints. We need to
# reconcile with the additional constraints.
# It's likely that ctrl_var and self already share a type
# variable. (Often because `ctrl_var == self`).
if ctrl_var.typevar == self.typevar:
return
if not sym_typevar.is_derived:
assert sym_typevar is sym_ctrl
# sym_typevar is a direct use of sym_ctrl, so we need to reconcile
# self with ctrl_var.
assert not sym_typevar.is_derived
self.typevar.constrain_types(sym_typevar)
# It's possible that ctrl_var has not yet been assigned a type
# variable.
if not ctrl_var.typevar:
ctrl_var.typevar = self.typevar
return
# We can also bind variables with a free type variable to another
# variable. Prefer to do this to temps because they aren't allowed
# to be free,
if self.is_temp() and self.has_free_typevar():
self.link_typevar(ctrl_var.typevar, TypeVar.SAMEAS)
return
if ctrl_var.is_temp() and ctrl_var.has_free_typevar():
ctrl_var.link_typevar(self.typevar, TypeVar.SAMEAS)
return
if self.has_free_typevar():
self.link_typevar(ctrl_var.typevar, TypeVar.SAMEAS)
return
if ctrl_var.has_free_typevar():
ctrl_var.link_typevar(self.typevar, TypeVar.SAMEAS)
return
# TODO: Other cases are harder to handle.
#
# - If either variable is an independent free type variable, it
# should be changed to be linked to the other.
# - If both variable are free, we should pick one to link to the
# other. In particular, if one is a temp, it should be linked.
else:
# sym_typevar is derived from sym_ctrl.
# TODO: Other cases are harder to handle.
pass
class Apply(Expr):
"""

View File

@@ -0,0 +1,432 @@
from __future__ import absolute_import
from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\
b1, icmp, iadd_cout, iadd_cin
from base.legalize import narrow, expand
from base.immediates import intcc
from .typevar import TypeVar
from .ast import Var, Def
from .xform import Rtl, XForm
from .ti import ti_rtl, subst, TypeEnv, get_type_env
from unittest import TestCase
from functools import reduce
try:
from .ti import TypeMap, ConstraintList, VarMap, TypingOrError # noqa
from .ti import Constraint
from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa
except ImportError:
TYPE_CHECKING = False
def sort_constr(c):
# type: (Constraint) -> Constraint
"""
Sort the 2 typevars in a constraint by name for comparison
"""
r = tuple(sorted(c, key=lambda y: y.name))
if TYPE_CHECKING:
return cast(Constraint, r)
else:
return r
def agree(me, other):
# type: (TypeEnv, TypeEnv) -> bool
"""
Given TypeEnvs me and other, check if they agree. As part of that build
a map m from TVs in me to their corresponding TVs in other.
Specifically:
1. Check that all TVs that are keys in me.type_map are also defined
in other.type_map
2. For any tv in me.type_map check that:
me[tv].get_typeset() == other[tv].get_typeset()
3. Set m[me[tv]] = other[tv] in the substitution m
4. If we find another tv1 such that me[tv1] == me[tv], assert that
other[tv1] == m[me[tv1]] == m[me[tv]] = other[tv]
5. Check that me and other have the same constraints under the
substitution m
"""
m = {} # type: TypeMap
# Check that our type map and other's agree and built substitution m
for tv in me.type_map:
if (me[tv] not in m):
m[me[tv]] = other[tv]
if me[tv].get_typeset() != other[tv].get_typeset():
return False
else:
if m[me[tv]] != other[tv]:
return False
# Tranlsate our constraints using m, and sort
me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints]
me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr])
# Sort other's constraints
other_equiv_constr = sorted([sort_constr(x) for x in other.constraints],
key=lambda y: y[0].name)
return me_equiv_constr == other_equiv_constr
def check_typing(got_or_err, expected, symtab=None):
# type: (TypingOrError, Tuple[VarMap, ConstraintList], Dict[str, Var]) -> None # noqa
"""
Check that a the typying we received (got_or_err) complies with the
expected typing (expected). If symtab is specified, substitute the Vars in
expected using symtab first (used when checking type inference on XForms)
"""
(m, c) = expected
got = get_type_env(got_or_err)
if (symtab is not None):
# For xforms we first need to re-write our TVs in terms of the tvs
# stored internally in the XForm. Use the symtab passed
subst_m = {k.get_typevar(): symtab[str(k)].get_typevar()
for k in m.keys()}
# Convert m from a Var->TypeVar map to TypeVar->TypeVar map where
# the key TypeVar is re-written to its XForm internal version
tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()}
# Rewrite the TVs in the input constraints to their XForm internal
# versions
c = [(subst(a, subst_m), subst(b, subst_m)) for (a, b) in c]
else:
# If no symtab, just convert m from Var->TypeVar map to a
# TypeVar->TypeVar map
tv_m = {k.get_typevar(): v for (k, v) in m.items()}
expected_typ = TypeEnv((tv_m, c))
assert agree(expected_typ, got), \
"typings disagree:\n {} \n {}".format(got.dot(),
expected_typ.dot())
def check_concrete_typing_rtl(var_types, rtl):
# type: (VarMap, Rtl) -> None
"""
Check that a concrete type assignment var_types (Dict[Var, TypeVar]) is
valid for an Rtl rtl. Specifically check that:
1) For each Var v \in rtl, v is defined in var_types
2) For all v, var_types[v] is a singleton type
3) For each v, and each location u, where v is used with expected type
tv_u, var_types[v].get_typeset() is a subset of
subst(tv_u, m).get_typeset() where m is the substitution of
formals->actuals we are building so far.
4) If tv_u is non-derived and not in m, set m[tv_u]= var_types[v]
"""
for d in rtl.rtl:
assert isinstance(d, Def)
inst = d.expr.inst
# Accumulate all actual TVs for value defs/opnums in actual_tvs
actual_tvs = [var_types[d.defs[i]] for i in inst.value_results]
for v in [d.expr.args[i] for i in inst.value_opnums]:
assert isinstance(v, Var)
actual_tvs.append(var_types[v])
# Accumulate all formal TVs for value defs/opnums in actual_tvs
formal_tvs = [inst.outs[i].typevar for i in inst.value_results] +\
[inst.ins[i].typevar for i in inst.value_opnums]
m = {} # type: TypeMap
# For each actual/formal pair check that they agree
for (actual_tv, formal_tv) in zip(actual_tvs, formal_tvs):
# actual should be a singleton
assert actual_tv.singleton_type() is not None
formal_tv = subst(formal_tv, m)
# actual should agree with the concretized formal
assert actual_tv.get_typeset().issubset(formal_tv.get_typeset())
if formal_tv not in m and not formal_tv.is_derived:
m[formal_tv] = actual_tv
def check_concrete_typing_xform(var_types, xform):
# type: (VarMap, XForm) -> None
"""
Check a concrete type assignment var_types for an XForm xform
"""
check_concrete_typing_rtl(var_types, xform.src)
check_concrete_typing_rtl(var_types, xform.dst)
class TypeCheckingBaseTest(TestCase):
def setUp(self):
# type: () -> None
self.v0 = Var("v0")
self.v1 = Var("v1")
self.v2 = Var("v2")
self.v3 = Var("v3")
self.v4 = Var("v4")
self.v5 = Var("v5")
self.v6 = Var("v6")
self.v7 = Var("v7")
self.v8 = Var("v8")
self.v9 = Var("v9")
self.imm0 = Var("imm0")
self.IxN_nonscalar = TypeVar("IxN_nonscalar", "", ints=True,
scalars=False, simd=True)
self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True,
scalars=False, simd=True)
self.b1 = TypeVar.singleton(b1)
class TestRTL(TypeCheckingBaseTest):
def test_bad_rtl1(self):
# type: () -> None
r = Rtl(
(self.v0, self.v1) << vsplit(self.v2),
self.v3 << vconcat(self.v0, self.v2),
)
ti = TypeEnv()
self.assertEqual(ti_rtl(r, ti),
"On line 1: fail ti on `typeof_v2` <: `2`: " +
"Error: empty type created when unifying " +
"`typeof_v2` and `half_vector(typeof_v2)`")
def test_vselect(self):
# type: () -> None
r = Rtl(
self.v0 << vselect(self.v1, self.v2, self.v3),
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
txn = self.TxN.get_fresh_copy("TxN1")
check_typing(typing, ({
self.v0: txn,
self.v1: txn.as_bool(),
self.v2: txn,
self.v3: txn
}, []))
def test_vselect_icmpimm(self):
# type: () -> None
r = Rtl(
self.v0 << iconst(self.imm0),
self.v1 << icmp(intcc.eq, self.v2, self.v0),
self.v5 << vselect(self.v1, self.v3, self.v4),
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
ixn = self.IxN_nonscalar.get_fresh_copy("IxN1")
txn = self.TxN.get_fresh_copy("TxN1")
check_typing(typing, ({
self.v0: ixn,
self.v1: ixn.as_bool(),
self.v2: ixn,
self.v3: txn,
self.v4: txn,
self.v5: txn,
}, [(ixn.as_bool(), txn.as_bool())]))
def test_vselect_vsplits(self):
# type: () -> None
r = Rtl(
self.v3 << vselect(self.v0, self.v1, self.v2),
(self.v4, self.v5) << vsplit(self.v3),
(self.v6, self.v7) << vsplit(self.v4),
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
t = TypeVar("t", "", ints=True, bools=True, floats=True,
simd=(4, 256))
check_typing(typing, ({
self.v0: t.as_bool(),
self.v1: t,
self.v2: t,
self.v3: t,
self.v4: t.half_vector(),
self.v5: t.half_vector(),
self.v6: t.half_vector().half_vector(),
self.v7: t.half_vector().half_vector(),
}, []))
def test_vselect_vconcats(self):
# type: () -> None
r = Rtl(
self.v3 << vselect(self.v0, self.v1, self.v2),
self.v8 << vconcat(self.v3, self.v3),
self.v9 << vconcat(self.v8, self.v8),
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
t = TypeVar("t", "", ints=True, bools=True, floats=True,
simd=(2, 64))
check_typing(typing, ({
self.v0: t.as_bool(),
self.v1: t,
self.v2: t,
self.v3: t,
self.v8: t.double_vector(),
self.v9: t.double_vector().double_vector(),
}, []))
def test_vselect_vsplits_vconcats(self):
# type: () -> None
r = Rtl(
self.v3 << vselect(self.v0, self.v1, self.v2),
(self.v4, self.v5) << vsplit(self.v3),
(self.v6, self.v7) << vsplit(self.v4),
self.v8 << vconcat(self.v3, self.v3),
self.v9 << vconcat(self.v8, self.v8),
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
t = TypeVar("t", "", ints=True, bools=True, floats=True,
simd=(4, 64))
check_typing(typing, ({
self.v0: t.as_bool(),
self.v1: t,
self.v2: t,
self.v3: t,
self.v4: t.half_vector(),
self.v5: t.half_vector(),
self.v6: t.half_vector().half_vector(),
self.v7: t.half_vector().half_vector(),
self.v8: t.double_vector(),
self.v9: t.double_vector().double_vector(),
}, []))
def test_bint(self):
# type: () -> None
r = Rtl(
self.v4 << iadd(self.v1, self.v2),
self.v5 << bint(self.v3),
self.v0 << iadd(self.v4, self.v5)
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
itype = TypeVar("t", "", ints=True, simd=(1, 256))
btype = TypeVar("b", "", bools=True, simd=True)
# Check that self.v5 gets the same integer type as
# the rest of them
# TODO: Add constraint nlanes(v3) == nlanes(v1) when we
# add that type constraint to bint
check_typing(typing, ({
self.v1: itype,
self.v2: itype,
self.v4: itype,
self.v5: itype,
self.v3: btype,
self.v0: itype,
}, []))
class TestXForm(TypeCheckingBaseTest):
def test_iadd_cout(self):
# type: () -> None
x = XForm(Rtl((self.v0, self.v1) << iadd_cout(self.v2, self.v3),),
Rtl(
self.v0 << iadd(self.v2, self.v3),
self.v1 << icmp(intcc.ult, self.v0, self.v2)
))
itype = TypeVar("t", "", ints=True, simd=(1, 1))
check_typing(x.ti, ({
self.v0: itype,
self.v2: itype,
self.v3: itype,
self.v1: itype.as_bool(),
}, []), x.symtab)
def test_iadd_cin(self):
# type: () -> None
x = XForm(Rtl(self.v0 << iadd_cin(self.v1, self.v2, self.v3)),
Rtl(
self.v4 << iadd(self.v1, self.v2),
self.v5 << bint(self.v3),
self.v0 << iadd(self.v4, self.v5)
))
itype = TypeVar("t", "", ints=True, simd=(1, 1))
check_typing(x.ti, ({
self.v0: itype,
self.v1: itype,
self.v2: itype,
self.v3: self.b1,
self.v4: itype,
self.v5: itype,
}, []), x.symtab)
def test_enumeration_with_constraints(self):
# type: () -> None
xform = XForm(
Rtl(
self.v0 << iconst(self.imm0),
self.v1 << icmp(intcc.eq, self.v2, self.v0),
self.v5 << vselect(self.v1, self.v3, self.v4)
),
Rtl(
self.v0 << iconst(self.imm0),
self.v1 << icmp(intcc.eq, self.v2, self.v0),
self.v5 << vselect(self.v1, self.v3, self.v4)
))
# Check all var assigns are correct
assert len(xform.ti.constraints) > 0
concrete_var_assigns = list(xform.ti.concrete_typings())
v0 = xform.symtab[str(self.v0)]
v1 = xform.symtab[str(self.v1)]
v2 = xform.symtab[str(self.v2)]
v3 = xform.symtab[str(self.v3)]
v4 = xform.symtab[str(self.v4)]
v5 = xform.symtab[str(self.v5)]
for var_m in concrete_var_assigns:
assert var_m[v0] == var_m[v2] and \
var_m[v3] == var_m[v4] and\
var_m[v5] == var_m[v3] and\
var_m[v1] == var_m[v2].as_bool() and\
var_m[v1].get_typeset() == var_m[v3].as_bool().get_typeset()
check_concrete_typing_xform(var_m, xform)
# The number of possible typings here is:
# 8 cases for v0 = i8xN times 2 options for v3 - i8, b8 = 16
# 8 cases for v0 = i16xN times 2 options for v3 - i16, b16 = 16
# 8 cases for v0 = i32xN times 3 options for v3 - i32, b32, f32 = 24
# 8 cases for v0 = i64xN times 3 options for v3 - i64, b64, f64 = 24
#
# (Note we have 8 cases for lanes since vselect prevents scalars)
# Total: 2*16 + 2*24 = 80
assert len(concrete_var_assigns) == 80
def test_base_legalizations_enumeration(self):
# type: () -> None
for xform in narrow.xforms + expand.xforms:
# Any legalization patterns we defined should have at least 1
# concrete typing
concrete_typings_list = list(xform.ti.concrete_typings())
assert len(concrete_typings_list) > 0
# If there are no free_typevars, this is a non-polymorphic pattern.
# There should be only one possible concrete typing.
if (len(xform.free_typevars) == 0):
assert len(concrete_typings_list) == 1
continue
# For any patterns where the type env includes constraints, at
# least one of the "theoretically possible" concrete typings must
# be prevented by the constraints. (i.e. we are not emitting
# unneccessary constraints).
# We check that by asserting that the number of concrete typings is
# less than the number of all possible free typevar assignments
if (len(xform.ti.constraints) > 0):
theoretical_num_typings =\
reduce(lambda x, y: x*y,
[tv.get_typeset().size()
for tv in xform.free_typevars], 1)
assert len(concrete_typings_list) < theoretical_num_typings
# Check the validity of each individual concrete typing against the
# xform
for concrete_typing in concrete_typings_list:
check_concrete_typing_xform(concrete_typing, xform)

View File

@@ -125,59 +125,56 @@ class TestTypeSet(TestCase):
self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(),
i32.by(4))
def test_map_inverse(self):
def test_preimage(self):
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
self.assertEqual(t, t.map_inverse(TypeVar.SAMEAS))
self.assertEqual(t, t.preimage(TypeVar.SAMEAS))
# LANEOF
self.assertEqual(TypeSet(lanes=True, ints=(8, 8), floats=(32, 32)),
t.map_inverse(TypeVar.LANEOF))
t.preimage(TypeVar.LANEOF))
# Inverse of empty set is still empty across LANEOF
self.assertEqual(TypeSet(),
TypeSet().map_inverse(TypeVar.LANEOF))
TypeSet().preimage(TypeVar.LANEOF))
# ASBOOL
t = TypeSet(lanes=(1, 4), bools=(1, 64))
self.assertEqual(t.map_inverse(TypeVar.ASBOOL),
self.assertEqual(t.preimage(TypeVar.ASBOOL),
TypeSet(lanes=(1, 4), ints=True, bools=True,
floats=True))
# Inverse image across ASBOOL of TS not involving b1 cannot have
# lanes=1
t = TypeSet(lanes=(1, 4), bools=(16, 32))
self.assertEqual(t.map_inverse(TypeVar.ASBOOL),
TypeSet(lanes=(2, 4), ints=(16, 32), bools=(16, 32),
floats=(32, 32)))
# Half/Double Vector
t = TypeSet(lanes=(1, 1), ints=(8, 8))
t1 = TypeSet(lanes=(256, 256), ints=(8, 8))
self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR).size(), 0)
self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR).size(), 0)
self.assertEqual(t.preimage(TypeVar.DOUBLEVECTOR).size(), 0)
self.assertEqual(t1.preimage(TypeVar.HALFVECTOR).size(), 0)
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 32))
t1 = TypeSet(lanes=(64, 256), bools=(1, 32))
self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR),
self.assertEqual(t.preimage(TypeVar.DOUBLEVECTOR),
TypeSet(lanes=(1, 8), ints=(8, 16), floats=(32, 32)))
self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR),
self.assertEqual(t1.preimage(TypeVar.HALFVECTOR),
TypeSet(lanes=(128, 256), bools=(1, 32)))
# Half/Double Width
t = TypeSet(ints=(8, 8), floats=(32, 32), bools=(1, 8))
t1 = TypeSet(ints=(64, 64), floats=(64, 64), bools=(64, 64))
self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH).size(), 0)
self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH).size(), 0)
self.assertEqual(t.preimage(TypeVar.DOUBLEWIDTH).size(), 0)
self.assertEqual(t1.preimage(TypeVar.HALFWIDTH).size(), 0)
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 64))
t1 = TypeSet(lanes=(64, 256), bools=(1, 64))
self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH),
self.assertEqual(t.preimage(TypeVar.DOUBLEWIDTH),
TypeSet(lanes=(1, 16), ints=(8, 8), floats=(32, 32)))
self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH),
self.assertEqual(t1.preimage(TypeVar.HALFWIDTH),
TypeSet(lanes=(64, 256), bools=(16, 64)))
def has_non_bijective_derived_f(iterable):
return any(not TypeVar.is_bijection(x) for x in iterable)
class TestTypeVar(TestCase):
def test_functions(self):
x = TypeVar('x', 'all ints', ints=True)
@@ -220,7 +217,7 @@ class TestTypeVar(TestCase):
self.assertEqual(len(x.type_set.bools), 0)
def test_stress_constrain_types(self):
# Get all 49 possible derived vars of lentgh 2. Since we have SAMEAS
# Get all 49 possible derived vars of length 2. Since we have SAMEAS
# this includes singly derived and non-derived vars
funcs = [TypeVar.SAMEAS, TypeVar.LANEOF,
TypeVar.ASBOOL, TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR,
@@ -231,18 +228,18 @@ class TestTypeVar(TestCase):
for (i1, i2) in product(v, v):
# Compute the derived sets for each starting with a full typeset
full_ts = TypeSet(lanes=True, floats=True, ints=True, bools=True)
ts1 = reduce(lambda ts, func: ts.map(func), i1, full_ts)
ts2 = reduce(lambda ts, func: ts.map(func), i2, full_ts)
ts1 = reduce(lambda ts, func: ts.image(func), i1, full_ts)
ts2 = reduce(lambda ts, func: ts.image(func), i2, full_ts)
# Compute intersection
intersect = ts1.copy()
intersect &= ts2
# Propagate instersections backward
ts1_src = reduce(lambda ts, func: ts.map_inverse(func),
ts1_src = reduce(lambda ts, func: ts.preimage(func),
reversed(i1),
intersect)
ts2_src = reduce(lambda ts, func: ts.map_inverse(func),
ts2_src = reduce(lambda ts, func: ts.preimage(func),
reversed(i2),
intersect)
@@ -262,13 +259,10 @@ class TestTypeVar(TestCase):
i2,
TypeVar.from_typeset(ts2_src))
# The typesets of the two derived variables should be subsets of
# the intersection we computed originally
assert tv1.get_typeset().issubset(intersect)
assert tv2.get_typeset().issubset(intersect)
# In the absence of AS_BOOL map(map_inverse(f)) == f so the
# In the absence of AS_BOOL image(preimage(f)) == f so the
# typesets of tv1 and tv2 should be exactly intersection
assert (tv1.get_typeset() == tv2.get_typeset() and
tv1.get_typeset() == intersect) or\
TypeVar.ASBOOL in set(i1 + i2)
assert tv1.get_typeset() == intersect or\
has_non_bijective_derived_f(i1)
assert tv2.get_typeset() == intersect or\
has_non_bijective_derived_f(i2)

View File

@@ -0,0 +1,556 @@
"""
Type Inference
"""
from .typevar import TypeVar
from .ast import Def, Var
from copy import copy
from itertools import product
try:
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
from typing import Iterable # noqa
from typing import cast, List
from .xform import Rtl, XForm # noqa
from .ast import Expr # noqa
if TYPE_CHECKING:
Constraint = Tuple[TypeVar, TypeVar]
ConstraintList = List[Constraint]
TypeMap = Dict[TypeVar, TypeVar]
VarMap = Dict[Var, TypeVar]
except ImportError:
TYPE_CHECKING = False
pass
class TypeEnv(object):
"""
Class encapsulating the neccessary book keeping for type inference.
:attribute type_map: dict holding the equivalence relations between tvs
:attribute constraints: a list of accumulated constraints - tuples
(tv1, tv2)) where tv1 and tv2 are equal
:attribute ranks: dictionary recording the (optional) ranks for tvs.
tvs corresponding to real variables have explicitly
specified ranks.
:attribute vars: a set containing all known Vars
:attribute idx: counter used to get fresh ids
"""
def __init__(self, arg=None):
# type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None
self.ranks = {} # type: Dict[TypeVar, int]
self.vars = set() # type: Set[Var]
if arg is None:
self.type_map = {} # type: TypeMap
self.constraints = [] # type: ConstraintList
else:
self.type_map, self.constraints = arg
self.idx = 0
def __getitem__(self, arg):
# type: (Union[TypeVar, Var]) -> TypeVar
"""
Lookup the canonical representative for a Var/TypeVar.
"""
if (isinstance(arg, Var)):
tv = arg.get_typevar()
else:
assert (isinstance(arg, TypeVar))
tv = arg
while tv in self.type_map:
tv = self.type_map[tv]
if tv.is_derived:
tv = TypeVar.derived(self[tv.base], tv.derived_func)
return tv
def equivalent(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
"""
Record a that the free tv1 is part of the same equivalence class as
tv2. The canonical representative of the merged class is tv2's
cannonical representative.
"""
assert not tv1.is_derived
assert self[tv1] == tv1
# Make sure we don't create cycles
if tv2.is_derived:
assert self[tv2.base] != tv1
self.type_map[tv1] = tv2
def add_constraint(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
"""
Add a new equivalence constraint between tv1 and tv2
"""
self.constraints.append((tv1, tv2))
def get_uid(self):
# type: () -> str
r = str(self.idx)
self.idx += 1
return r
def __repr__(self):
# type: () -> str
return self.dot()
def rank(self, tv):
# type: (TypeVar) -> int
"""
Get the rank of tv in the partial order. TVs directly associated with a
Var get their rank from the Var (see register()).
Internally generated non-derived TVs implicitly get the lowest rank (0)
Internal derived variables get the highest rank.
"""
default_rank = 5 if tv.is_derived else 0
return self.ranks.get(tv, default_rank)
def register(self, v):
# type: (Var) -> None
"""
Register a new Var v. This computes a rank for the associated TypeVar
for v, which is used to impose a partial order on type variables.
"""
self.vars.add(v)
if v.is_input():
r = 4
elif v.is_intermediate():
r = 3
elif v.is_output():
r = 2
else:
assert(v.is_temp())
r = 1
self.ranks[v.get_typevar()] = r
def free_typevars(self):
# type: () -> Set[TypeVar]
"""
Get the free typevars in the current type env.
"""
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
# Filter out None here due to singleton type vars
return set(filter(lambda x: x is not None, tvs))
def normalize(self):
# type: () -> None
"""
Normalize by:
- collapsing any roots that don't correspond to a concrete TV AND
have a single TV derived from them or equivalent to them
E.g. if we have a root of the tree that looks like:
typeof_a typeof_b
\ /
typeof_x
|
half_width(1)
|
1
we want to collapse the linear path between 1 and typeof_x. The
resulting graph is:
typeof_a typeof_b
\ /
typeof_x
"""
source_tvs = set([v.get_typevar() for v in self.vars])
children = {} # type: Dict[TypeVar, Set[TypeVar]]
for v in self.type_map.values():
if not v.is_derived:
continue
t = v.free_typevar()
s = children.get(t, set())
s.add(v)
children[t] = s
for (a, b) in self.type_map.items():
s = children.get(b, set())
s.add(a)
children[b] = s
for r in list(self.free_typevars()):
while (r not in source_tvs and r in children and
len(children[r]) == 1):
child = list(children[r])[0]
if child in self.type_map:
assert self.type_map[child] == r
del self.type_map[child]
r = child
def extract(self):
# type: () -> TypeEnv
"""
Extract a clean type environment from self, that only mentions
TVs associated with real variables
"""
vars_tvs = set([v.get_typevar() for v in self.vars])
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
new_constraints = [(self[tv1], self[tv2])
for (tv1, tv2) in self.constraints]
# Sanity: new constraints and the new type_map should only contain
# tvs associated with real vars
for (a, b) in new_constraints:
assert a.free_typevar() in vars_tvs and\
b.free_typevar() in vars_tvs
for (k, v) in new_type_map.items():
assert k in vars_tvs
assert v.free_typevar() is None or v.free_typevar() in vars_tvs
t = TypeEnv()
t.type_map = new_type_map
t.constraints = new_constraints
# ranks and vars contain only TVs associated with real vars
t.ranks = copy(self.ranks)
t.vars = copy(self.vars)
return t
def concrete_typings(self):
# type: () -> Iterable[VarMap]
"""
Return an iterable over all possible concrete typings permitted by this
TypeEnv.
"""
free_tvs = self.free_typevars()
free_tv_iters = [tv.get_typeset().concrete_types() for tv in free_tvs]
for concrete_types in product(*free_tv_iters):
# Build type substitutions for all free vars
m = {tv: TypeVar.singleton(typ)
for (tv, typ) in zip(free_tvs, concrete_types)}
concrete_var_map = {v: subst(self[v.get_typevar()], m)
for v in self.vars}
# Check if constraints are satisfied for this typing
failed = None
for (tv1, tv2) in self.constraints:
tv1 = subst(tv1, m)
tv2 = subst(tv2, m)
assert tv1.get_typeset().size() == 1 and\
tv2.get_typeset().size() == 1
if (tv1.get_typeset() != tv2.get_typeset()):
failed = (tv1, tv2)
break
if (failed is not None):
continue
yield concrete_var_map
def dot(self):
# type: () -> str
"""
Return a representation of self as a graph in dot format.
Nodes correspond to TypeVariables.
Dotted edges correspond to equivalences between TVS
Solid edges correspond to derivation relations between TVs.
Dashed edges correspond to equivalence constraints.
"""
def label(s):
# type: (TypeVar) -> str
return "\"" + str(s) + "\""
# Add all registered TVs (as some of them may be singleton nodes not
# appearing in the graph
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, Optional[str]]]
for (k, v) in self.type_map.items():
# Add all intermediate TVs appearing in edges
nodes.add(k)
nodes.add(v)
edges.add((k, v, "dotted", None))
while (v.is_derived):
nodes.add(v.base)
edges.add((v, v.base, "solid", v.derived_func))
v = v.base
for (a, b) in self.constraints:
assert a in nodes and b in nodes
edges.add((a, b, "dashed", None))
root_nodes = set([x for x in nodes
if x not in self.type_map and not x.is_derived])
r = "digraph {\n"
for n in nodes:
r += label(n)
if n in root_nodes:
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
r += ";\n"
for (n1, n2, style, elabel) in edges:
e = label(n1)
if style == "dashed":
e += '--'
else:
e += '->'
e += label(n2)
e += "[style={}".format(style)
if elabel is not None:
e += ",label={}".format(elabel)
e += "];\n"
r += e
r += "}"
return r
if TYPE_CHECKING:
TypingError = str
TypingOrError = Union[TypeEnv, TypingError]
def get_error(typing_or_err):
# type: (TypingOrError) -> Optional[TypingError]
"""
Helper function to appease mypy when checking the result of typing.
"""
if isinstance(typing_or_err, str):
if (TYPE_CHECKING):
return cast(TypingError, typing_or_err)
else:
return typing_or_err
else:
return None
def get_type_env(typing_or_err):
# type: (TypingOrError) -> TypeEnv
"""
Helper function to appease mypy when checking the result of typing.
"""
assert isinstance(typing_or_err, TypeEnv)
if (TYPE_CHECKING):
return cast(TypeEnv, typing_or_err)
else:
return typing_or_err
def subst(tv, tv_map):
# type: (TypeVar, TypeMap) -> TypeVar
"""
Perform substition on the input tv using the TypeMap tv_map.
"""
if tv in tv_map:
return tv_map[tv]
if tv.is_derived:
return TypeVar.derived(subst(tv.base, tv_map), tv.derived_func)
return tv
def normalize_tv(tv):
# type: (TypeVar) -> TypeVar
"""
Normalize a (potentially derived) TV using the following rules:
- collapse SAMEAS
SAMEAS(base) -> base
- vector and width derived functions commute
{HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) ->
{HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base))
- half/double pairs collapse
{HALF,DOUBLE}WIDTH({DOUBLE,HALF}WIDTH(base)) -> base
{HALF,DOUBLE}VECTOR({DOUBLE,HALF}VECTOR(base)) -> base
"""
vector_derives = [TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR]
width_derives = [TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
if not tv.is_derived:
return tv
df = tv.derived_func
# Collapse SAMEAS edges
if (df == TypeVar.SAMEAS):
return normalize_tv(tv.base)
if (tv.base.is_derived):
base_df = tv.base.derived_func
# Reordering: {HALFWIDTH, DOUBLEWIDTH} commute with {HALFVECTOR,
# DOUBLEVECTOR}. Arbitrarily pick WIDTH < VECTOR
if df in vector_derives and base_df in width_derives:
return normalize_tv(
TypeVar.derived(
TypeVar.derived(tv.base.base, df), base_df))
# Cancelling: HALFWIDTH, DOUBLEWIDTH and HALFVECTOR, DOUBLEVECTOR
# cancel each other. TODO: Does this cancellation hide type
# overflow/underflow?
if (df, base_df) in \
[(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR),
(TypeVar.DOUBLEVECTOR, TypeVar.HALFVECTOR),
(TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH),
(TypeVar.DOUBLEWIDTH, TypeVar.HALFWIDTH)]:
return normalize_tv(tv.base.base)
return TypeVar.derived(normalize_tv(tv.base), df)
def constrain_fixpoint(tv1, tv2):
# type: (TypeVar, TypeVar) -> None
"""
Given typevars tv1 and tv2 (which could be derived from one another)
constrain their typesets to be the same. When one is derived from the
other, repeat the constrain process until fixpoint.
"""
# Constrain tv2's typeset as long as tv1's typeset is changing.
while True:
old_tv1_ts = tv1.get_typeset().copy()
tv2.constrain_types(tv1)
if tv1.get_typeset() == old_tv1_ts:
break
old_tv2_ts = tv2.get_typeset().copy()
tv1.constrain_types(tv2)
assert old_tv2_ts == tv2.get_typeset()
def unify(tv1, tv2, typ):
# type: (TypeVar, TypeVar, TypeEnv) -> TypingOrError
"""
Unify tv1 and tv2 in the current type environment typ, and return an
updated type environment or error.
"""
tv1 = normalize_tv(typ[tv1])
tv2 = normalize_tv(typ[tv2])
# Already unified
if tv1 == tv2:
return typ
if typ.rank(tv2) < typ.rank(tv1):
return unify(tv2, tv1, typ)
constrain_fixpoint(tv1, tv2)
if (tv1.get_typeset().size() == 0 or tv2.get_typeset().size() == 0):
return "Error: empty type created when unifying {} and {}"\
.format(tv1, tv2)
# Free -> Derived(Free)
if not tv1.is_derived:
typ.equivalent(tv1, tv2)
return typ
assert tv2.is_derived, "Ordering gives us !tv1.is_derived==>tv2.is_derived"
if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)):
inv_f = TypeVar.inverse_func(tv1.derived_func)
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
typ.add_constraint(tv1, tv2)
return typ
def ti_def(definition, typ):
# type: (Def, TypeEnv) -> TypingOrError
"""
Perform type inference on one Def in the current type environment typ and
return an updated type environment or error.
At a high level this works by creating fresh copies of each formal type var
in the Def's instruction's signature, and unifying the formal tv with the
corresponding actual tv.
"""
expr = definition.expr
inst = expr.inst
# Create a map m mapping each free typevar in the signature of definition
# to a fresh copy of itself
all_formal_tvs = \
[inst.outs[i].typevar for i in inst.value_results] +\
[inst.ins[i].typevar for i in inst.value_opnums]
free_formal_tvs = [tv for tv in all_formal_tvs if not tv.is_derived]
m = {tv: tv.get_fresh_copy(str(typ.get_uid())) for tv in free_formal_tvs}
# Get fresh copies for each typevar in the signature (both free and
# derived)
fresh_formal_tvs = \
[subst(inst.outs[i].typevar, m) for i in inst.value_results] +\
[subst(inst.ins[i].typevar, m) for i in inst.value_opnums]
# Get the list of actual Vars
actual_vars = [] # type: List[Expr]
actual_vars += [definition.defs[i] for i in inst.value_results]
actual_vars += [expr.args[i] for i in inst.value_opnums]
# Get the list of the actual TypeVars
actual_tvs = []
for v in actual_vars:
assert(isinstance(v, Var))
# Register with TypeEnv that this typevar corresponds ot variable v,
# and thus has a given rank
typ.register(v)
actual_tvs.append(v.get_typevar())
# Unify each actual typevar with the correpsonding fresh formal tv
for (actual_tv, formal_tv) in zip(actual_tvs, fresh_formal_tvs):
typ_or_err = unify(actual_tv, formal_tv, typ)
err = get_error(typ_or_err)
if (err):
return "fail ti on {} <: {}: ".format(actual_tv, formal_tv) + err
typ = get_type_env(typ_or_err)
return typ
def ti_rtl(rtl, typ):
# type: (Rtl, TypeEnv) -> TypingOrError
"""
Perform type inference on an Rtl in a starting type env typ. Return an
updated type environment or error.
"""
for (i, d) in enumerate(rtl.rtl):
assert (isinstance(d, Def))
typ_or_err = ti_def(d, typ)
err = get_error(typ_or_err) # type: Optional[TypingError]
if (err):
return "On line {}: ".format(i) + err
typ = get_type_env(typ_or_err)
return typ
def ti_xform(xform, typ):
# type: (XForm, TypeEnv) -> TypingOrError
"""
Perform type inference on an Rtl in a starting type env typ. Return an
updated type environment or error.
"""
typ_or_err = ti_rtl(xform.src, typ)
err = get_error(typ_or_err) # type: Optional[TypingError]
if (err):
return "In src pattern: " + err
typ = get_type_env(typ_or_err)
typ_or_err = ti_rtl(xform.dst, typ)
err = get_error(typ_or_err)
if (err):
return "In dst pattern: " + err
typ = get_type_env(typ_or_err)
return get_type_env(typ_or_err)

View File

@@ -8,7 +8,6 @@ from __future__ import absolute_import
import math
from . import types, is_power_of_two
from copy import deepcopy
from .types import IntType, FloatType, BoolType
try:
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
@@ -210,6 +209,10 @@ class TypeSet(object):
else:
return False
def __ne__(self, other):
# type: (object) -> bool
return not self.__eq__(other)
def __repr__(self):
# type: () -> str
s = 'TypeSet(lanes={}'.format(pp_set(self.lanes))
@@ -289,7 +292,9 @@ class TypeSet(object):
new = self.copy()
new.ints = set()
new.floats = set()
new.bools = self.ints.union(self.floats).union(self.bools)
if len(self.lanes.difference(set([1]))) > 0:
new.bools = self.ints.union(self.floats).union(self.bools)
if 1 in self.lanes:
new.bools.add(1)
@@ -340,7 +345,7 @@ class TypeSet(object):
return new
def map(self, func):
def image(self, func):
# type: (str) -> TypeSet
"""
Return the image of self across the derived function func
@@ -362,7 +367,7 @@ class TypeSet(object):
else:
assert False, "Unknown derived function: " + func
def map_inverse(self, func):
def preimage(self, func):
# type: (str) -> TypeSet
"""
Return the inverse image of self across the derived function func
@@ -379,16 +384,14 @@ class TypeSet(object):
return new
elif (func == TypeVar.ASBOOL):
new = self.copy()
new.ints = self.bools.difference(set([1]))
new.floats = self.bools.intersection(set([32, 64]))
if 1 not in self.bools:
try:
# If the range doesn't have b1, then the domain can't
# include scalars, as as_bool(scalar)=b1
new.lanes.remove(1)
except KeyError:
pass
new.ints = self.bools.difference(set([1]))
new.floats = self.bools.intersection(set([32, 64]))
else:
new.ints = set([2**x for x in range(3, 7)])
new.floats = set([32, 64])
return new
elif (func == TypeVar.HALFWIDTH):
return self.double_width()
@@ -409,27 +412,32 @@ class TypeSet(object):
return len(self.lanes) * (len(self.ints) + len(self.floats) +
len(self.bools))
def concrete_types(self):
# type: () -> Iterable[types.ValueType]
def by(scalar, lanes):
# type: (types.ScalarType, int) -> types.ValueType
if (lanes == 1):
return scalar
else:
return scalar.by(lanes)
for nlanes in self.lanes:
for bits in self.ints:
yield by(types.IntType.with_bits(bits), nlanes)
for bits in self.floats:
yield by(types.FloatType.with_bits(bits), nlanes)
for bits in self.bools:
yield by(types.BoolType.with_bits(bits), nlanes)
def get_singleton(self):
# type: () -> types.ValueType
"""
Return the singleton type represented by self. Can only call on
typesets containing 1 type.
"""
assert self.size() == 1
scalar_type = None # type: types.ScalarType
if len(self.ints) > 0:
scalar_type = IntType.with_bits(tuple(self.ints)[0])
elif len(self.floats) > 0:
scalar_type = FloatType.with_bits(tuple(self.floats)[0])
else:
scalar_type = BoolType.with_bits(tuple(self.bools)[0])
nlanes = tuple(self.lanes)[0]
if nlanes == 1:
return scalar_type
else:
return scalar_type.by(nlanes)
types = list(self.concrete_types())
assert len(types) == 1
return types[0]
class TypeVar(object):
@@ -519,6 +527,13 @@ class TypeVar(object):
'TypeVar({}, {})'
.format(self.name, self.type_set))
def __hash__(self):
# type: () -> int
if (not self.is_derived):
return object.__hash__(self)
return hash((self.derived_func, self.base))
def __eq__(self, other):
# type: (object) -> bool
if not isinstance(other, TypeVar):
@@ -530,6 +545,10 @@ class TypeVar(object):
else:
return self is other
def __ne__(self, other):
# type: (object) -> bool
return not self.__eq__(other)
# Supported functions for derived type variables.
# The names here must match the method names on `ir::types::Type`.
# The camel_case of the names must match `enum OperandConstraint` in
@@ -542,6 +561,27 @@ class TypeVar(object):
HALFVECTOR = 'half_vector'
DOUBLEVECTOR = 'double_vector'
@staticmethod
def is_bijection(func):
# type: (str) -> bool
return func in [
TypeVar.SAMEAS,
TypeVar.HALFWIDTH,
TypeVar.DOUBLEWIDTH,
TypeVar.HALFVECTOR,
TypeVar.DOUBLEVECTOR]
@staticmethod
def inverse_func(func):
# type: (str) -> str
return {
TypeVar.SAMEAS: TypeVar.SAMEAS,
TypeVar.HALFWIDTH: TypeVar.DOUBLEWIDTH,
TypeVar.DOUBLEWIDTH: TypeVar.HALFWIDTH,
TypeVar.HALFVECTOR: TypeVar.DOUBLEVECTOR,
TypeVar.DOUBLEVECTOR: TypeVar.HALFVECTOR
}[func]
@staticmethod
def derived(base, derived_func):
# type: (TypeVar, str) -> TypeVar
@@ -668,7 +708,7 @@ class TypeVar(object):
Get the free type variable controlling this one.
"""
if self.is_derived:
return self.base
return self.base.free_typevar()
elif self.singleton_type() is not None:
# A singleton type variable is not a proper free variable.
return None
@@ -697,7 +737,7 @@ class TypeVar(object):
if not self.is_derived:
self.type_set &= ts
else:
self.base.constrain_types_by_ts(ts.map_inverse(self.derived_func))
self.base.constrain_types_by_ts(ts.preimage(self.derived_func))
def constrain_types(self, other):
# type: (TypeVar) -> None
@@ -723,19 +763,14 @@ class TypeVar(object):
if not self.is_derived:
return self.type_set
else:
if (self.derived_func == TypeVar.SAMEAS):
return self.base.get_typeset()
elif (self.derived_func == TypeVar.LANEOF):
return self.base.get_typeset().lane_of()
elif (self.derived_func == TypeVar.ASBOOL):
return self.base.get_typeset().as_bool()
elif (self.derived_func == TypeVar.HALFWIDTH):
return self.base.get_typeset().half_width()
elif (self.derived_func == TypeVar.DOUBLEWIDTH):
return self.base.get_typeset().double_width()
elif (self.derived_func == TypeVar.HALFVECTOR):
return self.base.get_typeset().half_vector()
elif (self.derived_func == TypeVar.DOUBLEVECTOR):
return self.base.get_typeset().double_vector()
else:
assert False, "Unknown derived function: " + self.derived_func
return self.base.get_typeset().image(self.derived_func)
def get_fresh_copy(self, name):
# type: (str) -> TypeVar
"""
Get a fresh copy of self. Can only be called on free typevars.
"""
assert not self.is_derived
tv = TypeVar.from_typeset(self.type_set.copy())
tv.name = name
return tv

View File

@@ -3,6 +3,7 @@ Instruction transformations.
"""
from __future__ import absolute_import
from .ast import Def, Var, Apply
from .ti import ti_xform, TypeEnv, get_type_env
try:
from typing import Union, Iterator, Sequence, Iterable, List, Dict # noqa
@@ -83,6 +84,8 @@ class XForm(object):
self._rewrite_rtl(src, symtab, Var.SRCCTX)
num_src_inputs = len(self.inputs)
self._rewrite_rtl(dst, symtab, Var.DSTCTX)
# Needed for testing type inference on XForms
self.symtab = symtab
# Check for inconsistently used inputs.
for i in self.inputs:
@@ -96,9 +99,25 @@ class XForm(object):
"extra inputs in dst RTL: {}".format(
self.inputs[num_src_inputs:]))
self._infer_types(self.src)
self._infer_types(self.dst)
self._collect_typevars()
# Perform type inference and cleanup
raw_ti = get_type_env(ti_xform(self, TypeEnv()))
raw_ti.normalize()
self.ti = raw_ti.extract()
# Sanity: The set of inferred free typevars should be a subset of the
# TVs corresponding to Vars appearing in src
self.free_typevars = self.ti.free_typevars()
src_vars = set(self.inputs).union(
[x for x in self.defs if not x.is_temp()])
src_tvs = set([v.get_typevar() for v in src_vars])
if (not self.free_typevars.issubset(src_tvs)):
raise AssertionError(
"Some free vars don't appear in src - {}"
.format(self.free_typevars.difference(src_tvs)))
# Update the type vars for each Var to their inferred values
for v in self.inputs + self.defs:
v.set_typevar(self.ti[v.get_typevar()])
def __repr__(self):
# type: () -> str
@@ -202,63 +221,6 @@ class XForm(object):
raise AssertionError(
'{} not defined in dest pattern'.format(d))
def _infer_types(self, rtl):
# type: (Rtl) -> None
"""Assign type variables to all value variables used in `rtl`."""
for d in rtl.rtl:
inst = d.expr.inst
# Get the Var corresponding to the controlling type variable.
ctrl_var = None # type: Var
if inst.is_polymorphic:
if inst.use_typevar_operand:
# Should this be an assertion instead?
# Should all value operands be required to be Vars?
arg = d.expr.args[inst.format.typevar_operand]
if isinstance(arg, Var):
ctrl_var = arg
else:
ctrl_var = d.defs[inst.value_results[0]]
# Reconcile arguments with the requirements of `inst`.
for opnum in inst.value_opnums:
inst_tv = inst.ins[opnum].typevar
v = d.expr.args[opnum]
if isinstance(v, Var):
v.constrain_typevar(inst_tv, inst.ctrl_typevar, ctrl_var)
# Reconcile results with the requirements of `inst`.
for resnum in inst.value_results:
inst_tv = inst.outs[resnum].typevar
v = d.defs[resnum]
v.constrain_typevar(inst_tv, inst.ctrl_typevar, ctrl_var)
def _collect_typevars(self):
# type: () -> None
"""
Collect a list of variables whose type can be used to infer the types
of all expressions.
This should be called after `_infer_types()` above has computed type
variables for all the used vars.
"""
fvars = list(v for v in self.inputs if v.has_free_typevar())
fvars += list(v for v in self.defs if v.has_free_typevar())
self.free_typevars = fvars
# When substituting a pattern, we know the types of all variables that
# appear on the source side: inut, output, and intermediate values.
# However, temporary values which appear only on the destination side
# must have their type computed somehow.
#
# Some variables have a fixed type which appears as a type variable
# with a singleton_type field set. That's allowed for temps too.
for v in fvars:
if v.is_temp() and not v.typevar.singleton_type():
raise AssertionError(
"Cannot determine type of temp '{}' in xform:\n{}"
.format(v, self))
class XFormGroup(object):
"""