* Reduce code duplication in TypeConstraint subclasses; Add ConstrainWiderOrEqual to ti and to ireduce,{s,u}extend and f{promote,demote}; Fix bug in emitting constraint edges in TypeEnv.dot(); Modify runtime constraint checks to reject match when they encounter overflow
* Rename Constrain types to something shorter; Move lane_bits/lane_counts in subclasses of ValueType; Add wider_or_eq function in rust and python;
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
962c945a3c
commit
7c438f866c
@@ -8,7 +8,7 @@ from itertools import product
|
||||
|
||||
try:
|
||||
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
||||
from typing import Iterable, List # noqa
|
||||
from typing import Iterable, List, Any # noqa
|
||||
from typing import cast
|
||||
from .xform import Rtl, XForm # noqa
|
||||
from .ast import Expr # noqa
|
||||
@@ -25,9 +25,72 @@ class TypeConstraint(object):
|
||||
"""
|
||||
Base class for all runtime-emittable type constraints.
|
||||
"""
|
||||
def translate(self, m):
|
||||
# type: (Union[TypeEnv, TypeMap]) -> TypeConstraint
|
||||
"""
|
||||
Translate any TypeVars in the constraint according to the map or
|
||||
TypeEnv m
|
||||
"""
|
||||
def translate_one(a):
|
||||
# type: (Any) -> Any
|
||||
if (isinstance(a, TypeVar)):
|
||||
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
|
||||
return a
|
||||
|
||||
res = None # type: TypeConstraint
|
||||
res = self.__class__(*tuple(map(translate_one, self._args())))
|
||||
return res
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if (not isinstance(other, self.__class__)):
|
||||
return False
|
||||
|
||||
assert isinstance(other, TypeConstraint) # help MyPy figure out other
|
||||
return self._args() == other._args()
|
||||
|
||||
def is_concrete(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true iff all typevars in the constraint are singletons.
|
||||
"""
|
||||
tvs = filter(lambda x: isinstance(x, TypeVar), self._args())
|
||||
return [] == list(filter(lambda x: x.singleton_type() is None, tvs))
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
return hash(self._args())
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
"""
|
||||
Return a tuple with the exact arguments passed to __init__ to create
|
||||
this object.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true if this constrain is statically decidable.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Evaluate this constraint. Should only be called when the constraint has
|
||||
been translated to concrete types.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return (self.__class__.__name__ + '(' +
|
||||
', '.join(map(str, self._args())) + ')')
|
||||
|
||||
|
||||
class ConstrainTVsEqual(TypeConstraint):
|
||||
class TypesEqual(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that two derived type vars must have the same runtime
|
||||
type.
|
||||
@@ -37,48 +100,24 @@ class ConstrainTVsEqual(TypeConstraint):
|
||||
assert tv1.is_derived and tv2.is_derived
|
||||
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv1, self.tv2)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true if this constrain is statically decidable.
|
||||
"""
|
||||
return self.tv1 == self.tv2 or \
|
||||
(self.tv1.singleton_type() is not None and
|
||||
self.tv2.singleton_type() is not None)
|
||||
|
||||
def translate(self, m):
|
||||
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual
|
||||
"""
|
||||
Translate any TypeVars in the constraint according to the map m
|
||||
"""
|
||||
if isinstance(m, TypeEnv):
|
||||
return ConstrainTVsEqual(m[self.tv1], m[self.tv2])
|
||||
else:
|
||||
return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m))
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if (not isinstance(other, ConstrainTVsEqual)):
|
||||
return False
|
||||
|
||||
return (self.tv1, self.tv2) == (other.tv1, other.tv2)
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
return hash((self.tv1, self.tv2))
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
return self.tv1 == self.tv2 or self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Evaluate this constraint. Should only be called when the constraint has
|
||||
been translated to concrete types.
|
||||
"""
|
||||
assert self.tv1.singleton_type() is not None and \
|
||||
self.tv2.singleton_type() is not None
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
return self.tv1.singleton_type() == self.tv2.singleton_type()
|
||||
|
||||
|
||||
class ConstrainTVInTypeset(TypeConstraint):
|
||||
class InTypeset(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that a type var must belong to some typeset.
|
||||
"""
|
||||
@@ -88,11 +127,14 @@ class ConstrainTVInTypeset(TypeConstraint):
|
||||
self.tv = tv
|
||||
self.ts = ts
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv, self.ts)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true if this constrain is statically decidable.
|
||||
"""
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
tv_ts = self.tv.get_typeset().copy()
|
||||
|
||||
# Trivially True
|
||||
@@ -104,39 +146,78 @@ class ConstrainTVInTypeset(TypeConstraint):
|
||||
if (tv_ts.size() == 0):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def translate(self, m):
|
||||
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset
|
||||
"""
|
||||
Translate any TypeVars in the constraint according to the map m
|
||||
"""
|
||||
if isinstance(m, TypeEnv):
|
||||
return ConstrainTVInTypeset(m[self.tv], self.ts)
|
||||
else:
|
||||
return ConstrainTVInTypeset(subst(self.tv, m), self.ts)
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if (not isinstance(other, ConstrainTVInTypeset)):
|
||||
return False
|
||||
|
||||
return (self.tv, self.ts) == (other.tv, other.ts)
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
return hash((self.tv, self.ts))
|
||||
return self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Evaluate this constraint. Should only be called when the constraint has
|
||||
been translated to concrete types.
|
||||
"""
|
||||
assert self.tv.singleton_type() is not None
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
return self.tv.get_typeset().issubset(self.ts)
|
||||
|
||||
|
||||
class WiderOrEq(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that a type var tv1 must be wider than or equal to
|
||||
type var tv2 at runtime. This requires that:
|
||||
1) They have the same number of lanes
|
||||
2) In a lane tv1 has at least as many bits as tv2.
|
||||
"""
|
||||
def __init__(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
self.tv1 = tv1
|
||||
self.tv2 = tv2
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv1, self.tv2)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
# Trivially true
|
||||
if (self.tv1 == self.tv2):
|
||||
return True
|
||||
|
||||
ts1 = self.tv1.get_typeset()
|
||||
ts2 = self.tv2.get_typeset()
|
||||
|
||||
def set_wider_or_equal(s1, s2):
|
||||
# type: (Set[int], Set[int]) -> bool
|
||||
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
|
||||
|
||||
# Trivially True
|
||||
if set_wider_or_equal(ts1.ints, ts2.ints) and\
|
||||
set_wider_or_equal(ts1.floats, ts2.floats) and\
|
||||
set_wider_or_equal(ts1.bools, ts2.bools):
|
||||
return True
|
||||
|
||||
def set_narrower(s1, s2):
|
||||
# type: (Set[int], Set[int]) -> bool
|
||||
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
|
||||
|
||||
# Trivially False
|
||||
if set_narrower(ts1.ints, ts2.ints) and\
|
||||
set_narrower(ts1.floats, ts2.floats) and\
|
||||
set_narrower(ts1.bools, ts2.bools):
|
||||
return True
|
||||
|
||||
# Trivially False
|
||||
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
|
||||
return True
|
||||
|
||||
return self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
typ1 = self.tv1.singleton_type()
|
||||
typ2 = self.tv2.singleton_type()
|
||||
|
||||
return typ1.wider_or_equal(typ2)
|
||||
|
||||
|
||||
class TypeEnv(object):
|
||||
"""
|
||||
Class encapsulating the neccessary book keeping for type inference.
|
||||
@@ -204,12 +285,11 @@ class TypeEnv(object):
|
||||
|
||||
self.type_map[tv1] = tv2
|
||||
|
||||
def add_constraint(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
def add_constraint(self, constr):
|
||||
# type: (TypeConstraint) -> None
|
||||
"""
|
||||
Add a new equivalence constraint between tv1 and tv2
|
||||
Add a new constraint
|
||||
"""
|
||||
constr = ConstrainTVsEqual(tv1, tv2)
|
||||
if (constr not in self.constraints):
|
||||
self.constraints.append(constr)
|
||||
|
||||
@@ -261,6 +341,7 @@ class TypeEnv(object):
|
||||
Get the free typevars in the current type env.
|
||||
"""
|
||||
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
|
||||
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
|
||||
# Filter out None here due to singleton type vars
|
||||
return sorted(filter(lambda x: x is not None, tvs),
|
||||
key=lambda x: x.name)
|
||||
@@ -326,17 +407,18 @@ class TypeEnv(object):
|
||||
|
||||
new_constraints = [] # type: List[TypeConstraint]
|
||||
for constr in self.constraints:
|
||||
# Currently typeinference only generates ConstrainTVsEqual
|
||||
# constraints
|
||||
assert isinstance(constr, ConstrainTVsEqual)
|
||||
constr = constr.translate(self)
|
||||
|
||||
if constr.is_trivial() or constr in new_constraints:
|
||||
continue
|
||||
|
||||
# Sanity: translated constraints should refer to only real vars
|
||||
assert constr.tv1.free_typevar() in vars_tvs and\
|
||||
constr.tv2.free_typevar() in vars_tvs
|
||||
for arg in constr._args():
|
||||
if (not isinstance(arg, TypeVar)):
|
||||
continue
|
||||
|
||||
arg_free_tv = arg.free_typevar()
|
||||
assert arg_free_tv is None or arg_free_tv in vars_tvs
|
||||
|
||||
new_constraints.append(constr)
|
||||
|
||||
@@ -372,9 +454,6 @@ class TypeEnv(object):
|
||||
# Check if constraints are satisfied for this typing
|
||||
failed = None
|
||||
for constr in self.constraints:
|
||||
# Currently typeinference only generates ConstrainTVsEqual
|
||||
# constraints
|
||||
assert isinstance(constr, ConstrainTVsEqual)
|
||||
concrete_constr = constr.translate(m)
|
||||
if not concrete_constr.eval():
|
||||
failed = concrete_constr
|
||||
@@ -401,22 +480,27 @@ class TypeEnv(object):
|
||||
# Add all registered TVs (as some of them may be singleton nodes not
|
||||
# appearing in the graph
|
||||
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
|
||||
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, Optional[str]]]
|
||||
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
|
||||
|
||||
for (k, v) in self.type_map.items():
|
||||
# Add all intermediate TVs appearing in edges
|
||||
nodes.add(k)
|
||||
nodes.add(v)
|
||||
edges.add((k, v, "dotted", None))
|
||||
edges.add((k, v, "dotted", "forward", None))
|
||||
while (v.is_derived):
|
||||
nodes.add(v.base)
|
||||
edges.add((v, v.base, "solid", v.derived_func))
|
||||
edges.add((v, v.base, "solid", "forward", v.derived_func))
|
||||
v = v.base
|
||||
|
||||
for constr in self.constraints:
|
||||
assert isinstance(constr, ConstrainTVsEqual)
|
||||
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", None))
|
||||
if isinstance(constr, TypesEqual):
|
||||
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
|
||||
elif isinstance(constr, WiderOrEq):
|
||||
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
|
||||
else:
|
||||
assert False, "Can't display constraint {}".format(constr)
|
||||
|
||||
root_nodes = set([x for x in nodes
|
||||
if x not in self.type_map and not x.is_derived])
|
||||
@@ -428,17 +512,12 @@ class TypeEnv(object):
|
||||
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
|
||||
r += ";\n"
|
||||
|
||||
for (n1, n2, style, elabel) in edges:
|
||||
e = label(n1)
|
||||
if style == "dashed":
|
||||
e += '--'
|
||||
else:
|
||||
e += '->'
|
||||
e += label(n2)
|
||||
e += "[style={}".format(style)
|
||||
for (n1, n2, style, direction, elabel) in edges:
|
||||
e = label(n1) + "->" + label(n2)
|
||||
e += "[style={},dir={}".format(style, direction)
|
||||
|
||||
if elabel is not None:
|
||||
e += ",label={}".format(elabel)
|
||||
e += ",label=\"{}\"".format(elabel)
|
||||
e += "];\n"
|
||||
|
||||
r += e
|
||||
@@ -589,7 +668,7 @@ def unify(tv1, tv2, typ):
|
||||
inv_f = TypeVar.inverse_func(tv1.derived_func)
|
||||
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
|
||||
|
||||
typ.add_constraint(tv1, tv2)
|
||||
typ.add_constraint(TypesEqual(tv1, tv2))
|
||||
return typ
|
||||
|
||||
|
||||
@@ -648,6 +727,10 @@ def ti_def(definition, typ):
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
# Add any instruction specific constraints
|
||||
for constr in inst.constraints:
|
||||
typ.add_constraint(constr.translate(m))
|
||||
|
||||
return typ
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user