Add fix for #114 (#115)

* Reduce code duplication in TypeConstraint subclasses; Add ConstrainWiderOrEqual to ti and to ireduce,{s,u}extend and f{promote,demote}; Fix bug in emitting constraint edges in TypeEnv.dot(); Modify runtime constraint checks to reject match when they encounter overflow

* Rename Constrain types to something shorter; Move lane_bits/lane_counts in subclasses of ValueType; Add wider_or_eq function in rust and python;
This commit is contained in:
d1m0
2017-07-12 08:51:55 -07:00
committed by Jakob Stoklund Olesen
parent de5501bc47
commit a9147ebd30
8 changed files with 471 additions and 132 deletions

View File

@@ -10,8 +10,10 @@ try:
if TYPE_CHECKING:
from .ast import Expr, Apply # noqa
from .typevar import TypeVar # noqa
from .ti import TypeConstraint # noqa
# List of operands for ins/outs:
OpList = Union[Sequence[Operand], Operand]
ConstrList = Union[Sequence[TypeConstraint], TypeConstraint]
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
except ImportError:
pass
@@ -80,6 +82,7 @@ class Instruction(object):
operands and other operand kinds.
:param outs: Tuple of output operands. The output operands must be SSA
values or `variable_args`.
:param constraints: Tuple of instruction-specific TypeConstraints.
:param is_terminator: This is a terminator instruction.
:param is_branch: This is a branch instruction.
:param is_call: This is a call instruction.
@@ -102,13 +105,14 @@ class Instruction(object):
'can_trap': 'Can this instruction cause a trap?',
}
def __init__(self, name, doc, ins=(), outs=(), **kwargs):
# type: (str, str, OpList, OpList, **Any) -> None # noqa
def __init__(self, name, doc, ins=(), outs=(), constraints=(), **kwargs):
# type: (str, str, OpList, OpList, ConstrList, **Any) -> None
self.name = name
self.camel_name = camel_case(name)
self.__doc__ = doc
self.ins = self._to_operand_tuple(ins)
self.outs = self._to_operand_tuple(outs)
self.constraints = self._to_constraint_tuple(constraints)
self.format = InstructionFormat.lookup(self.ins, self.outs)
# Opcode number, assigned by gen_instr.py.
@@ -268,6 +272,23 @@ class Instruction(object):
assert isinstance(op, Operand)
return x
@staticmethod
def _to_constraint_tuple(x):
# type: (ConstrList) -> Tuple[TypeConstraint, ...]
"""
Allow a single TypeConstraint instance instead of the awkward singleton
tuple syntax.
"""
# import placed here to avoid circular dependency
from .ti import TypeConstraint # noqa
if isinstance(x, TypeConstraint):
x = (x,)
else:
x = tuple(x)
for op in x:
assert isinstance(op, TypeConstraint)
return x
def bind(self, *args):
# type: (*ValueType) -> BoundInstruction
"""

View File

@@ -1,13 +1,14 @@
from __future__ import absolute_import
from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\
b1, icmp, iadd_cout, iadd_cin, uextend, ireduce
b1, icmp, iadd_cout, iadd_cin, uextend, sextend, ireduce, fpromote, \
fdemote
from base.legalize import narrow, expand
from base.immediates import intcc
from base.types import i32, i8
from .typevar import TypeVar
from .ast import Var, Def
from .xform import Rtl, XForm
from .ti import ti_rtl, subst, TypeEnv, get_type_env, ConstrainTVsEqual
from .ti import ti_rtl, subst, TypeEnv, get_type_env, TypesEqual, WiderOrEq
from unittest import TestCase
from functools import reduce
@@ -52,9 +53,10 @@ def agree(me, other):
# Translate our constraints using m, and sort
me_equiv_constr = sorted([constr.translate(m)
for constr in me.constraints])
for constr in me.constraints], key=repr)
# Sort other's constraints
other_equiv_constr = sorted(other.constraints)
other_equiv_constr = sorted([constr.translate(other)
for constr in other.constraints], key=repr)
return me_equiv_constr == other_equiv_constr
@@ -78,7 +80,7 @@ def check_typing(got_or_err, expected, symtab=None):
tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()}
# Rewrite the TVs in the input constraints to their XForm internal
# versions
c = [(subst(a, subst_m), subst(b, subst_m)) for (a, b) in c]
c = [constr.translate(subst_m) for constr in c]
else:
# If no symtab, just convert m from Var->TypeVar map to a
# TypeVar->TypeVar map
@@ -209,7 +211,7 @@ class TestRTL(TypeCheckingBaseTest):
self.v3: txn,
self.v4: txn,
self.v5: txn,
}, [ConstrainTVsEqual(ixn.as_bool(), txn.as_bool())]))
}, [TypesEqual(ixn.as_bool(), txn.as_bool())]))
def test_vselect_vsplits(self):
# type: () -> None
@@ -319,6 +321,90 @@ class TestRTL(TypeCheckingBaseTest):
"Error: empty type created when unifying " +
"`typeof_v4` and `typeof_v5`")
def test_extend_reduce(self):
# type: () -> None
r = Rtl(
self.v1 << uextend(self.v0),
self.v2 << ireduce(self.v1),
self.v3 << sextend(self.v2),
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
typing = typing.extract()
itype0 = TypeVar("t", "", ints=True, simd=(1, 256))
itype1 = TypeVar("t1", "", ints=True, simd=(1, 256))
itype2 = TypeVar("t2", "", ints=True, simd=(1, 256))
itype3 = TypeVar("t3", "", ints=True, simd=(1, 256))
check_typing(typing, ({
self.v0: itype0,
self.v1: itype1,
self.v2: itype2,
self.v3: itype3,
}, [WiderOrEq(itype1, itype0),
WiderOrEq(itype1, itype2),
WiderOrEq(itype3, itype2)]))
def test_extend_reduce_enumeration(self):
# type: () -> None
for op in (uextend, sextend, ireduce):
r = Rtl(
self.v1 << op(self.v0),
)
ti = TypeEnv()
typing = ti_rtl(r, ti).extract()
# The number of possible typings is 9 * (3+ 2*2 + 3) = 90
l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
assert (len(l) == len(set(l)) and len(l) == 90)
for (tv0, tv1) in l:
typ0, typ1 = (tv0.singleton_type(), tv1.singleton_type())
if (op == ireduce):
assert typ0.wider_or_equal(typ1)
else:
assert typ1.wider_or_equal(typ0)
def test_fpromote_fdemote(self):
# type: () -> None
r = Rtl(
self.v1 << fpromote(self.v0),
self.v2 << fdemote(self.v1),
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
typing = typing.extract()
ftype0 = TypeVar("t", "", floats=True, simd=(1, 256))
ftype1 = TypeVar("t1", "", floats=True, simd=(1, 256))
ftype2 = TypeVar("t2", "", floats=True, simd=(1, 256))
check_typing(typing, ({
self.v0: ftype0,
self.v1: ftype1,
self.v2: ftype2,
}, [WiderOrEq(ftype1, ftype0),
WiderOrEq(ftype1, ftype2)]))
def test_fpromote_fdemote_enumeration(self):
# type: () -> None
for op in (fpromote, fdemote):
r = Rtl(
self.v1 << op(self.v0),
)
ti = TypeEnv()
typing = ti_rtl(r, ti).extract()
# The number of possible typings is 9*(2 + 1) = 27
l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
assert (len(l) == len(set(l)) and len(l) == 27)
for (tv0, tv1) in l:
(typ0, typ1) = (tv0.singleton_type(), tv1.singleton_type())
if (op == fdemote):
assert typ0.wider_or_equal(typ1)
else:
assert typ1.wider_or_equal(typ0)
class TestXForm(TypeCheckingBaseTest):
def test_iadd_cout(self):
@@ -453,7 +539,7 @@ class TestXForm(TypeCheckingBaseTest):
self.v3: i32t,
self.v4: i32t,
self.v5: i32t,
}, []), x.symtab)
}, [WiderOrEq(i32t, itype)]), x.symtab)
def test_bound_inst_inference1(self):
# Second example taken from issue #26
@@ -477,7 +563,7 @@ class TestXForm(TypeCheckingBaseTest):
self.v3: i32t,
self.v4: i32t,
self.v5: i32t,
}, []), x.symtab)
}, [WiderOrEq(i32t, itype)]), x.symtab)
def test_fully_bound_inst_inference(self):
# Second example taken from issue #26 with complete bounds
@@ -494,6 +580,7 @@ class TestXForm(TypeCheckingBaseTest):
i8t = TypeVar.singleton(i8)
i32t = TypeVar.singleton(i32)
# Note no constraints here since they are all trivial
check_typing(x.ti, ({
self.v0: i8t,
self.v1: i8t,

View File

@@ -8,7 +8,7 @@ from itertools import product
try:
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
from typing import Iterable, List # noqa
from typing import Iterable, List, Any # noqa
from typing import cast
from .xform import Rtl, XForm # noqa
from .ast import Expr # noqa
@@ -25,9 +25,72 @@ class TypeConstraint(object):
"""
Base class for all runtime-emittable type constraints.
"""
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> TypeConstraint
"""
Translate any TypeVars in the constraint according to the map or
TypeEnv m
"""
def translate_one(a):
# type: (Any) -> Any
if (isinstance(a, TypeVar)):
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
return a
res = None # type: TypeConstraint
res = self.__class__(*tuple(map(translate_one, self._args())))
return res
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, self.__class__)):
return False
assert isinstance(other, TypeConstraint) # help MyPy figure out other
return self._args() == other._args()
def is_concrete(self):
# type: () -> bool
"""
Return true iff all typevars in the constraint are singletons.
"""
tvs = filter(lambda x: isinstance(x, TypeVar), self._args())
return [] == list(filter(lambda x: x.singleton_type() is None, tvs))
def __hash__(self):
# type: () -> int
return hash(self._args())
def _args(self):
# type: () -> Tuple[Any,...]
"""
Return a tuple with the exact arguments passed to __init__ to create
this object.
"""
assert False, "Abstract"
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
assert False, "Abstract"
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert False, "Abstract"
def __repr__(self):
# type: () -> str
return (self.__class__.__name__ + '(' +
', '.join(map(str, self._args())) + ')')
class ConstrainTVsEqual(TypeConstraint):
class TypesEqual(TypeConstraint):
"""
Constraint specifying that two derived type vars must have the same runtime
type.
@@ -37,48 +100,24 @@ class ConstrainTVsEqual(TypeConstraint):
assert tv1.is_derived and tv2.is_derived
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv1, self.tv2)
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
return self.tv1 == self.tv2 or \
(self.tv1.singleton_type() is not None and
self.tv2.singleton_type() is not None)
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVsEqual(m[self.tv1], m[self.tv2])
else:
return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m))
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVsEqual)):
return False
return (self.tv1, self.tv2) == (other.tv1, other.tv2)
def __hash__(self):
# type: () -> int
return hash((self.tv1, self.tv2))
""" See TypeConstraint.is_trivial() """
return self.tv1 == self.tv2 or self.is_concrete()
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert self.tv1.singleton_type() is not None and \
self.tv2.singleton_type() is not None
""" See TypeConstraint.eval() """
assert self.is_concrete()
return self.tv1.singleton_type() == self.tv2.singleton_type()
class ConstrainTVInTypeset(TypeConstraint):
class InTypeset(TypeConstraint):
"""
Constraint specifying that a type var must belong to some typeset.
"""
@@ -88,11 +127,14 @@ class ConstrainTVInTypeset(TypeConstraint):
self.tv = tv
self.ts = ts
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv, self.ts)
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
""" See TypeConstraint.is_trivial() """
tv_ts = self.tv.get_typeset().copy()
# Trivially True
@@ -104,39 +146,78 @@ class ConstrainTVInTypeset(TypeConstraint):
if (tv_ts.size() == 0):
return True
return False
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVInTypeset(m[self.tv], self.ts)
else:
return ConstrainTVInTypeset(subst(self.tv, m), self.ts)
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVInTypeset)):
return False
return (self.tv, self.ts) == (other.tv, other.ts)
def __hash__(self):
# type: () -> int
return hash((self.tv, self.ts))
return self.is_concrete()
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert self.tv.singleton_type() is not None
""" See TypeConstraint.eval() """
assert self.is_concrete()
return self.tv.get_typeset().issubset(self.ts)
class WiderOrEq(TypeConstraint):
"""
Constraint specifying that a type var tv1 must be wider than or equal to
type var tv2 at runtime. This requires that:
1) They have the same number of lanes
2) In a lane tv1 has at least as many bits as tv2.
"""
def __init__(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
self.tv1 = tv1
self.tv2 = tv2
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv1, self.tv2)
def is_trivial(self):
# type: () -> bool
""" See TypeConstraint.is_trivial() """
# Trivially true
if (self.tv1 == self.tv2):
return True
ts1 = self.tv1.get_typeset()
ts2 = self.tv2.get_typeset()
def set_wider_or_equal(s1, s2):
# type: (Set[int], Set[int]) -> bool
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
# Trivially True
if set_wider_or_equal(ts1.ints, ts2.ints) and\
set_wider_or_equal(ts1.floats, ts2.floats) and\
set_wider_or_equal(ts1.bools, ts2.bools):
return True
def set_narrower(s1, s2):
# type: (Set[int], Set[int]) -> bool
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
# Trivially False
if set_narrower(ts1.ints, ts2.ints) and\
set_narrower(ts1.floats, ts2.floats) and\
set_narrower(ts1.bools, ts2.bools):
return True
# Trivially False
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
return True
return self.is_concrete()
def eval(self):
# type: () -> bool
""" See TypeConstraint.eval() """
assert self.is_concrete()
typ1 = self.tv1.singleton_type()
typ2 = self.tv2.singleton_type()
return typ1.wider_or_equal(typ2)
class TypeEnv(object):
"""
Class encapsulating the neccessary book keeping for type inference.
@@ -204,12 +285,11 @@ class TypeEnv(object):
self.type_map[tv1] = tv2
def add_constraint(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
def add_constraint(self, constr):
# type: (TypeConstraint) -> None
"""
Add a new equivalence constraint between tv1 and tv2
Add a new constraint
"""
constr = ConstrainTVsEqual(tv1, tv2)
if (constr not in self.constraints):
self.constraints.append(constr)
@@ -261,6 +341,7 @@ class TypeEnv(object):
Get the free typevars in the current type env.
"""
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
# Filter out None here due to singleton type vars
return sorted(filter(lambda x: x is not None, tvs),
key=lambda x: x.name)
@@ -326,17 +407,18 @@ class TypeEnv(object):
new_constraints = [] # type: List[TypeConstraint]
for constr in self.constraints:
# Currently typeinference only generates ConstrainTVsEqual
# constraints
assert isinstance(constr, ConstrainTVsEqual)
constr = constr.translate(self)
if constr.is_trivial() or constr in new_constraints:
continue
# Sanity: translated constraints should refer to only real vars
assert constr.tv1.free_typevar() in vars_tvs and\
constr.tv2.free_typevar() in vars_tvs
for arg in constr._args():
if (not isinstance(arg, TypeVar)):
continue
arg_free_tv = arg.free_typevar()
assert arg_free_tv is None or arg_free_tv in vars_tvs
new_constraints.append(constr)
@@ -372,9 +454,6 @@ class TypeEnv(object):
# Check if constraints are satisfied for this typing
failed = None
for constr in self.constraints:
# Currently typeinference only generates ConstrainTVsEqual
# constraints
assert isinstance(constr, ConstrainTVsEqual)
concrete_constr = constr.translate(m)
if not concrete_constr.eval():
failed = concrete_constr
@@ -401,22 +480,27 @@ class TypeEnv(object):
# Add all registered TVs (as some of them may be singleton nodes not
# appearing in the graph
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, Optional[str]]]
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
for (k, v) in self.type_map.items():
# Add all intermediate TVs appearing in edges
nodes.add(k)
nodes.add(v)
edges.add((k, v, "dotted", None))
edges.add((k, v, "dotted", "forward", None))
while (v.is_derived):
nodes.add(v.base)
edges.add((v, v.base, "solid", v.derived_func))
edges.add((v, v.base, "solid", "forward", v.derived_func))
v = v.base
for constr in self.constraints:
assert isinstance(constr, ConstrainTVsEqual)
assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", None))
if isinstance(constr, TypesEqual):
assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
elif isinstance(constr, WiderOrEq):
assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
else:
assert False, "Can't display constraint {}".format(constr)
root_nodes = set([x for x in nodes
if x not in self.type_map and not x.is_derived])
@@ -428,17 +512,12 @@ class TypeEnv(object):
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
r += ";\n"
for (n1, n2, style, elabel) in edges:
e = label(n1)
if style == "dashed":
e += '--'
else:
e += '->'
e += label(n2)
e += "[style={}".format(style)
for (n1, n2, style, direction, elabel) in edges:
e = label(n1) + "->" + label(n2)
e += "[style={},dir={}".format(style, direction)
if elabel is not None:
e += ",label={}".format(elabel)
e += ",label=\"{}\"".format(elabel)
e += "];\n"
r += e
@@ -589,7 +668,7 @@ def unify(tv1, tv2, typ):
inv_f = TypeVar.inverse_func(tv1.derived_func)
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
typ.add_constraint(tv1, tv2)
typ.add_constraint(TypesEqual(tv1, tv2))
return typ
@@ -648,6 +727,10 @@ def ti_def(definition, typ):
typ = get_type_env(typ_or_err)
# Add any instruction specific constraints
for constr in inst.constraints:
typ.add_constraint(constr.translate(m))
return typ

View File

@@ -49,6 +49,26 @@ class ValueType(object):
else:
raise AttributeError("No type named '{}'".format(name))
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
assert False, "Abstract"
def lane_count(self):
# type: () -> int
"""Return the number of lanes."""
assert False, "Abstract"
def wider_or_equal(self, other):
# type: (ValueType) -> bool
"""
Return true iff:
1. self and other have equal number of lanes
2. each lane in self has at least as many bits as a lane in other
"""
return (self.lane_count() == other.lane_count() and
self.lane_bits() >= other.lane_bits())
class ScalarType(ValueType):
"""
@@ -85,6 +105,11 @@ class ScalarType(ValueType):
self._vectors[lanes] = v
return v
def lane_count(self):
# type: () -> int
"""Return the number of lanes."""
return 1
class VectorType(ValueType):
"""
@@ -112,6 +137,16 @@ class VectorType(ValueType):
return ('VectorType(base={}, lanes={})'
.format(self.base.name, self.lanes))
def lane_count(self):
# type: () -> int
"""Return the number of lanes."""
return self.lanes
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.base.lane_bits()
class IntType(ScalarType):
"""A concrete scalar integer type."""
@@ -138,6 +173,11 @@ class IntType(ScalarType):
else:
return typ
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.bits
class FloatType(ScalarType):
"""A concrete scalar floating point type."""
@@ -164,6 +204,11 @@ class FloatType(ScalarType):
else:
return typ
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.bits
class BoolType(ScalarType):
"""A concrete scalar boolean type."""
@@ -189,3 +234,8 @@ class BoolType(ScalarType):
return cast(BoolType, typ)
else:
return typ
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.bits