Cleanup typos; Remove SAMEAS; More descriptive rank comments; Introduce explicit sorting in free_typevars() (#111)

As per the comment in TypeEnv.normalize_tv about cancellation, whenever we create a TypeVar we must assert that there is no under/overflow. To make sure this always happen move the safety checks to TypeVar.derived() from the other helper methods
This commit is contained in:
d1m0
2017-07-05 15:47:44 -07:00
committed by Jakob Stoklund Olesen
parent fe127ab3eb
commit 83e55525d6
5 changed files with 66 additions and 90 deletions

View File

@@ -62,7 +62,7 @@ def agree(me, other):
if m[me[tv]] != other[tv]:
return False
# Tranlsate our constraints using m, and sort
# Translate our constraints using m, and sort
me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints]
me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr])
@@ -76,7 +76,7 @@ def agree(me, other):
def check_typing(got_or_err, expected, symtab=None):
# type: (TypingOrError, Tuple[VarMap, ConstraintList], Dict[str, Var]) -> None # noqa
"""
Check that a the typying we received (got_or_err) complies with the
Check that a the typing we received (got_or_err) complies with the
expected typing (expected). If symtab is specified, substitute the Vars in
expected using symtab first (used when checking type inference on XForms)
"""
@@ -409,7 +409,7 @@ class TestXForm(TypeCheckingBaseTest):
# If there are no free_typevars, this is a non-polymorphic pattern.
# There should be only one possible concrete typing.
if (len(xform.free_typevars) == 0):
if (len(xform.ti.free_typevars()) == 0):
assert len(concrete_typings_list) == 1
continue
@@ -423,7 +423,7 @@ class TestXForm(TypeCheckingBaseTest):
theoretical_num_typings =\
reduce(lambda x, y: x*y,
[tv.get_typeset().size()
for tv in xform.free_typevars], 1)
for tv in xform.ti.free_typevars()], 1)
assert len(concrete_typings_list) < theoretical_num_typings
# Check the validity of each individual concrete typing against the

View File

@@ -127,7 +127,6 @@ class TestTypeSet(TestCase):
def test_preimage(self):
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
self.assertEqual(t, t.preimage(TypeVar.SAMEAS))
# LANEOF
self.assertEqual(TypeSet(lanes=True, ints=(8, 8), floats=(32, 32)),
@@ -217,12 +216,11 @@ class TestTypeVar(TestCase):
self.assertEqual(len(x.type_set.bools), 0)
def test_stress_constrain_types(self):
# Get all 49 possible derived vars of length 2. Since we have SAMEAS
# this includes singly derived and non-derived vars
funcs = [TypeVar.SAMEAS, TypeVar.LANEOF,
# Get all 43 possible derived vars of length up to 2
funcs = [TypeVar.LANEOF,
TypeVar.ASBOOL, TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR,
TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
v = list(product(*[funcs, funcs]))
v = [()] + [(x,) for x in funcs] + list(product(*[funcs, funcs]))
# For each pair of derived variables
for (i1, i2) in product(v, v):

View File

@@ -29,11 +29,19 @@ class TypeEnv(object):
:attribute constraints: a list of accumulated constraints - tuples
(tv1, tv2)) where tv1 and tv2 are equal
:attribute ranks: dictionary recording the (optional) ranks for tvs.
tvs corresponding to real variables have explicitly
specified ranks.
'rank' is a partial ordering on TVs based on their
origin. See comments in rank() and register().
:attribute vars: a set containing all known Vars
:attribute idx: counter used to get fresh ids
"""
RANK_DERIVED = 5
RANK_INPUT = 4
RANK_INTERMEDIATE = 3
RANK_OUTPUT = 2
RANK_TEMP = 1
RANK_INTERNAL = 0
def __init__(self, arg=None):
# type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None
self.ranks = {} # type: Dict[TypeVar, int]
@@ -104,9 +112,10 @@ class TypeEnv(object):
Get the rank of tv in the partial order. TVs directly associated with a
Var get their rank from the Var (see register()).
Internally generated non-derived TVs implicitly get the lowest rank (0)
Internal derived variables get the highest rank.
Derived variables get the highest rank.
"""
default_rank = 5 if tv.is_derived else 0
default_rank = TypeEnv.RANK_DERIVED if tv.is_derived else\
TypeEnv.RANK_INTERNAL
return self.ranks.get(tv, default_rank)
def register(self, v):
@@ -118,25 +127,26 @@ class TypeEnv(object):
self.vars.add(v)
if v.is_input():
r = 4
r = TypeEnv.RANK_INPUT
elif v.is_intermediate():
r = 3
r = TypeEnv.RANK_INTERMEDIATE
elif v.is_output():
r = 2
r = TypeEnv.RANK_OUTPUT
else:
assert(v.is_temp())
r = 1
r = TypeEnv.RANK_TEMP
self.ranks[v.get_typevar()] = r
def free_typevars(self):
# type: () -> Set[TypeVar]
# type: () -> List[TypeVar]
"""
Get the free typevars in the current type env.
"""
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
# Filter out None here due to singleton type vars
return set(filter(lambda x: x is not None, tvs))
return sorted(filter(lambda x: x is not None, tvs),
key=lambda x: x.name)
def normalize(self):
# type: () -> None
@@ -178,7 +188,7 @@ class TypeEnv(object):
s.add(a)
children[b] = s
for r in list(self.free_typevars()):
for r in self.free_typevars():
while (r not in source_tvs and r in children and
len(children[r]) == 1):
child = list(children[r])[0]
@@ -359,9 +369,6 @@ def normalize_tv(tv):
# type: (TypeVar) -> TypeVar
"""
Normalize a (potentially derived) TV using the following rules:
- collapse SAMEAS
SAMEAS(base) -> base
- vector and width derived functions commute
{HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) ->
{HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base))
@@ -378,10 +385,6 @@ def normalize_tv(tv):
df = tv.derived_func
# Collapse SAMEAS edges
if (df == TypeVar.SAMEAS):
return normalize_tv(tv.base)
if (tv.base.is_derived):
base_df = tv.base.derived_func
@@ -393,8 +396,9 @@ def normalize_tv(tv):
TypeVar.derived(tv.base.base, df), base_df))
# Cancelling: HALFWIDTH, DOUBLEWIDTH and HALFVECTOR, DOUBLEVECTOR
# cancel each other. TODO: Does this cancellation hide type
# overflow/underflow?
# cancel each other. Note: This doesn't hide any over/underflows,
# since we 1) assert the safety of each TV in the chain upon its
# creation, and 2) the base typeset is only allowed to shrink.
if (df, base_df) in \
[(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR),

View File

@@ -350,9 +350,7 @@ class TypeSet(object):
"""
Return the image of self across the derived function func
"""
if (func == TypeVar.SAMEAS):
return self
elif (func == TypeVar.LANEOF):
if (func == TypeVar.LANEOF):
return self.lane_of()
elif (func == TypeVar.ASBOOL):
return self.as_bool()
@@ -376,9 +374,7 @@ class TypeSet(object):
if (self.size() == 0):
return self
if (func == TypeVar.SAMEAS):
return self
elif (func == TypeVar.LANEOF):
if (func == TypeVar.LANEOF):
new = self.copy()
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
return new
@@ -388,6 +384,9 @@ class TypeSet(object):
if 1 not in self.bools:
new.ints = self.bools.difference(set([1]))
new.floats = self.bools.intersection(set([32, 64]))
# If b1 is not in our typeset, than lanes=1 cannot be in the
# pre-image, as as_bool() of scalars is always b1.
new.lanes = self.lanes.difference(set([1]))
else:
new.ints = set([2**x for x in range(3, 7)])
new.floats = set([32, 64])
@@ -553,7 +552,6 @@ class TypeVar(object):
# The names here must match the method names on `ir::types::Type`.
# The camel_case of the names must match `enum OperandConstraint` in
# `instructions.rs`.
SAMEAS = 'same_as'
LANEOF = 'lane_of'
ASBOOL = 'as_bool'
HALFWIDTH = 'half_width'
@@ -565,7 +563,6 @@ class TypeVar(object):
def is_bijection(func):
# type: (str) -> bool
return func in [
TypeVar.SAMEAS,
TypeVar.HALFWIDTH,
TypeVar.DOUBLEWIDTH,
TypeVar.HALFVECTOR,
@@ -575,7 +572,6 @@ class TypeVar(object):
def inverse_func(func):
# type: (str) -> str
return {
TypeVar.SAMEAS: TypeVar.SAMEAS,
TypeVar.HALFWIDTH: TypeVar.DOUBLEWIDTH,
TypeVar.DOUBLEWIDTH: TypeVar.HALFWIDTH,
TypeVar.HALFVECTOR: TypeVar.DOUBLEVECTOR,
@@ -586,6 +582,31 @@ class TypeVar(object):
def derived(base, derived_func):
# type: (TypeVar, str) -> TypeVar
"""Create a type variable that is a function of another."""
# Safety checks to avoid over/underflows.
ts = base.get_typeset()
if derived_func == TypeVar.HALFWIDTH:
if len(ts.ints) > 0:
assert min(ts.ints) > 8, "Can't halve all integer types"
if len(ts.floats) > 0:
assert min(ts.floats) > 32, "Can't halve all float types"
if len(ts.bools) > 0:
assert min(ts.bools) > 8, "Can't halve all boolean types"
elif derived_func == TypeVar.DOUBLEWIDTH:
if len(ts.ints) > 0:
assert max(ts.ints) < MAX_BITS,\
"Can't double all integer types."
if len(ts.floats) > 0:
assert max(ts.floats) < MAX_BITS,\
"Can't double all float types."
if len(ts.bools) > 0:
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
elif derived_func == TypeVar.HALFVECTOR:
assert min(ts.lanes) > 1, "Can't halve a scalar type"
elif derived_func == TypeVar.DOUBLEVECTOR:
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
return TypeVar(None, None, base=base, derived_func=derived_func)
@staticmethod
@@ -596,27 +617,6 @@ class TypeVar(object):
tv.type_set = ts
return tv
def change_to_derived(self, base, derived_func):
# type: (TypeVar, str) -> None
"""Change this type variable into a derived one."""
self.type_set = None
self.is_derived = True
self.base = base
self.derived_func = derived_func
def strip_sameas(self):
# type: () -> TypeVar
"""
Strip any `SAMEAS` functions from this typevar.
Also rewrite any `SAMEAS` functions nested under this typevar.
"""
if self.is_derived:
self.base = self.base.strip_sameas()
if self.derived_func == self.SAMEAS:
return self.base
return self
def lane_of(self):
# type: () -> TypeVar
"""
@@ -642,14 +642,6 @@ class TypeVar(object):
Return a derived type variable that has the same number of vector lanes
as this one, but the lanes are half the width.
"""
ts = self.get_typeset()
if len(ts.ints) > 0:
assert min(ts.ints) > 8, "Can't halve all integer types"
if len(ts.floats) > 0:
assert min(ts.floats) > 32, "Can't halve all float types"
if len(ts.bools) > 0:
assert min(ts.bools) > 8, "Can't halve all boolean types"
return TypeVar.derived(self, self.HALFWIDTH)
def double_width(self):
@@ -658,14 +650,6 @@ class TypeVar(object):
Return a derived type variable that has the same number of vector lanes
as this one, but the lanes are double the width.
"""
ts = self.get_typeset()
if len(ts.ints) > 0:
assert max(ts.ints) < MAX_BITS, "Can't double all integer types."
if len(ts.floats) > 0:
assert max(ts.floats) < MAX_BITS, "Can't double all float types."
if len(ts.bools) > 0:
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
return TypeVar.derived(self, self.DOUBLEWIDTH)
def half_vector(self):
@@ -674,9 +658,6 @@ class TypeVar(object):
Return a derived type variable that has half the number of vector lanes
as this one, with the same lane type.
"""
ts = self.get_typeset()
assert min(ts.lanes) > 1, "Can't halve a scalar type"
return TypeVar.derived(self, self.HALFVECTOR)
def double_vector(self):
@@ -685,9 +666,6 @@ class TypeVar(object):
Return a derived type variable that has twice the number of vector
lanes as this one, with the same lane type.
"""
ts = self.get_typeset()
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
return TypeVar.derived(self, self.DOUBLEVECTOR)
def singleton_type(self):
@@ -744,15 +722,11 @@ class TypeVar(object):
"""
Constrain the range of types this variable can assume to a subset of
those `other` can assume.
If this is a SAMEAS-derived type variable, constrain the base instead.
"""
a = self.strip_sameas()
b = other.strip_sameas()
if a is b:
if self is other:
return
a.constrain_types_by_ts(b.get_typeset())
self.constrain_types_by_ts(other.get_typeset())
def get_typeset(self):
# type: () -> TypeSet

View File

@@ -106,14 +106,14 @@ class XForm(object):
# Sanity: The set of inferred free typevars should be a subset of the
# TVs corresponding to Vars appearing in src
self.free_typevars = self.ti.free_typevars()
free_typevars = set(self.ti.free_typevars())
src_vars = set(self.inputs).union(
[x for x in self.defs if not x.is_temp()])
src_tvs = set([v.get_typevar() for v in src_vars])
if (not self.free_typevars.issubset(src_tvs)):
if (not free_typevars.issubset(src_tvs)):
raise AssertionError(
"Some free vars don't appear in src - {}"
.format(self.free_typevars.difference(src_tvs)))
.format(free_typevars.difference(src_tvs)))
# Update the type vars for each Var to their inferred values
for v in self.inputs + self.defs: