Cleanup typos; Remove SAMEAS; More descriptive rank comments; Introduce explicit sorting in free_typevars() (#111)

As per the comment in TypeEnv.normalize_tv about cancellation, whenever we create a TypeVar we must assert that there is no under/overflow. To make sure this always happen move the safety checks to TypeVar.derived() from the other helper methods
This commit is contained in:
d1m0
2017-07-05 15:47:44 -07:00
committed by Jakob Stoklund Olesen
parent fe127ab3eb
commit 83e55525d6
5 changed files with 66 additions and 90 deletions

View File

@@ -62,7 +62,7 @@ def agree(me, other):
if m[me[tv]] != other[tv]: if m[me[tv]] != other[tv]:
return False return False
# Tranlsate our constraints using m, and sort # Translate our constraints using m, and sort
me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints] me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints]
me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr]) me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr])
@@ -76,7 +76,7 @@ def agree(me, other):
def check_typing(got_or_err, expected, symtab=None): def check_typing(got_or_err, expected, symtab=None):
# type: (TypingOrError, Tuple[VarMap, ConstraintList], Dict[str, Var]) -> None # noqa # type: (TypingOrError, Tuple[VarMap, ConstraintList], Dict[str, Var]) -> None # noqa
""" """
Check that a the typying we received (got_or_err) complies with the Check that a the typing we received (got_or_err) complies with the
expected typing (expected). If symtab is specified, substitute the Vars in expected typing (expected). If symtab is specified, substitute the Vars in
expected using symtab first (used when checking type inference on XForms) expected using symtab first (used when checking type inference on XForms)
""" """
@@ -409,7 +409,7 @@ class TestXForm(TypeCheckingBaseTest):
# If there are no free_typevars, this is a non-polymorphic pattern. # If there are no free_typevars, this is a non-polymorphic pattern.
# There should be only one possible concrete typing. # There should be only one possible concrete typing.
if (len(xform.free_typevars) == 0): if (len(xform.ti.free_typevars()) == 0):
assert len(concrete_typings_list) == 1 assert len(concrete_typings_list) == 1
continue continue
@@ -423,7 +423,7 @@ class TestXForm(TypeCheckingBaseTest):
theoretical_num_typings =\ theoretical_num_typings =\
reduce(lambda x, y: x*y, reduce(lambda x, y: x*y,
[tv.get_typeset().size() [tv.get_typeset().size()
for tv in xform.free_typevars], 1) for tv in xform.ti.free_typevars()], 1)
assert len(concrete_typings_list) < theoretical_num_typings assert len(concrete_typings_list) < theoretical_num_typings
# Check the validity of each individual concrete typing against the # Check the validity of each individual concrete typing against the

View File

@@ -127,7 +127,6 @@ class TestTypeSet(TestCase):
def test_preimage(self): def test_preimage(self):
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32)) t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
self.assertEqual(t, t.preimage(TypeVar.SAMEAS))
# LANEOF # LANEOF
self.assertEqual(TypeSet(lanes=True, ints=(8, 8), floats=(32, 32)), self.assertEqual(TypeSet(lanes=True, ints=(8, 8), floats=(32, 32)),
@@ -217,12 +216,11 @@ class TestTypeVar(TestCase):
self.assertEqual(len(x.type_set.bools), 0) self.assertEqual(len(x.type_set.bools), 0)
def test_stress_constrain_types(self): def test_stress_constrain_types(self):
# Get all 49 possible derived vars of length 2. Since we have SAMEAS # Get all 43 possible derived vars of length up to 2
# this includes singly derived and non-derived vars funcs = [TypeVar.LANEOF,
funcs = [TypeVar.SAMEAS, TypeVar.LANEOF,
TypeVar.ASBOOL, TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR, TypeVar.ASBOOL, TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR,
TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH] TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
v = list(product(*[funcs, funcs])) v = [()] + [(x,) for x in funcs] + list(product(*[funcs, funcs]))
# For each pair of derived variables # For each pair of derived variables
for (i1, i2) in product(v, v): for (i1, i2) in product(v, v):

View File

@@ -29,11 +29,19 @@ class TypeEnv(object):
:attribute constraints: a list of accumulated constraints - tuples :attribute constraints: a list of accumulated constraints - tuples
(tv1, tv2)) where tv1 and tv2 are equal (tv1, tv2)) where tv1 and tv2 are equal
:attribute ranks: dictionary recording the (optional) ranks for tvs. :attribute ranks: dictionary recording the (optional) ranks for tvs.
tvs corresponding to real variables have explicitly 'rank' is a partial ordering on TVs based on their
specified ranks. origin. See comments in rank() and register().
:attribute vars: a set containing all known Vars :attribute vars: a set containing all known Vars
:attribute idx: counter used to get fresh ids :attribute idx: counter used to get fresh ids
""" """
RANK_DERIVED = 5
RANK_INPUT = 4
RANK_INTERMEDIATE = 3
RANK_OUTPUT = 2
RANK_TEMP = 1
RANK_INTERNAL = 0
def __init__(self, arg=None): def __init__(self, arg=None):
# type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None # type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None
self.ranks = {} # type: Dict[TypeVar, int] self.ranks = {} # type: Dict[TypeVar, int]
@@ -104,9 +112,10 @@ class TypeEnv(object):
Get the rank of tv in the partial order. TVs directly associated with a Get the rank of tv in the partial order. TVs directly associated with a
Var get their rank from the Var (see register()). Var get their rank from the Var (see register()).
Internally generated non-derived TVs implicitly get the lowest rank (0) Internally generated non-derived TVs implicitly get the lowest rank (0)
Internal derived variables get the highest rank. Derived variables get the highest rank.
""" """
default_rank = 5 if tv.is_derived else 0 default_rank = TypeEnv.RANK_DERIVED if tv.is_derived else\
TypeEnv.RANK_INTERNAL
return self.ranks.get(tv, default_rank) return self.ranks.get(tv, default_rank)
def register(self, v): def register(self, v):
@@ -118,25 +127,26 @@ class TypeEnv(object):
self.vars.add(v) self.vars.add(v)
if v.is_input(): if v.is_input():
r = 4 r = TypeEnv.RANK_INPUT
elif v.is_intermediate(): elif v.is_intermediate():
r = 3 r = TypeEnv.RANK_INTERMEDIATE
elif v.is_output(): elif v.is_output():
r = 2 r = TypeEnv.RANK_OUTPUT
else: else:
assert(v.is_temp()) assert(v.is_temp())
r = 1 r = TypeEnv.RANK_TEMP
self.ranks[v.get_typevar()] = r self.ranks[v.get_typevar()] = r
def free_typevars(self): def free_typevars(self):
# type: () -> Set[TypeVar] # type: () -> List[TypeVar]
""" """
Get the free typevars in the current type env. Get the free typevars in the current type env.
""" """
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()]) tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
# Filter out None here due to singleton type vars # Filter out None here due to singleton type vars
return set(filter(lambda x: x is not None, tvs)) return sorted(filter(lambda x: x is not None, tvs),
key=lambda x: x.name)
def normalize(self): def normalize(self):
# type: () -> None # type: () -> None
@@ -178,7 +188,7 @@ class TypeEnv(object):
s.add(a) s.add(a)
children[b] = s children[b] = s
for r in list(self.free_typevars()): for r in self.free_typevars():
while (r not in source_tvs and r in children and while (r not in source_tvs and r in children and
len(children[r]) == 1): len(children[r]) == 1):
child = list(children[r])[0] child = list(children[r])[0]
@@ -359,9 +369,6 @@ def normalize_tv(tv):
# type: (TypeVar) -> TypeVar # type: (TypeVar) -> TypeVar
""" """
Normalize a (potentially derived) TV using the following rules: Normalize a (potentially derived) TV using the following rules:
- collapse SAMEAS
SAMEAS(base) -> base
- vector and width derived functions commute - vector and width derived functions commute
{HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) -> {HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) ->
{HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base)) {HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base))
@@ -378,10 +385,6 @@ def normalize_tv(tv):
df = tv.derived_func df = tv.derived_func
# Collapse SAMEAS edges
if (df == TypeVar.SAMEAS):
return normalize_tv(tv.base)
if (tv.base.is_derived): if (tv.base.is_derived):
base_df = tv.base.derived_func base_df = tv.base.derived_func
@@ -393,8 +396,9 @@ def normalize_tv(tv):
TypeVar.derived(tv.base.base, df), base_df)) TypeVar.derived(tv.base.base, df), base_df))
# Cancelling: HALFWIDTH, DOUBLEWIDTH and HALFVECTOR, DOUBLEVECTOR # Cancelling: HALFWIDTH, DOUBLEWIDTH and HALFVECTOR, DOUBLEVECTOR
# cancel each other. TODO: Does this cancellation hide type # cancel each other. Note: This doesn't hide any over/underflows,
# overflow/underflow? # since we 1) assert the safety of each TV in the chain upon its
# creation, and 2) the base typeset is only allowed to shrink.
if (df, base_df) in \ if (df, base_df) in \
[(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR), [(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR),

View File

@@ -350,9 +350,7 @@ class TypeSet(object):
""" """
Return the image of self across the derived function func Return the image of self across the derived function func
""" """
if (func == TypeVar.SAMEAS): if (func == TypeVar.LANEOF):
return self
elif (func == TypeVar.LANEOF):
return self.lane_of() return self.lane_of()
elif (func == TypeVar.ASBOOL): elif (func == TypeVar.ASBOOL):
return self.as_bool() return self.as_bool()
@@ -376,9 +374,7 @@ class TypeSet(object):
if (self.size() == 0): if (self.size() == 0):
return self return self
if (func == TypeVar.SAMEAS): if (func == TypeVar.LANEOF):
return self
elif (func == TypeVar.LANEOF):
new = self.copy() new = self.copy()
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)]) new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
return new return new
@@ -388,6 +384,9 @@ class TypeSet(object):
if 1 not in self.bools: if 1 not in self.bools:
new.ints = self.bools.difference(set([1])) new.ints = self.bools.difference(set([1]))
new.floats = self.bools.intersection(set([32, 64])) new.floats = self.bools.intersection(set([32, 64]))
# If b1 is not in our typeset, than lanes=1 cannot be in the
# pre-image, as as_bool() of scalars is always b1.
new.lanes = self.lanes.difference(set([1]))
else: else:
new.ints = set([2**x for x in range(3, 7)]) new.ints = set([2**x for x in range(3, 7)])
new.floats = set([32, 64]) new.floats = set([32, 64])
@@ -553,7 +552,6 @@ class TypeVar(object):
# The names here must match the method names on `ir::types::Type`. # The names here must match the method names on `ir::types::Type`.
# The camel_case of the names must match `enum OperandConstraint` in # The camel_case of the names must match `enum OperandConstraint` in
# `instructions.rs`. # `instructions.rs`.
SAMEAS = 'same_as'
LANEOF = 'lane_of' LANEOF = 'lane_of'
ASBOOL = 'as_bool' ASBOOL = 'as_bool'
HALFWIDTH = 'half_width' HALFWIDTH = 'half_width'
@@ -565,7 +563,6 @@ class TypeVar(object):
def is_bijection(func): def is_bijection(func):
# type: (str) -> bool # type: (str) -> bool
return func in [ return func in [
TypeVar.SAMEAS,
TypeVar.HALFWIDTH, TypeVar.HALFWIDTH,
TypeVar.DOUBLEWIDTH, TypeVar.DOUBLEWIDTH,
TypeVar.HALFVECTOR, TypeVar.HALFVECTOR,
@@ -575,7 +572,6 @@ class TypeVar(object):
def inverse_func(func): def inverse_func(func):
# type: (str) -> str # type: (str) -> str
return { return {
TypeVar.SAMEAS: TypeVar.SAMEAS,
TypeVar.HALFWIDTH: TypeVar.DOUBLEWIDTH, TypeVar.HALFWIDTH: TypeVar.DOUBLEWIDTH,
TypeVar.DOUBLEWIDTH: TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH: TypeVar.HALFWIDTH,
TypeVar.HALFVECTOR: TypeVar.DOUBLEVECTOR, TypeVar.HALFVECTOR: TypeVar.DOUBLEVECTOR,
@@ -586,6 +582,31 @@ class TypeVar(object):
def derived(base, derived_func): def derived(base, derived_func):
# type: (TypeVar, str) -> TypeVar # type: (TypeVar, str) -> TypeVar
"""Create a type variable that is a function of another.""" """Create a type variable that is a function of another."""
# Safety checks to avoid over/underflows.
ts = base.get_typeset()
if derived_func == TypeVar.HALFWIDTH:
if len(ts.ints) > 0:
assert min(ts.ints) > 8, "Can't halve all integer types"
if len(ts.floats) > 0:
assert min(ts.floats) > 32, "Can't halve all float types"
if len(ts.bools) > 0:
assert min(ts.bools) > 8, "Can't halve all boolean types"
elif derived_func == TypeVar.DOUBLEWIDTH:
if len(ts.ints) > 0:
assert max(ts.ints) < MAX_BITS,\
"Can't double all integer types."
if len(ts.floats) > 0:
assert max(ts.floats) < MAX_BITS,\
"Can't double all float types."
if len(ts.bools) > 0:
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
elif derived_func == TypeVar.HALFVECTOR:
assert min(ts.lanes) > 1, "Can't halve a scalar type"
elif derived_func == TypeVar.DOUBLEVECTOR:
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
return TypeVar(None, None, base=base, derived_func=derived_func) return TypeVar(None, None, base=base, derived_func=derived_func)
@staticmethod @staticmethod
@@ -596,27 +617,6 @@ class TypeVar(object):
tv.type_set = ts tv.type_set = ts
return tv return tv
def change_to_derived(self, base, derived_func):
# type: (TypeVar, str) -> None
"""Change this type variable into a derived one."""
self.type_set = None
self.is_derived = True
self.base = base
self.derived_func = derived_func
def strip_sameas(self):
# type: () -> TypeVar
"""
Strip any `SAMEAS` functions from this typevar.
Also rewrite any `SAMEAS` functions nested under this typevar.
"""
if self.is_derived:
self.base = self.base.strip_sameas()
if self.derived_func == self.SAMEAS:
return self.base
return self
def lane_of(self): def lane_of(self):
# type: () -> TypeVar # type: () -> TypeVar
""" """
@@ -642,14 +642,6 @@ class TypeVar(object):
Return a derived type variable that has the same number of vector lanes Return a derived type variable that has the same number of vector lanes
as this one, but the lanes are half the width. as this one, but the lanes are half the width.
""" """
ts = self.get_typeset()
if len(ts.ints) > 0:
assert min(ts.ints) > 8, "Can't halve all integer types"
if len(ts.floats) > 0:
assert min(ts.floats) > 32, "Can't halve all float types"
if len(ts.bools) > 0:
assert min(ts.bools) > 8, "Can't halve all boolean types"
return TypeVar.derived(self, self.HALFWIDTH) return TypeVar.derived(self, self.HALFWIDTH)
def double_width(self): def double_width(self):
@@ -658,14 +650,6 @@ class TypeVar(object):
Return a derived type variable that has the same number of vector lanes Return a derived type variable that has the same number of vector lanes
as this one, but the lanes are double the width. as this one, but the lanes are double the width.
""" """
ts = self.get_typeset()
if len(ts.ints) > 0:
assert max(ts.ints) < MAX_BITS, "Can't double all integer types."
if len(ts.floats) > 0:
assert max(ts.floats) < MAX_BITS, "Can't double all float types."
if len(ts.bools) > 0:
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
return TypeVar.derived(self, self.DOUBLEWIDTH) return TypeVar.derived(self, self.DOUBLEWIDTH)
def half_vector(self): def half_vector(self):
@@ -674,9 +658,6 @@ class TypeVar(object):
Return a derived type variable that has half the number of vector lanes Return a derived type variable that has half the number of vector lanes
as this one, with the same lane type. as this one, with the same lane type.
""" """
ts = self.get_typeset()
assert min(ts.lanes) > 1, "Can't halve a scalar type"
return TypeVar.derived(self, self.HALFVECTOR) return TypeVar.derived(self, self.HALFVECTOR)
def double_vector(self): def double_vector(self):
@@ -685,9 +666,6 @@ class TypeVar(object):
Return a derived type variable that has twice the number of vector Return a derived type variable that has twice the number of vector
lanes as this one, with the same lane type. lanes as this one, with the same lane type.
""" """
ts = self.get_typeset()
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
return TypeVar.derived(self, self.DOUBLEVECTOR) return TypeVar.derived(self, self.DOUBLEVECTOR)
def singleton_type(self): def singleton_type(self):
@@ -744,15 +722,11 @@ class TypeVar(object):
""" """
Constrain the range of types this variable can assume to a subset of Constrain the range of types this variable can assume to a subset of
those `other` can assume. those `other` can assume.
If this is a SAMEAS-derived type variable, constrain the base instead.
""" """
a = self.strip_sameas() if self is other:
b = other.strip_sameas()
if a is b:
return return
a.constrain_types_by_ts(b.get_typeset()) self.constrain_types_by_ts(other.get_typeset())
def get_typeset(self): def get_typeset(self):
# type: () -> TypeSet # type: () -> TypeSet

View File

@@ -106,14 +106,14 @@ class XForm(object):
# Sanity: The set of inferred free typevars should be a subset of the # Sanity: The set of inferred free typevars should be a subset of the
# TVs corresponding to Vars appearing in src # TVs corresponding to Vars appearing in src
self.free_typevars = self.ti.free_typevars() free_typevars = set(self.ti.free_typevars())
src_vars = set(self.inputs).union( src_vars = set(self.inputs).union(
[x for x in self.defs if not x.is_temp()]) [x for x in self.defs if not x.is_temp()])
src_tvs = set([v.get_typevar() for v in src_vars]) src_tvs = set([v.get_typevar() for v in src_vars])
if (not self.free_typevars.issubset(src_tvs)): if (not free_typevars.issubset(src_tvs)):
raise AssertionError( raise AssertionError(
"Some free vars don't appear in src - {}" "Some free vars don't appear in src - {}"
.format(self.free_typevars.difference(src_tvs))) .format(free_typevars.difference(src_tvs)))
# Update the type vars for each Var to their inferred values # Update the type vars for each Var to their inferred values
for v in self.inputs + self.defs: for v in self.inputs + self.defs: