Add fix for #114 (#115)

* Reduce code duplication in TypeConstraint subclasses; Add ConstrainWiderOrEqual to ti and to ireduce,{s,u}extend and f{promote,demote}; Fix bug in emitting constraint edges in TypeEnv.dot(); Modify runtime constraint checks to reject match when they encounter overflow

* Rename Constrain types to something shorter; Move lane_bits/lane_counts in subclasses of ValueType; Add wider_or_eq function in rust and python;
This commit is contained in:
d1m0
2017-07-12 08:51:55 -07:00
committed by Jakob Stoklund Olesen
parent 962c945a3c
commit 7c438f866c
8 changed files with 471 additions and 132 deletions

View File

@@ -8,7 +8,7 @@ from itertools import product
try:
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
from typing import Iterable, List # noqa
from typing import Iterable, List, Any # noqa
from typing import cast
from .xform import Rtl, XForm # noqa
from .ast import Expr # noqa
@@ -25,9 +25,72 @@ class TypeConstraint(object):
"""
Base class for all runtime-emittable type constraints.
"""
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> TypeConstraint
"""
Translate any TypeVars in the constraint according to the map or
TypeEnv m
"""
def translate_one(a):
# type: (Any) -> Any
if (isinstance(a, TypeVar)):
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
return a
res = None # type: TypeConstraint
res = self.__class__(*tuple(map(translate_one, self._args())))
return res
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, self.__class__)):
return False
assert isinstance(other, TypeConstraint) # help MyPy figure out other
return self._args() == other._args()
def is_concrete(self):
# type: () -> bool
"""
Return true iff all typevars in the constraint are singletons.
"""
tvs = filter(lambda x: isinstance(x, TypeVar), self._args())
return [] == list(filter(lambda x: x.singleton_type() is None, tvs))
def __hash__(self):
# type: () -> int
return hash(self._args())
def _args(self):
# type: () -> Tuple[Any,...]
"""
Return a tuple with the exact arguments passed to __init__ to create
this object.
"""
assert False, "Abstract"
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
assert False, "Abstract"
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert False, "Abstract"
def __repr__(self):
# type: () -> str
return (self.__class__.__name__ + '(' +
', '.join(map(str, self._args())) + ')')
class ConstrainTVsEqual(TypeConstraint):
class TypesEqual(TypeConstraint):
"""
Constraint specifying that two derived type vars must have the same runtime
type.
@@ -37,48 +100,24 @@ class ConstrainTVsEqual(TypeConstraint):
assert tv1.is_derived and tv2.is_derived
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv1, self.tv2)
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
return self.tv1 == self.tv2 or \
(self.tv1.singleton_type() is not None and
self.tv2.singleton_type() is not None)
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVsEqual(m[self.tv1], m[self.tv2])
else:
return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m))
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVsEqual)):
return False
return (self.tv1, self.tv2) == (other.tv1, other.tv2)
def __hash__(self):
# type: () -> int
return hash((self.tv1, self.tv2))
""" See TypeConstraint.is_trivial() """
return self.tv1 == self.tv2 or self.is_concrete()
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert self.tv1.singleton_type() is not None and \
self.tv2.singleton_type() is not None
""" See TypeConstraint.eval() """
assert self.is_concrete()
return self.tv1.singleton_type() == self.tv2.singleton_type()
class ConstrainTVInTypeset(TypeConstraint):
class InTypeset(TypeConstraint):
"""
Constraint specifying that a type var must belong to some typeset.
"""
@@ -88,11 +127,14 @@ class ConstrainTVInTypeset(TypeConstraint):
self.tv = tv
self.ts = ts
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv, self.ts)
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
""" See TypeConstraint.is_trivial() """
tv_ts = self.tv.get_typeset().copy()
# Trivially True
@@ -104,39 +146,78 @@ class ConstrainTVInTypeset(TypeConstraint):
if (tv_ts.size() == 0):
return True
return False
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVInTypeset(m[self.tv], self.ts)
else:
return ConstrainTVInTypeset(subst(self.tv, m), self.ts)
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVInTypeset)):
return False
return (self.tv, self.ts) == (other.tv, other.ts)
def __hash__(self):
# type: () -> int
return hash((self.tv, self.ts))
return self.is_concrete()
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert self.tv.singleton_type() is not None
""" See TypeConstraint.eval() """
assert self.is_concrete()
return self.tv.get_typeset().issubset(self.ts)
class WiderOrEq(TypeConstraint):
"""
Constraint specifying that a type var tv1 must be wider than or equal to
type var tv2 at runtime. This requires that:
1) They have the same number of lanes
2) In a lane tv1 has at least as many bits as tv2.
"""
def __init__(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
self.tv1 = tv1
self.tv2 = tv2
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv1, self.tv2)
def is_trivial(self):
# type: () -> bool
""" See TypeConstraint.is_trivial() """
# Trivially true
if (self.tv1 == self.tv2):
return True
ts1 = self.tv1.get_typeset()
ts2 = self.tv2.get_typeset()
def set_wider_or_equal(s1, s2):
# type: (Set[int], Set[int]) -> bool
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
# Trivially True
if set_wider_or_equal(ts1.ints, ts2.ints) and\
set_wider_or_equal(ts1.floats, ts2.floats) and\
set_wider_or_equal(ts1.bools, ts2.bools):
return True
def set_narrower(s1, s2):
# type: (Set[int], Set[int]) -> bool
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
# Trivially False
if set_narrower(ts1.ints, ts2.ints) and\
set_narrower(ts1.floats, ts2.floats) and\
set_narrower(ts1.bools, ts2.bools):
return True
# Trivially False
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
return True
return self.is_concrete()
def eval(self):
# type: () -> bool
""" See TypeConstraint.eval() """
assert self.is_concrete()
typ1 = self.tv1.singleton_type()
typ2 = self.tv2.singleton_type()
return typ1.wider_or_equal(typ2)
class TypeEnv(object):
"""
Class encapsulating the neccessary book keeping for type inference.
@@ -204,12 +285,11 @@ class TypeEnv(object):
self.type_map[tv1] = tv2
def add_constraint(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
def add_constraint(self, constr):
# type: (TypeConstraint) -> None
"""
Add a new equivalence constraint between tv1 and tv2
Add a new constraint
"""
constr = ConstrainTVsEqual(tv1, tv2)
if (constr not in self.constraints):
self.constraints.append(constr)
@@ -261,6 +341,7 @@ class TypeEnv(object):
Get the free typevars in the current type env.
"""
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
# Filter out None here due to singleton type vars
return sorted(filter(lambda x: x is not None, tvs),
key=lambda x: x.name)
@@ -326,17 +407,18 @@ class TypeEnv(object):
new_constraints = [] # type: List[TypeConstraint]
for constr in self.constraints:
# Currently typeinference only generates ConstrainTVsEqual
# constraints
assert isinstance(constr, ConstrainTVsEqual)
constr = constr.translate(self)
if constr.is_trivial() or constr in new_constraints:
continue
# Sanity: translated constraints should refer to only real vars
assert constr.tv1.free_typevar() in vars_tvs and\
constr.tv2.free_typevar() in vars_tvs
for arg in constr._args():
if (not isinstance(arg, TypeVar)):
continue
arg_free_tv = arg.free_typevar()
assert arg_free_tv is None or arg_free_tv in vars_tvs
new_constraints.append(constr)
@@ -372,9 +454,6 @@ class TypeEnv(object):
# Check if constraints are satisfied for this typing
failed = None
for constr in self.constraints:
# Currently typeinference only generates ConstrainTVsEqual
# constraints
assert isinstance(constr, ConstrainTVsEqual)
concrete_constr = constr.translate(m)
if not concrete_constr.eval():
failed = concrete_constr
@@ -401,22 +480,27 @@ class TypeEnv(object):
# Add all registered TVs (as some of them may be singleton nodes not
# appearing in the graph
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, Optional[str]]]
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
for (k, v) in self.type_map.items():
# Add all intermediate TVs appearing in edges
nodes.add(k)
nodes.add(v)
edges.add((k, v, "dotted", None))
edges.add((k, v, "dotted", "forward", None))
while (v.is_derived):
nodes.add(v.base)
edges.add((v, v.base, "solid", v.derived_func))
edges.add((v, v.base, "solid", "forward", v.derived_func))
v = v.base
for constr in self.constraints:
assert isinstance(constr, ConstrainTVsEqual)
assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", None))
if isinstance(constr, TypesEqual):
assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
elif isinstance(constr, WiderOrEq):
assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
else:
assert False, "Can't display constraint {}".format(constr)
root_nodes = set([x for x in nodes
if x not in self.type_map and not x.is_derived])
@@ -428,17 +512,12 @@ class TypeEnv(object):
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
r += ";\n"
for (n1, n2, style, elabel) in edges:
e = label(n1)
if style == "dashed":
e += '--'
else:
e += '->'
e += label(n2)
e += "[style={}".format(style)
for (n1, n2, style, direction, elabel) in edges:
e = label(n1) + "->" + label(n2)
e += "[style={},dir={}".format(style, direction)
if elabel is not None:
e += ",label={}".format(elabel)
e += ",label=\"{}\"".format(elabel)
e += "];\n"
r += e
@@ -589,7 +668,7 @@ def unify(tv1, tv2, typ):
inv_f = TypeVar.inverse_func(tv1.derived_func)
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
typ.add_constraint(tv1, tv2)
typ.add_constraint(TypesEqual(tv1, tv2))
return typ
@@ -648,6 +727,10 @@ def ti_def(definition, typ):
typ = get_type_env(typ_or_err)
# Add any instruction specific constraints
for constr in inst.constraints:
typ.add_constraint(constr.translate(m))
return typ