Cleanup typos; Remove SAMEAS; More descriptive rank comments; Introduce explicit sorting in free_typevars() (#111)
As per the comment in TypeEnv.normalize_tv about cancellation, whenever we create a TypeVar we must assert that there is no under/overflow. To make sure this always happen move the safety checks to TypeVar.derived() from the other helper methods
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
b7917fe404
commit
01abbcbebe
@@ -62,7 +62,7 @@ def agree(me, other):
|
||||
if m[me[tv]] != other[tv]:
|
||||
return False
|
||||
|
||||
# Tranlsate our constraints using m, and sort
|
||||
# Translate our constraints using m, and sort
|
||||
me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints]
|
||||
me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr])
|
||||
|
||||
@@ -76,7 +76,7 @@ def agree(me, other):
|
||||
def check_typing(got_or_err, expected, symtab=None):
|
||||
# type: (TypingOrError, Tuple[VarMap, ConstraintList], Dict[str, Var]) -> None # noqa
|
||||
"""
|
||||
Check that a the typying we received (got_or_err) complies with the
|
||||
Check that a the typing we received (got_or_err) complies with the
|
||||
expected typing (expected). If symtab is specified, substitute the Vars in
|
||||
expected using symtab first (used when checking type inference on XForms)
|
||||
"""
|
||||
@@ -409,7 +409,7 @@ class TestXForm(TypeCheckingBaseTest):
|
||||
|
||||
# If there are no free_typevars, this is a non-polymorphic pattern.
|
||||
# There should be only one possible concrete typing.
|
||||
if (len(xform.free_typevars) == 0):
|
||||
if (len(xform.ti.free_typevars()) == 0):
|
||||
assert len(concrete_typings_list) == 1
|
||||
continue
|
||||
|
||||
@@ -423,7 +423,7 @@ class TestXForm(TypeCheckingBaseTest):
|
||||
theoretical_num_typings =\
|
||||
reduce(lambda x, y: x*y,
|
||||
[tv.get_typeset().size()
|
||||
for tv in xform.free_typevars], 1)
|
||||
for tv in xform.ti.free_typevars()], 1)
|
||||
assert len(concrete_typings_list) < theoretical_num_typings
|
||||
|
||||
# Check the validity of each individual concrete typing against the
|
||||
|
||||
@@ -127,7 +127,6 @@ class TestTypeSet(TestCase):
|
||||
|
||||
def test_preimage(self):
|
||||
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
|
||||
self.assertEqual(t, t.preimage(TypeVar.SAMEAS))
|
||||
|
||||
# LANEOF
|
||||
self.assertEqual(TypeSet(lanes=True, ints=(8, 8), floats=(32, 32)),
|
||||
@@ -217,12 +216,11 @@ class TestTypeVar(TestCase):
|
||||
self.assertEqual(len(x.type_set.bools), 0)
|
||||
|
||||
def test_stress_constrain_types(self):
|
||||
# Get all 49 possible derived vars of length 2. Since we have SAMEAS
|
||||
# this includes singly derived and non-derived vars
|
||||
funcs = [TypeVar.SAMEAS, TypeVar.LANEOF,
|
||||
# Get all 43 possible derived vars of length up to 2
|
||||
funcs = [TypeVar.LANEOF,
|
||||
TypeVar.ASBOOL, TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR,
|
||||
TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
|
||||
v = list(product(*[funcs, funcs]))
|
||||
v = [()] + [(x,) for x in funcs] + list(product(*[funcs, funcs]))
|
||||
|
||||
# For each pair of derived variables
|
||||
for (i1, i2) in product(v, v):
|
||||
|
||||
@@ -29,11 +29,19 @@ class TypeEnv(object):
|
||||
:attribute constraints: a list of accumulated constraints - tuples
|
||||
(tv1, tv2)) where tv1 and tv2 are equal
|
||||
:attribute ranks: dictionary recording the (optional) ranks for tvs.
|
||||
tvs corresponding to real variables have explicitly
|
||||
specified ranks.
|
||||
'rank' is a partial ordering on TVs based on their
|
||||
origin. See comments in rank() and register().
|
||||
:attribute vars: a set containing all known Vars
|
||||
:attribute idx: counter used to get fresh ids
|
||||
"""
|
||||
|
||||
RANK_DERIVED = 5
|
||||
RANK_INPUT = 4
|
||||
RANK_INTERMEDIATE = 3
|
||||
RANK_OUTPUT = 2
|
||||
RANK_TEMP = 1
|
||||
RANK_INTERNAL = 0
|
||||
|
||||
def __init__(self, arg=None):
|
||||
# type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None
|
||||
self.ranks = {} # type: Dict[TypeVar, int]
|
||||
@@ -104,9 +112,10 @@ class TypeEnv(object):
|
||||
Get the rank of tv in the partial order. TVs directly associated with a
|
||||
Var get their rank from the Var (see register()).
|
||||
Internally generated non-derived TVs implicitly get the lowest rank (0)
|
||||
Internal derived variables get the highest rank.
|
||||
Derived variables get the highest rank.
|
||||
"""
|
||||
default_rank = 5 if tv.is_derived else 0
|
||||
default_rank = TypeEnv.RANK_DERIVED if tv.is_derived else\
|
||||
TypeEnv.RANK_INTERNAL
|
||||
return self.ranks.get(tv, default_rank)
|
||||
|
||||
def register(self, v):
|
||||
@@ -118,25 +127,26 @@ class TypeEnv(object):
|
||||
self.vars.add(v)
|
||||
|
||||
if v.is_input():
|
||||
r = 4
|
||||
r = TypeEnv.RANK_INPUT
|
||||
elif v.is_intermediate():
|
||||
r = 3
|
||||
r = TypeEnv.RANK_INTERMEDIATE
|
||||
elif v.is_output():
|
||||
r = 2
|
||||
r = TypeEnv.RANK_OUTPUT
|
||||
else:
|
||||
assert(v.is_temp())
|
||||
r = 1
|
||||
r = TypeEnv.RANK_TEMP
|
||||
|
||||
self.ranks[v.get_typevar()] = r
|
||||
|
||||
def free_typevars(self):
|
||||
# type: () -> Set[TypeVar]
|
||||
# type: () -> List[TypeVar]
|
||||
"""
|
||||
Get the free typevars in the current type env.
|
||||
"""
|
||||
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
|
||||
# Filter out None here due to singleton type vars
|
||||
return set(filter(lambda x: x is not None, tvs))
|
||||
return sorted(filter(lambda x: x is not None, tvs),
|
||||
key=lambda x: x.name)
|
||||
|
||||
def normalize(self):
|
||||
# type: () -> None
|
||||
@@ -178,7 +188,7 @@ class TypeEnv(object):
|
||||
s.add(a)
|
||||
children[b] = s
|
||||
|
||||
for r in list(self.free_typevars()):
|
||||
for r in self.free_typevars():
|
||||
while (r not in source_tvs and r in children and
|
||||
len(children[r]) == 1):
|
||||
child = list(children[r])[0]
|
||||
@@ -359,9 +369,6 @@ def normalize_tv(tv):
|
||||
# type: (TypeVar) -> TypeVar
|
||||
"""
|
||||
Normalize a (potentially derived) TV using the following rules:
|
||||
- collapse SAMEAS
|
||||
SAMEAS(base) -> base
|
||||
|
||||
- vector and width derived functions commute
|
||||
{HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) ->
|
||||
{HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base))
|
||||
@@ -378,10 +385,6 @@ def normalize_tv(tv):
|
||||
|
||||
df = tv.derived_func
|
||||
|
||||
# Collapse SAMEAS edges
|
||||
if (df == TypeVar.SAMEAS):
|
||||
return normalize_tv(tv.base)
|
||||
|
||||
if (tv.base.is_derived):
|
||||
base_df = tv.base.derived_func
|
||||
|
||||
@@ -393,8 +396,9 @@ def normalize_tv(tv):
|
||||
TypeVar.derived(tv.base.base, df), base_df))
|
||||
|
||||
# Cancelling: HALFWIDTH, DOUBLEWIDTH and HALFVECTOR, DOUBLEVECTOR
|
||||
# cancel each other. TODO: Does this cancellation hide type
|
||||
# overflow/underflow?
|
||||
# cancel each other. Note: This doesn't hide any over/underflows,
|
||||
# since we 1) assert the safety of each TV in the chain upon its
|
||||
# creation, and 2) the base typeset is only allowed to shrink.
|
||||
|
||||
if (df, base_df) in \
|
||||
[(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR),
|
||||
|
||||
@@ -350,9 +350,7 @@ class TypeSet(object):
|
||||
"""
|
||||
Return the image of self across the derived function func
|
||||
"""
|
||||
if (func == TypeVar.SAMEAS):
|
||||
return self
|
||||
elif (func == TypeVar.LANEOF):
|
||||
if (func == TypeVar.LANEOF):
|
||||
return self.lane_of()
|
||||
elif (func == TypeVar.ASBOOL):
|
||||
return self.as_bool()
|
||||
@@ -376,9 +374,7 @@ class TypeSet(object):
|
||||
if (self.size() == 0):
|
||||
return self
|
||||
|
||||
if (func == TypeVar.SAMEAS):
|
||||
return self
|
||||
elif (func == TypeVar.LANEOF):
|
||||
if (func == TypeVar.LANEOF):
|
||||
new = self.copy()
|
||||
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
|
||||
return new
|
||||
@@ -388,6 +384,9 @@ class TypeSet(object):
|
||||
if 1 not in self.bools:
|
||||
new.ints = self.bools.difference(set([1]))
|
||||
new.floats = self.bools.intersection(set([32, 64]))
|
||||
# If b1 is not in our typeset, than lanes=1 cannot be in the
|
||||
# pre-image, as as_bool() of scalars is always b1.
|
||||
new.lanes = self.lanes.difference(set([1]))
|
||||
else:
|
||||
new.ints = set([2**x for x in range(3, 7)])
|
||||
new.floats = set([32, 64])
|
||||
@@ -553,7 +552,6 @@ class TypeVar(object):
|
||||
# The names here must match the method names on `ir::types::Type`.
|
||||
# The camel_case of the names must match `enum OperandConstraint` in
|
||||
# `instructions.rs`.
|
||||
SAMEAS = 'same_as'
|
||||
LANEOF = 'lane_of'
|
||||
ASBOOL = 'as_bool'
|
||||
HALFWIDTH = 'half_width'
|
||||
@@ -565,7 +563,6 @@ class TypeVar(object):
|
||||
def is_bijection(func):
|
||||
# type: (str) -> bool
|
||||
return func in [
|
||||
TypeVar.SAMEAS,
|
||||
TypeVar.HALFWIDTH,
|
||||
TypeVar.DOUBLEWIDTH,
|
||||
TypeVar.HALFVECTOR,
|
||||
@@ -575,7 +572,6 @@ class TypeVar(object):
|
||||
def inverse_func(func):
|
||||
# type: (str) -> str
|
||||
return {
|
||||
TypeVar.SAMEAS: TypeVar.SAMEAS,
|
||||
TypeVar.HALFWIDTH: TypeVar.DOUBLEWIDTH,
|
||||
TypeVar.DOUBLEWIDTH: TypeVar.HALFWIDTH,
|
||||
TypeVar.HALFVECTOR: TypeVar.DOUBLEVECTOR,
|
||||
@@ -586,6 +582,31 @@ class TypeVar(object):
|
||||
def derived(base, derived_func):
|
||||
# type: (TypeVar, str) -> TypeVar
|
||||
"""Create a type variable that is a function of another."""
|
||||
|
||||
# Safety checks to avoid over/underflows.
|
||||
ts = base.get_typeset()
|
||||
|
||||
if derived_func == TypeVar.HALFWIDTH:
|
||||
if len(ts.ints) > 0:
|
||||
assert min(ts.ints) > 8, "Can't halve all integer types"
|
||||
if len(ts.floats) > 0:
|
||||
assert min(ts.floats) > 32, "Can't halve all float types"
|
||||
if len(ts.bools) > 0:
|
||||
assert min(ts.bools) > 8, "Can't halve all boolean types"
|
||||
elif derived_func == TypeVar.DOUBLEWIDTH:
|
||||
if len(ts.ints) > 0:
|
||||
assert max(ts.ints) < MAX_BITS,\
|
||||
"Can't double all integer types."
|
||||
if len(ts.floats) > 0:
|
||||
assert max(ts.floats) < MAX_BITS,\
|
||||
"Can't double all float types."
|
||||
if len(ts.bools) > 0:
|
||||
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
|
||||
elif derived_func == TypeVar.HALFVECTOR:
|
||||
assert min(ts.lanes) > 1, "Can't halve a scalar type"
|
||||
elif derived_func == TypeVar.DOUBLEVECTOR:
|
||||
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
|
||||
|
||||
return TypeVar(None, None, base=base, derived_func=derived_func)
|
||||
|
||||
@staticmethod
|
||||
@@ -596,27 +617,6 @@ class TypeVar(object):
|
||||
tv.type_set = ts
|
||||
return tv
|
||||
|
||||
def change_to_derived(self, base, derived_func):
|
||||
# type: (TypeVar, str) -> None
|
||||
"""Change this type variable into a derived one."""
|
||||
self.type_set = None
|
||||
self.is_derived = True
|
||||
self.base = base
|
||||
self.derived_func = derived_func
|
||||
|
||||
def strip_sameas(self):
|
||||
# type: () -> TypeVar
|
||||
"""
|
||||
Strip any `SAMEAS` functions from this typevar.
|
||||
|
||||
Also rewrite any `SAMEAS` functions nested under this typevar.
|
||||
"""
|
||||
if self.is_derived:
|
||||
self.base = self.base.strip_sameas()
|
||||
if self.derived_func == self.SAMEAS:
|
||||
return self.base
|
||||
return self
|
||||
|
||||
def lane_of(self):
|
||||
# type: () -> TypeVar
|
||||
"""
|
||||
@@ -642,14 +642,6 @@ class TypeVar(object):
|
||||
Return a derived type variable that has the same number of vector lanes
|
||||
as this one, but the lanes are half the width.
|
||||
"""
|
||||
ts = self.get_typeset()
|
||||
if len(ts.ints) > 0:
|
||||
assert min(ts.ints) > 8, "Can't halve all integer types"
|
||||
if len(ts.floats) > 0:
|
||||
assert min(ts.floats) > 32, "Can't halve all float types"
|
||||
if len(ts.bools) > 0:
|
||||
assert min(ts.bools) > 8, "Can't halve all boolean types"
|
||||
|
||||
return TypeVar.derived(self, self.HALFWIDTH)
|
||||
|
||||
def double_width(self):
|
||||
@@ -658,14 +650,6 @@ class TypeVar(object):
|
||||
Return a derived type variable that has the same number of vector lanes
|
||||
as this one, but the lanes are double the width.
|
||||
"""
|
||||
ts = self.get_typeset()
|
||||
if len(ts.ints) > 0:
|
||||
assert max(ts.ints) < MAX_BITS, "Can't double all integer types."
|
||||
if len(ts.floats) > 0:
|
||||
assert max(ts.floats) < MAX_BITS, "Can't double all float types."
|
||||
if len(ts.bools) > 0:
|
||||
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
|
||||
|
||||
return TypeVar.derived(self, self.DOUBLEWIDTH)
|
||||
|
||||
def half_vector(self):
|
||||
@@ -674,9 +658,6 @@ class TypeVar(object):
|
||||
Return a derived type variable that has half the number of vector lanes
|
||||
as this one, with the same lane type.
|
||||
"""
|
||||
ts = self.get_typeset()
|
||||
assert min(ts.lanes) > 1, "Can't halve a scalar type"
|
||||
|
||||
return TypeVar.derived(self, self.HALFVECTOR)
|
||||
|
||||
def double_vector(self):
|
||||
@@ -685,9 +666,6 @@ class TypeVar(object):
|
||||
Return a derived type variable that has twice the number of vector
|
||||
lanes as this one, with the same lane type.
|
||||
"""
|
||||
ts = self.get_typeset()
|
||||
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
|
||||
|
||||
return TypeVar.derived(self, self.DOUBLEVECTOR)
|
||||
|
||||
def singleton_type(self):
|
||||
@@ -744,15 +722,11 @@ class TypeVar(object):
|
||||
"""
|
||||
Constrain the range of types this variable can assume to a subset of
|
||||
those `other` can assume.
|
||||
|
||||
If this is a SAMEAS-derived type variable, constrain the base instead.
|
||||
"""
|
||||
a = self.strip_sameas()
|
||||
b = other.strip_sameas()
|
||||
if a is b:
|
||||
if self is other:
|
||||
return
|
||||
|
||||
a.constrain_types_by_ts(b.get_typeset())
|
||||
self.constrain_types_by_ts(other.get_typeset())
|
||||
|
||||
def get_typeset(self):
|
||||
# type: () -> TypeSet
|
||||
|
||||
@@ -106,14 +106,14 @@ class XForm(object):
|
||||
|
||||
# Sanity: The set of inferred free typevars should be a subset of the
|
||||
# TVs corresponding to Vars appearing in src
|
||||
self.free_typevars = self.ti.free_typevars()
|
||||
free_typevars = set(self.ti.free_typevars())
|
||||
src_vars = set(self.inputs).union(
|
||||
[x for x in self.defs if not x.is_temp()])
|
||||
src_tvs = set([v.get_typevar() for v in src_vars])
|
||||
if (not self.free_typevars.issubset(src_tvs)):
|
||||
if (not free_typevars.issubset(src_tvs)):
|
||||
raise AssertionError(
|
||||
"Some free vars don't appear in src - {}"
|
||||
.format(self.free_typevars.difference(src_tvs)))
|
||||
.format(free_typevars.difference(src_tvs)))
|
||||
|
||||
# Update the type vars for each Var to their inferred values
|
||||
for v in self.inputs + self.defs:
|
||||
|
||||
Reference in New Issue
Block a user