diff --git a/lib/cretonne/meta/cdsl/ast.py b/lib/cretonne/meta/cdsl/ast.py index 2b671fc46e..6efc492cf5 100644 --- a/lib/cretonne/meta/cdsl/ast.py +++ b/lib/cretonne/meta/cdsl/ast.py @@ -100,8 +100,8 @@ class Var(Expr): # TypeVar representing the type of this variable. self.typevar = None # type: TypeVar # The original 'typeof(x)' type variable that was created for this Var. - # This one doesn't change. `self.typevar` above may be joined with - # other typevars. + # This one doesn't change. `self.typevar` above may be changed to + # another typevar by type inference. self.original_typevar = None # type: TypeVar def __str__(self): @@ -180,16 +180,9 @@ class Var(Expr): self.typevar = tv return self.typevar - def link_typevar(self, base, derived_func): - # type: (TypeVar, str) -> None - """ - Link the type variable on this Var to the type variable `base` using - `derived_func`. - """ - self.original_typevar = None - self.typevar.change_to_derived(base, derived_func) - # Possibly eliminate redundant SAMEAS links. - self.typevar = self.typevar.strip_sameas() + def set_typevar(self, tv): + # type: (TypeVar) -> None + self.typevar = tv def has_free_typevar(self): # type: () -> bool @@ -213,136 +206,6 @@ class Var(Expr): """ return self.typevar.rust_expr() - def constrain_typevar(self, sym_typevar, sym_ctrl, ctrl_var): - # type: (TypeVar, TypeVar, Var) -> None - """ - Constrain the set of allowed types for this variable. - - Merge type variables for the involved variables to minimize the set for - free type variables. - - Suppose we're looking at an instruction defined like this: - - c = Operand('c', TxN.as_bool()) - x = Operand('x', TxN) - y = Operand('y', TxN) - a = Operand('a', TxN) - vselect = Instruction('vselect', ins=(c, x, y), outs=a) - - And suppose the instruction is used in a pattern like this: - - v0 << vselect(v1, v2, v3) - - We want to reconcile the types of the variables v0-v3 with the - constraints from the definition of vselect. This means that v0, v2, and - v3 must all have the same type, and v1 must have the type - `typeof(v2).as_bool()`. - - The types are reconciled by calling this function once for each - input/output operand on the instruction in the pattern with these - arguments. - - :param sym_typevar: Symbolic type variable constraining this variable - in the definition of the instruction. - :param sym_ctrl: Controlling type variable of `sym_typevar` in the - definition of the instruction. - :param ctrl_var: Variable determining the type of `sym_ctrl`. - - When processing `v1` as used in the pattern above, we would get: - - - self: v1 - - sym_typevar: TxN.as_bool() - - sym_ctrl: TxN - - ctrl_var: v2 - - Here, 'v2' represents the controlling variable because of how the - `Ternary` instruction format is defined with `typevar_operand=1`. - """ - # First check if sym_typevar is tied to the controlling type variable - # in the instruction definition. We also allow free type variables on - # instruction inputs that can't be tied to anything else. - # - # This also covers non-polymorphic instructions and other cases where - # we don't have a Var representing the controlling type variable. - sym_free_var = sym_typevar.free_typevar() - if not sym_free_var or sym_free_var is not sym_ctrl or not ctrl_var: - # Just constrain our type to be compatible with the required - # typeset. - self.get_typevar().constrain_types(sym_typevar) - return - - # Now sym_typevar is known to be tied to (or identical to) the - # controlling type variable. - - if not self.typevar: - # If this variable is not yet constrained, just infer its type and - # link it to the controlling type variable. - if not sym_typevar.is_derived: - assert sym_typevar is sym_ctrl - # Identity mapping. - # Note that `self == ctrl_var` is both possible and common. - self.typevar = ctrl_var.get_typevar() - else: - assert self is not ctrl_var, ( - 'Impossible type constraints for {}: {}' - .format(self, sym_typevar)) - # Create a derived type variable identical to sym_typevar, but - # with a different base. - self.typevar = TypeVar.derived( - ctrl_var.get_typevar(), - sym_typevar.derived_func) - # Match the type set constraints of the instruction. - self.typevar.constrain_types(sym_typevar) - return - - # We already have a self.typevar describing our constraints. We need to - # reconcile with the additional constraints. - - # It's likely that ctrl_var and self already share a type - # variable. (Often because `ctrl_var == self`). - if ctrl_var.typevar == self.typevar: - return - - if not sym_typevar.is_derived: - assert sym_typevar is sym_ctrl - # sym_typevar is a direct use of sym_ctrl, so we need to reconcile - # self with ctrl_var. - assert not sym_typevar.is_derived - self.typevar.constrain_types(sym_typevar) - - # It's possible that ctrl_var has not yet been assigned a type - # variable. - if not ctrl_var.typevar: - ctrl_var.typevar = self.typevar - return - - # We can also bind variables with a free type variable to another - # variable. Prefer to do this to temps because they aren't allowed - # to be free, - if self.is_temp() and self.has_free_typevar(): - self.link_typevar(ctrl_var.typevar, TypeVar.SAMEAS) - return - if ctrl_var.is_temp() and ctrl_var.has_free_typevar(): - ctrl_var.link_typevar(self.typevar, TypeVar.SAMEAS) - return - if self.has_free_typevar(): - self.link_typevar(ctrl_var.typevar, TypeVar.SAMEAS) - return - if ctrl_var.has_free_typevar(): - ctrl_var.link_typevar(self.typevar, TypeVar.SAMEAS) - return - - # TODO: Other cases are harder to handle. - # - # - If either variable is an independent free type variable, it - # should be changed to be linked to the other. - # - If both variable are free, we should pick one to link to the - # other. In particular, if one is a temp, it should be linked. - else: - # sym_typevar is derived from sym_ctrl. - # TODO: Other cases are harder to handle. - pass - class Apply(Expr): """ diff --git a/lib/cretonne/meta/cdsl/test_ti.py b/lib/cretonne/meta/cdsl/test_ti.py new file mode 100644 index 0000000000..d97902568a --- /dev/null +++ b/lib/cretonne/meta/cdsl/test_ti.py @@ -0,0 +1,432 @@ +from __future__ import absolute_import +from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\ + b1, icmp, iadd_cout, iadd_cin +from base.legalize import narrow, expand +from base.immediates import intcc +from .typevar import TypeVar +from .ast import Var, Def +from .xform import Rtl, XForm +from .ti import ti_rtl, subst, TypeEnv, get_type_env +from unittest import TestCase +from functools import reduce + +try: + from .ti import TypeMap, ConstraintList, VarMap, TypingOrError # noqa + from .ti import Constraint + from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa +except ImportError: + TYPE_CHECKING = False + + +def sort_constr(c): + # type: (Constraint) -> Constraint + """ + Sort the 2 typevars in a constraint by name for comparison + """ + r = tuple(sorted(c, key=lambda y: y.name)) + if TYPE_CHECKING: + return cast(Constraint, r) + else: + return r + + +def agree(me, other): + # type: (TypeEnv, TypeEnv) -> bool + """ + Given TypeEnvs me and other, check if they agree. As part of that build + a map m from TVs in me to their corresponding TVs in other. + Specifically: + + 1. Check that all TVs that are keys in me.type_map are also defined + in other.type_map + + 2. For any tv in me.type_map check that: + me[tv].get_typeset() == other[tv].get_typeset() + + 3. Set m[me[tv]] = other[tv] in the substitution m + + 4. If we find another tv1 such that me[tv1] == me[tv], assert that + other[tv1] == m[me[tv1]] == m[me[tv]] = other[tv] + + 5. Check that me and other have the same constraints under the + substitution m + """ + m = {} # type: TypeMap + # Check that our type map and other's agree and built substitution m + for tv in me.type_map: + if (me[tv] not in m): + m[me[tv]] = other[tv] + if me[tv].get_typeset() != other[tv].get_typeset(): + return False + else: + if m[me[tv]] != other[tv]: + return False + + # Tranlsate our constraints using m, and sort + me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints] + me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr]) + + # Sort other's constraints + other_equiv_constr = sorted([sort_constr(x) for x in other.constraints], + key=lambda y: y[0].name) + + return me_equiv_constr == other_equiv_constr + + +def check_typing(got_or_err, expected, symtab=None): + # type: (TypingOrError, Tuple[VarMap, ConstraintList], Dict[str, Var]) -> None # noqa + """ + Check that a the typying we received (got_or_err) complies with the + expected typing (expected). If symtab is specified, substitute the Vars in + expected using symtab first (used when checking type inference on XForms) + """ + (m, c) = expected + got = get_type_env(got_or_err) + + if (symtab is not None): + # For xforms we first need to re-write our TVs in terms of the tvs + # stored internally in the XForm. Use the symtab passed + subst_m = {k.get_typevar(): symtab[str(k)].get_typevar() + for k in m.keys()} + # Convert m from a Var->TypeVar map to TypeVar->TypeVar map where + # the key TypeVar is re-written to its XForm internal version + tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()} + # Rewrite the TVs in the input constraints to their XForm internal + # versions + c = [(subst(a, subst_m), subst(b, subst_m)) for (a, b) in c] + else: + # If no symtab, just convert m from Var->TypeVar map to a + # TypeVar->TypeVar map + tv_m = {k.get_typevar(): v for (k, v) in m.items()} + + expected_typ = TypeEnv((tv_m, c)) + assert agree(expected_typ, got), \ + "typings disagree:\n {} \n {}".format(got.dot(), + expected_typ.dot()) + + +def check_concrete_typing_rtl(var_types, rtl): + # type: (VarMap, Rtl) -> None + """ + Check that a concrete type assignment var_types (Dict[Var, TypeVar]) is + valid for an Rtl rtl. Specifically check that: + + 1) For each Var v \in rtl, v is defined in var_types + + 2) For all v, var_types[v] is a singleton type + + 3) For each v, and each location u, where v is used with expected type + tv_u, var_types[v].get_typeset() is a subset of + subst(tv_u, m).get_typeset() where m is the substitution of + formals->actuals we are building so far. + + 4) If tv_u is non-derived and not in m, set m[tv_u]= var_types[v] + """ + for d in rtl.rtl: + assert isinstance(d, Def) + inst = d.expr.inst + # Accumulate all actual TVs for value defs/opnums in actual_tvs + actual_tvs = [var_types[d.defs[i]] for i in inst.value_results] + for v in [d.expr.args[i] for i in inst.value_opnums]: + assert isinstance(v, Var) + actual_tvs.append(var_types[v]) + + # Accumulate all formal TVs for value defs/opnums in actual_tvs + formal_tvs = [inst.outs[i].typevar for i in inst.value_results] +\ + [inst.ins[i].typevar for i in inst.value_opnums] + m = {} # type: TypeMap + + # For each actual/formal pair check that they agree + for (actual_tv, formal_tv) in zip(actual_tvs, formal_tvs): + # actual should be a singleton + assert actual_tv.singleton_type() is not None + formal_tv = subst(formal_tv, m) + # actual should agree with the concretized formal + assert actual_tv.get_typeset().issubset(formal_tv.get_typeset()) + + if formal_tv not in m and not formal_tv.is_derived: + m[formal_tv] = actual_tv + + +def check_concrete_typing_xform(var_types, xform): + # type: (VarMap, XForm) -> None + """ + Check a concrete type assignment var_types for an XForm xform + """ + check_concrete_typing_rtl(var_types, xform.src) + check_concrete_typing_rtl(var_types, xform.dst) + + +class TypeCheckingBaseTest(TestCase): + def setUp(self): + # type: () -> None + self.v0 = Var("v0") + self.v1 = Var("v1") + self.v2 = Var("v2") + self.v3 = Var("v3") + self.v4 = Var("v4") + self.v5 = Var("v5") + self.v6 = Var("v6") + self.v7 = Var("v7") + self.v8 = Var("v8") + self.v9 = Var("v9") + self.imm0 = Var("imm0") + self.IxN_nonscalar = TypeVar("IxN_nonscalar", "", ints=True, + scalars=False, simd=True) + self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True, + scalars=False, simd=True) + self.b1 = TypeVar.singleton(b1) + + +class TestRTL(TypeCheckingBaseTest): + def test_bad_rtl1(self): + # type: () -> None + r = Rtl( + (self.v0, self.v1) << vsplit(self.v2), + self.v3 << vconcat(self.v0, self.v2), + ) + ti = TypeEnv() + self.assertEqual(ti_rtl(r, ti), + "On line 1: fail ti on `typeof_v2` <: `2`: " + + "Error: empty type created when unifying " + + "`typeof_v2` and `half_vector(typeof_v2)`") + + def test_vselect(self): + # type: () -> None + r = Rtl( + self.v0 << vselect(self.v1, self.v2, self.v3), + ) + ti = TypeEnv() + typing = ti_rtl(r, ti) + txn = self.TxN.get_fresh_copy("TxN1") + check_typing(typing, ({ + self.v0: txn, + self.v1: txn.as_bool(), + self.v2: txn, + self.v3: txn + }, [])) + + def test_vselect_icmpimm(self): + # type: () -> None + r = Rtl( + self.v0 << iconst(self.imm0), + self.v1 << icmp(intcc.eq, self.v2, self.v0), + self.v5 << vselect(self.v1, self.v3, self.v4), + ) + ti = TypeEnv() + typing = ti_rtl(r, ti) + ixn = self.IxN_nonscalar.get_fresh_copy("IxN1") + txn = self.TxN.get_fresh_copy("TxN1") + check_typing(typing, ({ + self.v0: ixn, + self.v1: ixn.as_bool(), + self.v2: ixn, + self.v3: txn, + self.v4: txn, + self.v5: txn, + }, [(ixn.as_bool(), txn.as_bool())])) + + def test_vselect_vsplits(self): + # type: () -> None + r = Rtl( + self.v3 << vselect(self.v0, self.v1, self.v2), + (self.v4, self.v5) << vsplit(self.v3), + (self.v6, self.v7) << vsplit(self.v4), + ) + ti = TypeEnv() + typing = ti_rtl(r, ti) + t = TypeVar("t", "", ints=True, bools=True, floats=True, + simd=(4, 256)) + check_typing(typing, ({ + self.v0: t.as_bool(), + self.v1: t, + self.v2: t, + self.v3: t, + self.v4: t.half_vector(), + self.v5: t.half_vector(), + self.v6: t.half_vector().half_vector(), + self.v7: t.half_vector().half_vector(), + }, [])) + + def test_vselect_vconcats(self): + # type: () -> None + r = Rtl( + self.v3 << vselect(self.v0, self.v1, self.v2), + self.v8 << vconcat(self.v3, self.v3), + self.v9 << vconcat(self.v8, self.v8), + ) + ti = TypeEnv() + typing = ti_rtl(r, ti) + t = TypeVar("t", "", ints=True, bools=True, floats=True, + simd=(2, 64)) + check_typing(typing, ({ + self.v0: t.as_bool(), + self.v1: t, + self.v2: t, + self.v3: t, + self.v8: t.double_vector(), + self.v9: t.double_vector().double_vector(), + }, [])) + + def test_vselect_vsplits_vconcats(self): + # type: () -> None + r = Rtl( + self.v3 << vselect(self.v0, self.v1, self.v2), + (self.v4, self.v5) << vsplit(self.v3), + (self.v6, self.v7) << vsplit(self.v4), + self.v8 << vconcat(self.v3, self.v3), + self.v9 << vconcat(self.v8, self.v8), + ) + ti = TypeEnv() + typing = ti_rtl(r, ti) + t = TypeVar("t", "", ints=True, bools=True, floats=True, + simd=(4, 64)) + check_typing(typing, ({ + self.v0: t.as_bool(), + self.v1: t, + self.v2: t, + self.v3: t, + self.v4: t.half_vector(), + self.v5: t.half_vector(), + self.v6: t.half_vector().half_vector(), + self.v7: t.half_vector().half_vector(), + self.v8: t.double_vector(), + self.v9: t.double_vector().double_vector(), + }, [])) + + def test_bint(self): + # type: () -> None + r = Rtl( + self.v4 << iadd(self.v1, self.v2), + self.v5 << bint(self.v3), + self.v0 << iadd(self.v4, self.v5) + ) + ti = TypeEnv() + typing = ti_rtl(r, ti) + itype = TypeVar("t", "", ints=True, simd=(1, 256)) + btype = TypeVar("b", "", bools=True, simd=True) + + # Check that self.v5 gets the same integer type as + # the rest of them + # TODO: Add constraint nlanes(v3) == nlanes(v1) when we + # add that type constraint to bint + check_typing(typing, ({ + self.v1: itype, + self.v2: itype, + self.v4: itype, + self.v5: itype, + self.v3: btype, + self.v0: itype, + }, [])) + + +class TestXForm(TypeCheckingBaseTest): + def test_iadd_cout(self): + # type: () -> None + x = XForm(Rtl((self.v0, self.v1) << iadd_cout(self.v2, self.v3),), + Rtl( + self.v0 << iadd(self.v2, self.v3), + self.v1 << icmp(intcc.ult, self.v0, self.v2) + )) + itype = TypeVar("t", "", ints=True, simd=(1, 1)) + + check_typing(x.ti, ({ + self.v0: itype, + self.v2: itype, + self.v3: itype, + self.v1: itype.as_bool(), + }, []), x.symtab) + + def test_iadd_cin(self): + # type: () -> None + x = XForm(Rtl(self.v0 << iadd_cin(self.v1, self.v2, self.v3)), + Rtl( + self.v4 << iadd(self.v1, self.v2), + self.v5 << bint(self.v3), + self.v0 << iadd(self.v4, self.v5) + )) + itype = TypeVar("t", "", ints=True, simd=(1, 1)) + + check_typing(x.ti, ({ + self.v0: itype, + self.v1: itype, + self.v2: itype, + self.v3: self.b1, + self.v4: itype, + self.v5: itype, + }, []), x.symtab) + + def test_enumeration_with_constraints(self): + # type: () -> None + xform = XForm( + Rtl( + self.v0 << iconst(self.imm0), + self.v1 << icmp(intcc.eq, self.v2, self.v0), + self.v5 << vselect(self.v1, self.v3, self.v4) + ), + Rtl( + self.v0 << iconst(self.imm0), + self.v1 << icmp(intcc.eq, self.v2, self.v0), + self.v5 << vselect(self.v1, self.v3, self.v4) + )) + + # Check all var assigns are correct + assert len(xform.ti.constraints) > 0 + concrete_var_assigns = list(xform.ti.concrete_typings()) + + v0 = xform.symtab[str(self.v0)] + v1 = xform.symtab[str(self.v1)] + v2 = xform.symtab[str(self.v2)] + v3 = xform.symtab[str(self.v3)] + v4 = xform.symtab[str(self.v4)] + v5 = xform.symtab[str(self.v5)] + + for var_m in concrete_var_assigns: + assert var_m[v0] == var_m[v2] and \ + var_m[v3] == var_m[v4] and\ + var_m[v5] == var_m[v3] and\ + var_m[v1] == var_m[v2].as_bool() and\ + var_m[v1].get_typeset() == var_m[v3].as_bool().get_typeset() + check_concrete_typing_xform(var_m, xform) + + # The number of possible typings here is: + # 8 cases for v0 = i8xN times 2 options for v3 - i8, b8 = 16 + # 8 cases for v0 = i16xN times 2 options for v3 - i16, b16 = 16 + # 8 cases for v0 = i32xN times 3 options for v3 - i32, b32, f32 = 24 + # 8 cases for v0 = i64xN times 3 options for v3 - i64, b64, f64 = 24 + # + # (Note we have 8 cases for lanes since vselect prevents scalars) + # Total: 2*16 + 2*24 = 80 + assert len(concrete_var_assigns) == 80 + + def test_base_legalizations_enumeration(self): + # type: () -> None + for xform in narrow.xforms + expand.xforms: + # Any legalization patterns we defined should have at least 1 + # concrete typing + concrete_typings_list = list(xform.ti.concrete_typings()) + assert len(concrete_typings_list) > 0 + + # If there are no free_typevars, this is a non-polymorphic pattern. + # There should be only one possible concrete typing. + if (len(xform.free_typevars) == 0): + assert len(concrete_typings_list) == 1 + continue + + # For any patterns where the type env includes constraints, at + # least one of the "theoretically possible" concrete typings must + # be prevented by the constraints. (i.e. we are not emitting + # unneccessary constraints). + # We check that by asserting that the number of concrete typings is + # less than the number of all possible free typevar assignments + if (len(xform.ti.constraints) > 0): + theoretical_num_typings =\ + reduce(lambda x, y: x*y, + [tv.get_typeset().size() + for tv in xform.free_typevars], 1) + assert len(concrete_typings_list) < theoretical_num_typings + + # Check the validity of each individual concrete typing against the + # xform + for concrete_typing in concrete_typings_list: + check_concrete_typing_xform(concrete_typing, xform) diff --git a/lib/cretonne/meta/cdsl/test_typevar.py b/lib/cretonne/meta/cdsl/test_typevar.py index f655c8966d..d990d7804f 100644 --- a/lib/cretonne/meta/cdsl/test_typevar.py +++ b/lib/cretonne/meta/cdsl/test_typevar.py @@ -125,59 +125,56 @@ class TestTypeSet(TestCase): self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(), i32.by(4)) - def test_map_inverse(self): + def test_preimage(self): t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32)) - self.assertEqual(t, t.map_inverse(TypeVar.SAMEAS)) + self.assertEqual(t, t.preimage(TypeVar.SAMEAS)) # LANEOF self.assertEqual(TypeSet(lanes=True, ints=(8, 8), floats=(32, 32)), - t.map_inverse(TypeVar.LANEOF)) + t.preimage(TypeVar.LANEOF)) # Inverse of empty set is still empty across LANEOF self.assertEqual(TypeSet(), - TypeSet().map_inverse(TypeVar.LANEOF)) + TypeSet().preimage(TypeVar.LANEOF)) # ASBOOL t = TypeSet(lanes=(1, 4), bools=(1, 64)) - self.assertEqual(t.map_inverse(TypeVar.ASBOOL), + self.assertEqual(t.preimage(TypeVar.ASBOOL), TypeSet(lanes=(1, 4), ints=True, bools=True, floats=True)) - # Inverse image across ASBOOL of TS not involving b1 cannot have - # lanes=1 - t = TypeSet(lanes=(1, 4), bools=(16, 32)) - self.assertEqual(t.map_inverse(TypeVar.ASBOOL), - TypeSet(lanes=(2, 4), ints=(16, 32), bools=(16, 32), - floats=(32, 32))) - # Half/Double Vector t = TypeSet(lanes=(1, 1), ints=(8, 8)) t1 = TypeSet(lanes=(256, 256), ints=(8, 8)) - self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR).size(), 0) - self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR).size(), 0) + self.assertEqual(t.preimage(TypeVar.DOUBLEVECTOR).size(), 0) + self.assertEqual(t1.preimage(TypeVar.HALFVECTOR).size(), 0) t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 32)) t1 = TypeSet(lanes=(64, 256), bools=(1, 32)) - self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR), + self.assertEqual(t.preimage(TypeVar.DOUBLEVECTOR), TypeSet(lanes=(1, 8), ints=(8, 16), floats=(32, 32))) - self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR), + self.assertEqual(t1.preimage(TypeVar.HALFVECTOR), TypeSet(lanes=(128, 256), bools=(1, 32))) # Half/Double Width t = TypeSet(ints=(8, 8), floats=(32, 32), bools=(1, 8)) t1 = TypeSet(ints=(64, 64), floats=(64, 64), bools=(64, 64)) - self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH).size(), 0) - self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH).size(), 0) + self.assertEqual(t.preimage(TypeVar.DOUBLEWIDTH).size(), 0) + self.assertEqual(t1.preimage(TypeVar.HALFWIDTH).size(), 0) t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 64)) t1 = TypeSet(lanes=(64, 256), bools=(1, 64)) - self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH), + self.assertEqual(t.preimage(TypeVar.DOUBLEWIDTH), TypeSet(lanes=(1, 16), ints=(8, 8), floats=(32, 32))) - self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH), + self.assertEqual(t1.preimage(TypeVar.HALFWIDTH), TypeSet(lanes=(64, 256), bools=(16, 64))) +def has_non_bijective_derived_f(iterable): + return any(not TypeVar.is_bijection(x) for x in iterable) + + class TestTypeVar(TestCase): def test_functions(self): x = TypeVar('x', 'all ints', ints=True) @@ -220,7 +217,7 @@ class TestTypeVar(TestCase): self.assertEqual(len(x.type_set.bools), 0) def test_stress_constrain_types(self): - # Get all 49 possible derived vars of lentgh 2. Since we have SAMEAS + # Get all 49 possible derived vars of length 2. Since we have SAMEAS # this includes singly derived and non-derived vars funcs = [TypeVar.SAMEAS, TypeVar.LANEOF, TypeVar.ASBOOL, TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR, @@ -231,18 +228,18 @@ class TestTypeVar(TestCase): for (i1, i2) in product(v, v): # Compute the derived sets for each starting with a full typeset full_ts = TypeSet(lanes=True, floats=True, ints=True, bools=True) - ts1 = reduce(lambda ts, func: ts.map(func), i1, full_ts) - ts2 = reduce(lambda ts, func: ts.map(func), i2, full_ts) + ts1 = reduce(lambda ts, func: ts.image(func), i1, full_ts) + ts2 = reduce(lambda ts, func: ts.image(func), i2, full_ts) # Compute intersection intersect = ts1.copy() intersect &= ts2 # Propagate instersections backward - ts1_src = reduce(lambda ts, func: ts.map_inverse(func), + ts1_src = reduce(lambda ts, func: ts.preimage(func), reversed(i1), intersect) - ts2_src = reduce(lambda ts, func: ts.map_inverse(func), + ts2_src = reduce(lambda ts, func: ts.preimage(func), reversed(i2), intersect) @@ -262,13 +259,10 @@ class TestTypeVar(TestCase): i2, TypeVar.from_typeset(ts2_src)) - # The typesets of the two derived variables should be subsets of - # the intersection we computed originally - assert tv1.get_typeset().issubset(intersect) - assert tv2.get_typeset().issubset(intersect) - - # In the absence of AS_BOOL map(map_inverse(f)) == f so the + # In the absence of AS_BOOL image(preimage(f)) == f so the # typesets of tv1 and tv2 should be exactly intersection - assert (tv1.get_typeset() == tv2.get_typeset() and - tv1.get_typeset() == intersect) or\ - TypeVar.ASBOOL in set(i1 + i2) + assert tv1.get_typeset() == intersect or\ + has_non_bijective_derived_f(i1) + + assert tv2.get_typeset() == intersect or\ + has_non_bijective_derived_f(i2) diff --git a/lib/cretonne/meta/cdsl/ti.py b/lib/cretonne/meta/cdsl/ti.py new file mode 100644 index 0000000000..5b376f1e5d --- /dev/null +++ b/lib/cretonne/meta/cdsl/ti.py @@ -0,0 +1,556 @@ +""" +Type Inference +""" +from .typevar import TypeVar +from .ast import Def, Var +from copy import copy +from itertools import product + +try: + from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa + from typing import Iterable # noqa + from typing import cast, List + from .xform import Rtl, XForm # noqa + from .ast import Expr # noqa + if TYPE_CHECKING: + Constraint = Tuple[TypeVar, TypeVar] + ConstraintList = List[Constraint] + TypeMap = Dict[TypeVar, TypeVar] + VarMap = Dict[Var, TypeVar] +except ImportError: + TYPE_CHECKING = False + pass + + +class TypeEnv(object): + """ + Class encapsulating the neccessary book keeping for type inference. + :attribute type_map: dict holding the equivalence relations between tvs + :attribute constraints: a list of accumulated constraints - tuples + (tv1, tv2)) where tv1 and tv2 are equal + :attribute ranks: dictionary recording the (optional) ranks for tvs. + tvs corresponding to real variables have explicitly + specified ranks. + :attribute vars: a set containing all known Vars + :attribute idx: counter used to get fresh ids + """ + def __init__(self, arg=None): + # type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None + self.ranks = {} # type: Dict[TypeVar, int] + self.vars = set() # type: Set[Var] + + if arg is None: + self.type_map = {} # type: TypeMap + self.constraints = [] # type: ConstraintList + else: + self.type_map, self.constraints = arg + + self.idx = 0 + + def __getitem__(self, arg): + # type: (Union[TypeVar, Var]) -> TypeVar + """ + Lookup the canonical representative for a Var/TypeVar. + """ + if (isinstance(arg, Var)): + tv = arg.get_typevar() + else: + assert (isinstance(arg, TypeVar)) + tv = arg + + while tv in self.type_map: + tv = self.type_map[tv] + + if tv.is_derived: + tv = TypeVar.derived(self[tv.base], tv.derived_func) + return tv + + def equivalent(self, tv1, tv2): + # type: (TypeVar, TypeVar) -> None + """ + Record a that the free tv1 is part of the same equivalence class as + tv2. The canonical representative of the merged class is tv2's + cannonical representative. + """ + assert not tv1.is_derived + assert self[tv1] == tv1 + + # Make sure we don't create cycles + if tv2.is_derived: + assert self[tv2.base] != tv1 + + self.type_map[tv1] = tv2 + + def add_constraint(self, tv1, tv2): + # type: (TypeVar, TypeVar) -> None + """ + Add a new equivalence constraint between tv1 and tv2 + """ + self.constraints.append((tv1, tv2)) + + def get_uid(self): + # type: () -> str + r = str(self.idx) + self.idx += 1 + return r + + def __repr__(self): + # type: () -> str + return self.dot() + + def rank(self, tv): + # type: (TypeVar) -> int + """ + Get the rank of tv in the partial order. TVs directly associated with a + Var get their rank from the Var (see register()). + Internally generated non-derived TVs implicitly get the lowest rank (0) + Internal derived variables get the highest rank. + """ + default_rank = 5 if tv.is_derived else 0 + return self.ranks.get(tv, default_rank) + + def register(self, v): + # type: (Var) -> None + """ + Register a new Var v. This computes a rank for the associated TypeVar + for v, which is used to impose a partial order on type variables. + """ + self.vars.add(v) + + if v.is_input(): + r = 4 + elif v.is_intermediate(): + r = 3 + elif v.is_output(): + r = 2 + else: + assert(v.is_temp()) + r = 1 + + self.ranks[v.get_typevar()] = r + + def free_typevars(self): + # type: () -> Set[TypeVar] + """ + Get the free typevars in the current type env. + """ + tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()]) + # Filter out None here due to singleton type vars + return set(filter(lambda x: x is not None, tvs)) + + def normalize(self): + # type: () -> None + """ + Normalize by: + - collapsing any roots that don't correspond to a concrete TV AND + have a single TV derived from them or equivalent to them + + E.g. if we have a root of the tree that looks like: + + typeof_a typeof_b + \ / + typeof_x + | + half_width(1) + | + 1 + + we want to collapse the linear path between 1 and typeof_x. The + resulting graph is: + + typeof_a typeof_b + \ / + typeof_x + """ + source_tvs = set([v.get_typevar() for v in self.vars]) + children = {} # type: Dict[TypeVar, Set[TypeVar]] + for v in self.type_map.values(): + if not v.is_derived: + continue + + t = v.free_typevar() + s = children.get(t, set()) + s.add(v) + children[t] = s + + for (a, b) in self.type_map.items(): + s = children.get(b, set()) + s.add(a) + children[b] = s + + for r in list(self.free_typevars()): + while (r not in source_tvs and r in children and + len(children[r]) == 1): + child = list(children[r])[0] + if child in self.type_map: + assert self.type_map[child] == r + del self.type_map[child] + + r = child + + def extract(self): + # type: () -> TypeEnv + """ + Extract a clean type environment from self, that only mentions + TVs associated with real variables + """ + vars_tvs = set([v.get_typevar() for v in self.vars]) + new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]} + new_constraints = [(self[tv1], self[tv2]) + for (tv1, tv2) in self.constraints] + + # Sanity: new constraints and the new type_map should only contain + # tvs associated with real vars + for (a, b) in new_constraints: + assert a.free_typevar() in vars_tvs and\ + b.free_typevar() in vars_tvs + + for (k, v) in new_type_map.items(): + assert k in vars_tvs + assert v.free_typevar() is None or v.free_typevar() in vars_tvs + + t = TypeEnv() + t.type_map = new_type_map + t.constraints = new_constraints + # ranks and vars contain only TVs associated with real vars + t.ranks = copy(self.ranks) + t.vars = copy(self.vars) + return t + + def concrete_typings(self): + # type: () -> Iterable[VarMap] + """ + Return an iterable over all possible concrete typings permitted by this + TypeEnv. + """ + free_tvs = self.free_typevars() + free_tv_iters = [tv.get_typeset().concrete_types() for tv in free_tvs] + for concrete_types in product(*free_tv_iters): + # Build type substitutions for all free vars + m = {tv: TypeVar.singleton(typ) + for (tv, typ) in zip(free_tvs, concrete_types)} + + concrete_var_map = {v: subst(self[v.get_typevar()], m) + for v in self.vars} + + # Check if constraints are satisfied for this typing + failed = None + for (tv1, tv2) in self.constraints: + tv1 = subst(tv1, m) + tv2 = subst(tv2, m) + assert tv1.get_typeset().size() == 1 and\ + tv2.get_typeset().size() == 1 + if (tv1.get_typeset() != tv2.get_typeset()): + failed = (tv1, tv2) + break + + if (failed is not None): + continue + + yield concrete_var_map + + def dot(self): + # type: () -> str + """ + Return a representation of self as a graph in dot format. + Nodes correspond to TypeVariables. + Dotted edges correspond to equivalences between TVS + Solid edges correspond to derivation relations between TVs. + Dashed edges correspond to equivalence constraints. + """ + def label(s): + # type: (TypeVar) -> str + return "\"" + str(s) + "\"" + + # Add all registered TVs (as some of them may be singleton nodes not + # appearing in the graph + nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar] + edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, Optional[str]]] + + for (k, v) in self.type_map.items(): + # Add all intermediate TVs appearing in edges + nodes.add(k) + nodes.add(v) + edges.add((k, v, "dotted", None)) + while (v.is_derived): + nodes.add(v.base) + edges.add((v, v.base, "solid", v.derived_func)) + v = v.base + + for (a, b) in self.constraints: + assert a in nodes and b in nodes + edges.add((a, b, "dashed", None)) + + root_nodes = set([x for x in nodes + if x not in self.type_map and not x.is_derived]) + + r = "digraph {\n" + for n in nodes: + r += label(n) + if n in root_nodes: + r += "[xlabel=\"{}\"]".format(self[n].get_typeset()) + r += ";\n" + + for (n1, n2, style, elabel) in edges: + e = label(n1) + if style == "dashed": + e += '--' + else: + e += '->' + e += label(n2) + e += "[style={}".format(style) + + if elabel is not None: + e += ",label={}".format(elabel) + e += "];\n" + + r += e + r += "}" + + return r + + +if TYPE_CHECKING: + TypingError = str + TypingOrError = Union[TypeEnv, TypingError] + + +def get_error(typing_or_err): + # type: (TypingOrError) -> Optional[TypingError] + """ + Helper function to appease mypy when checking the result of typing. + """ + if isinstance(typing_or_err, str): + if (TYPE_CHECKING): + return cast(TypingError, typing_or_err) + else: + return typing_or_err + else: + return None + + +def get_type_env(typing_or_err): + # type: (TypingOrError) -> TypeEnv + """ + Helper function to appease mypy when checking the result of typing. + """ + assert isinstance(typing_or_err, TypeEnv) + if (TYPE_CHECKING): + return cast(TypeEnv, typing_or_err) + else: + return typing_or_err + + +def subst(tv, tv_map): + # type: (TypeVar, TypeMap) -> TypeVar + """ + Perform substition on the input tv using the TypeMap tv_map. + """ + if tv in tv_map: + return tv_map[tv] + + if tv.is_derived: + return TypeVar.derived(subst(tv.base, tv_map), tv.derived_func) + + return tv + + +def normalize_tv(tv): + # type: (TypeVar) -> TypeVar + """ + Normalize a (potentially derived) TV using the following rules: + - collapse SAMEAS + SAMEAS(base) -> base + + - vector and width derived functions commute + {HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) -> + {HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base)) + + - half/double pairs collapse + {HALF,DOUBLE}WIDTH({DOUBLE,HALF}WIDTH(base)) -> base + {HALF,DOUBLE}VECTOR({DOUBLE,HALF}VECTOR(base)) -> base + """ + vector_derives = [TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR] + width_derives = [TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH] + + if not tv.is_derived: + return tv + + df = tv.derived_func + + # Collapse SAMEAS edges + if (df == TypeVar.SAMEAS): + return normalize_tv(tv.base) + + if (tv.base.is_derived): + base_df = tv.base.derived_func + + # Reordering: {HALFWIDTH, DOUBLEWIDTH} commute with {HALFVECTOR, + # DOUBLEVECTOR}. Arbitrarily pick WIDTH < VECTOR + if df in vector_derives and base_df in width_derives: + return normalize_tv( + TypeVar.derived( + TypeVar.derived(tv.base.base, df), base_df)) + + # Cancelling: HALFWIDTH, DOUBLEWIDTH and HALFVECTOR, DOUBLEVECTOR + # cancel each other. TODO: Does this cancellation hide type + # overflow/underflow? + + if (df, base_df) in \ + [(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR), + (TypeVar.DOUBLEVECTOR, TypeVar.HALFVECTOR), + (TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH), + (TypeVar.DOUBLEWIDTH, TypeVar.HALFWIDTH)]: + return normalize_tv(tv.base.base) + + return TypeVar.derived(normalize_tv(tv.base), df) + + +def constrain_fixpoint(tv1, tv2): + # type: (TypeVar, TypeVar) -> None + """ + Given typevars tv1 and tv2 (which could be derived from one another) + constrain their typesets to be the same. When one is derived from the + other, repeat the constrain process until fixpoint. + """ + # Constrain tv2's typeset as long as tv1's typeset is changing. + while True: + old_tv1_ts = tv1.get_typeset().copy() + tv2.constrain_types(tv1) + if tv1.get_typeset() == old_tv1_ts: + break + + old_tv2_ts = tv2.get_typeset().copy() + tv1.constrain_types(tv2) + assert old_tv2_ts == tv2.get_typeset() + + +def unify(tv1, tv2, typ): + # type: (TypeVar, TypeVar, TypeEnv) -> TypingOrError + """ + Unify tv1 and tv2 in the current type environment typ, and return an + updated type environment or error. + """ + tv1 = normalize_tv(typ[tv1]) + tv2 = normalize_tv(typ[tv2]) + + # Already unified + if tv1 == tv2: + return typ + + if typ.rank(tv2) < typ.rank(tv1): + return unify(tv2, tv1, typ) + + constrain_fixpoint(tv1, tv2) + + if (tv1.get_typeset().size() == 0 or tv2.get_typeset().size() == 0): + return "Error: empty type created when unifying {} and {}"\ + .format(tv1, tv2) + + # Free -> Derived(Free) + if not tv1.is_derived: + typ.equivalent(tv1, tv2) + return typ + + assert tv2.is_derived, "Ordering gives us !tv1.is_derived==>tv2.is_derived" + + if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)): + inv_f = TypeVar.inverse_func(tv1.derived_func) + return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ) + + typ.add_constraint(tv1, tv2) + return typ + + +def ti_def(definition, typ): + # type: (Def, TypeEnv) -> TypingOrError + """ + Perform type inference on one Def in the current type environment typ and + return an updated type environment or error. + + At a high level this works by creating fresh copies of each formal type var + in the Def's instruction's signature, and unifying the formal tv with the + corresponding actual tv. + """ + expr = definition.expr + inst = expr.inst + + # Create a map m mapping each free typevar in the signature of definition + # to a fresh copy of itself + all_formal_tvs = \ + [inst.outs[i].typevar for i in inst.value_results] +\ + [inst.ins[i].typevar for i in inst.value_opnums] + free_formal_tvs = [tv for tv in all_formal_tvs if not tv.is_derived] + m = {tv: tv.get_fresh_copy(str(typ.get_uid())) for tv in free_formal_tvs} + + # Get fresh copies for each typevar in the signature (both free and + # derived) + fresh_formal_tvs = \ + [subst(inst.outs[i].typevar, m) for i in inst.value_results] +\ + [subst(inst.ins[i].typevar, m) for i in inst.value_opnums] + + # Get the list of actual Vars + actual_vars = [] # type: List[Expr] + actual_vars += [definition.defs[i] for i in inst.value_results] + actual_vars += [expr.args[i] for i in inst.value_opnums] + + # Get the list of the actual TypeVars + actual_tvs = [] + for v in actual_vars: + assert(isinstance(v, Var)) + # Register with TypeEnv that this typevar corresponds ot variable v, + # and thus has a given rank + typ.register(v) + actual_tvs.append(v.get_typevar()) + + # Unify each actual typevar with the correpsonding fresh formal tv + for (actual_tv, formal_tv) in zip(actual_tvs, fresh_formal_tvs): + typ_or_err = unify(actual_tv, formal_tv, typ) + err = get_error(typ_or_err) + if (err): + return "fail ti on {} <: {}: ".format(actual_tv, formal_tv) + err + + typ = get_type_env(typ_or_err) + + return typ + + +def ti_rtl(rtl, typ): + # type: (Rtl, TypeEnv) -> TypingOrError + """ + Perform type inference on an Rtl in a starting type env typ. Return an + updated type environment or error. + """ + for (i, d) in enumerate(rtl.rtl): + assert (isinstance(d, Def)) + typ_or_err = ti_def(d, typ) + err = get_error(typ_or_err) # type: Optional[TypingError] + if (err): + return "On line {}: ".format(i) + err + + typ = get_type_env(typ_or_err) + + return typ + + +def ti_xform(xform, typ): + # type: (XForm, TypeEnv) -> TypingOrError + """ + Perform type inference on an Rtl in a starting type env typ. Return an + updated type environment or error. + """ + typ_or_err = ti_rtl(xform.src, typ) + err = get_error(typ_or_err) # type: Optional[TypingError] + if (err): + return "In src pattern: " + err + + typ = get_type_env(typ_or_err) + + typ_or_err = ti_rtl(xform.dst, typ) + err = get_error(typ_or_err) + if (err): + return "In dst pattern: " + err + + typ = get_type_env(typ_or_err) + + return get_type_env(typ_or_err) diff --git a/lib/cretonne/meta/cdsl/typevar.py b/lib/cretonne/meta/cdsl/typevar.py index 69b419bfa3..1ea9831b91 100644 --- a/lib/cretonne/meta/cdsl/typevar.py +++ b/lib/cretonne/meta/cdsl/typevar.py @@ -8,7 +8,6 @@ from __future__ import absolute_import import math from . import types, is_power_of_two from copy import deepcopy -from .types import IntType, FloatType, BoolType try: from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa @@ -210,6 +209,10 @@ class TypeSet(object): else: return False + def __ne__(self, other): + # type: (object) -> bool + return not self.__eq__(other) + def __repr__(self): # type: () -> str s = 'TypeSet(lanes={}'.format(pp_set(self.lanes)) @@ -289,7 +292,9 @@ class TypeSet(object): new = self.copy() new.ints = set() new.floats = set() - new.bools = self.ints.union(self.floats).union(self.bools) + + if len(self.lanes.difference(set([1]))) > 0: + new.bools = self.ints.union(self.floats).union(self.bools) if 1 in self.lanes: new.bools.add(1) @@ -340,7 +345,7 @@ class TypeSet(object): return new - def map(self, func): + def image(self, func): # type: (str) -> TypeSet """ Return the image of self across the derived function func @@ -362,7 +367,7 @@ class TypeSet(object): else: assert False, "Unknown derived function: " + func - def map_inverse(self, func): + def preimage(self, func): # type: (str) -> TypeSet """ Return the inverse image of self across the derived function func @@ -379,16 +384,14 @@ class TypeSet(object): return new elif (func == TypeVar.ASBOOL): new = self.copy() - new.ints = self.bools.difference(set([1])) - new.floats = self.bools.intersection(set([32, 64])) if 1 not in self.bools: - try: - # If the range doesn't have b1, then the domain can't - # include scalars, as as_bool(scalar)=b1 - new.lanes.remove(1) - except KeyError: - pass + new.ints = self.bools.difference(set([1])) + new.floats = self.bools.intersection(set([32, 64])) + else: + new.ints = set([2**x for x in range(3, 7)]) + new.floats = set([32, 64]) + return new elif (func == TypeVar.HALFWIDTH): return self.double_width() @@ -409,27 +412,32 @@ class TypeSet(object): return len(self.lanes) * (len(self.ints) + len(self.floats) + len(self.bools)) + def concrete_types(self): + # type: () -> Iterable[types.ValueType] + def by(scalar, lanes): + # type: (types.ScalarType, int) -> types.ValueType + if (lanes == 1): + return scalar + else: + return scalar.by(lanes) + + for nlanes in self.lanes: + for bits in self.ints: + yield by(types.IntType.with_bits(bits), nlanes) + for bits in self.floats: + yield by(types.FloatType.with_bits(bits), nlanes) + for bits in self.bools: + yield by(types.BoolType.with_bits(bits), nlanes) + def get_singleton(self): # type: () -> types.ValueType """ Return the singleton type represented by self. Can only call on typesets containing 1 type. """ - assert self.size() == 1 - scalar_type = None # type: types.ScalarType - if len(self.ints) > 0: - scalar_type = IntType.with_bits(tuple(self.ints)[0]) - elif len(self.floats) > 0: - scalar_type = FloatType.with_bits(tuple(self.floats)[0]) - else: - scalar_type = BoolType.with_bits(tuple(self.bools)[0]) - - nlanes = tuple(self.lanes)[0] - - if nlanes == 1: - return scalar_type - else: - return scalar_type.by(nlanes) + types = list(self.concrete_types()) + assert len(types) == 1 + return types[0] class TypeVar(object): @@ -519,6 +527,13 @@ class TypeVar(object): 'TypeVar({}, {})' .format(self.name, self.type_set)) + def __hash__(self): + # type: () -> int + if (not self.is_derived): + return object.__hash__(self) + + return hash((self.derived_func, self.base)) + def __eq__(self, other): # type: (object) -> bool if not isinstance(other, TypeVar): @@ -530,6 +545,10 @@ class TypeVar(object): else: return self is other + def __ne__(self, other): + # type: (object) -> bool + return not self.__eq__(other) + # Supported functions for derived type variables. # The names here must match the method names on `ir::types::Type`. # The camel_case of the names must match `enum OperandConstraint` in @@ -542,6 +561,27 @@ class TypeVar(object): HALFVECTOR = 'half_vector' DOUBLEVECTOR = 'double_vector' + @staticmethod + def is_bijection(func): + # type: (str) -> bool + return func in [ + TypeVar.SAMEAS, + TypeVar.HALFWIDTH, + TypeVar.DOUBLEWIDTH, + TypeVar.HALFVECTOR, + TypeVar.DOUBLEVECTOR] + + @staticmethod + def inverse_func(func): + # type: (str) -> str + return { + TypeVar.SAMEAS: TypeVar.SAMEAS, + TypeVar.HALFWIDTH: TypeVar.DOUBLEWIDTH, + TypeVar.DOUBLEWIDTH: TypeVar.HALFWIDTH, + TypeVar.HALFVECTOR: TypeVar.DOUBLEVECTOR, + TypeVar.DOUBLEVECTOR: TypeVar.HALFVECTOR + }[func] + @staticmethod def derived(base, derived_func): # type: (TypeVar, str) -> TypeVar @@ -668,7 +708,7 @@ class TypeVar(object): Get the free type variable controlling this one. """ if self.is_derived: - return self.base + return self.base.free_typevar() elif self.singleton_type() is not None: # A singleton type variable is not a proper free variable. return None @@ -697,7 +737,7 @@ class TypeVar(object): if not self.is_derived: self.type_set &= ts else: - self.base.constrain_types_by_ts(ts.map_inverse(self.derived_func)) + self.base.constrain_types_by_ts(ts.preimage(self.derived_func)) def constrain_types(self, other): # type: (TypeVar) -> None @@ -723,19 +763,14 @@ class TypeVar(object): if not self.is_derived: return self.type_set else: - if (self.derived_func == TypeVar.SAMEAS): - return self.base.get_typeset() - elif (self.derived_func == TypeVar.LANEOF): - return self.base.get_typeset().lane_of() - elif (self.derived_func == TypeVar.ASBOOL): - return self.base.get_typeset().as_bool() - elif (self.derived_func == TypeVar.HALFWIDTH): - return self.base.get_typeset().half_width() - elif (self.derived_func == TypeVar.DOUBLEWIDTH): - return self.base.get_typeset().double_width() - elif (self.derived_func == TypeVar.HALFVECTOR): - return self.base.get_typeset().half_vector() - elif (self.derived_func == TypeVar.DOUBLEVECTOR): - return self.base.get_typeset().double_vector() - else: - assert False, "Unknown derived function: " + self.derived_func + return self.base.get_typeset().image(self.derived_func) + + def get_fresh_copy(self, name): + # type: (str) -> TypeVar + """ + Get a fresh copy of self. Can only be called on free typevars. + """ + assert not self.is_derived + tv = TypeVar.from_typeset(self.type_set.copy()) + tv.name = name + return tv diff --git a/lib/cretonne/meta/cdsl/xform.py b/lib/cretonne/meta/cdsl/xform.py index 9ce93c9ed9..a8fbb7f66e 100644 --- a/lib/cretonne/meta/cdsl/xform.py +++ b/lib/cretonne/meta/cdsl/xform.py @@ -3,6 +3,7 @@ Instruction transformations. """ from __future__ import absolute_import from .ast import Def, Var, Apply +from .ti import ti_xform, TypeEnv, get_type_env try: from typing import Union, Iterator, Sequence, Iterable, List, Dict # noqa @@ -83,6 +84,8 @@ class XForm(object): self._rewrite_rtl(src, symtab, Var.SRCCTX) num_src_inputs = len(self.inputs) self._rewrite_rtl(dst, symtab, Var.DSTCTX) + # Needed for testing type inference on XForms + self.symtab = symtab # Check for inconsistently used inputs. for i in self.inputs: @@ -96,9 +99,25 @@ class XForm(object): "extra inputs in dst RTL: {}".format( self.inputs[num_src_inputs:])) - self._infer_types(self.src) - self._infer_types(self.dst) - self._collect_typevars() + # Perform type inference and cleanup + raw_ti = get_type_env(ti_xform(self, TypeEnv())) + raw_ti.normalize() + self.ti = raw_ti.extract() + + # Sanity: The set of inferred free typevars should be a subset of the + # TVs corresponding to Vars appearing in src + self.free_typevars = self.ti.free_typevars() + src_vars = set(self.inputs).union( + [x for x in self.defs if not x.is_temp()]) + src_tvs = set([v.get_typevar() for v in src_vars]) + if (not self.free_typevars.issubset(src_tvs)): + raise AssertionError( + "Some free vars don't appear in src - {}" + .format(self.free_typevars.difference(src_tvs))) + + # Update the type vars for each Var to their inferred values + for v in self.inputs + self.defs: + v.set_typevar(self.ti[v.get_typevar()]) def __repr__(self): # type: () -> str @@ -202,63 +221,6 @@ class XForm(object): raise AssertionError( '{} not defined in dest pattern'.format(d)) - def _infer_types(self, rtl): - # type: (Rtl) -> None - """Assign type variables to all value variables used in `rtl`.""" - for d in rtl.rtl: - inst = d.expr.inst - - # Get the Var corresponding to the controlling type variable. - ctrl_var = None # type: Var - if inst.is_polymorphic: - if inst.use_typevar_operand: - # Should this be an assertion instead? - # Should all value operands be required to be Vars? - arg = d.expr.args[inst.format.typevar_operand] - if isinstance(arg, Var): - ctrl_var = arg - else: - ctrl_var = d.defs[inst.value_results[0]] - - # Reconcile arguments with the requirements of `inst`. - for opnum in inst.value_opnums: - inst_tv = inst.ins[opnum].typevar - v = d.expr.args[opnum] - if isinstance(v, Var): - v.constrain_typevar(inst_tv, inst.ctrl_typevar, ctrl_var) - - # Reconcile results with the requirements of `inst`. - for resnum in inst.value_results: - inst_tv = inst.outs[resnum].typevar - v = d.defs[resnum] - v.constrain_typevar(inst_tv, inst.ctrl_typevar, ctrl_var) - - def _collect_typevars(self): - # type: () -> None - """ - Collect a list of variables whose type can be used to infer the types - of all expressions. - - This should be called after `_infer_types()` above has computed type - variables for all the used vars. - """ - fvars = list(v for v in self.inputs if v.has_free_typevar()) - fvars += list(v for v in self.defs if v.has_free_typevar()) - self.free_typevars = fvars - - # When substituting a pattern, we know the types of all variables that - # appear on the source side: inut, output, and intermediate values. - # However, temporary values which appear only on the destination side - # must have their type computed somehow. - # - # Some variables have a fixed type which appears as a type variable - # with a singleton_type field set. That's allowed for temps too. - for v in fvars: - if v.is_temp() and not v.typevar.singleton_type(): - raise AssertionError( - "Cannot determine type of temp '{}' in xform:\n{}" - .format(v, self)) - class XFormGroup(object): """