Cleanup ValueType.get_names to with_bits form previous PR; Add computation of inverse image of typeset across a derived function - TypeSet.map_inverse; Change TypeVar.constrain_type to perform a more-general computation using inverse images of TypeSets; Tests for map_inverse;

This commit is contained in:
Dimo
2017-06-23 17:46:05 -07:00
committed by Jakob Stoklund Olesen
parent 7c298078c8
commit c073d919f4
3 changed files with 237 additions and 48 deletions

View File

@@ -4,6 +4,8 @@ from doctest import DocTestSuite
from . import typevar from . import typevar
from .typevar import TypeSet, TypeVar from .typevar import TypeSet, TypeVar
from base.types import i32, i16, b1, f64 from base.types import i32, i16, b1, f64
from itertools import product
from functools import reduce
def load_tests(loader, tests, ignore): def load_tests(loader, tests, ignore):
@@ -123,6 +125,58 @@ class TestTypeSet(TestCase):
self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(), self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(),
i32.by(4)) i32.by(4))
def test_map_inverse(self):
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
self.assertEqual(t, t.map_inverse(TypeVar.SAMEAS))
# LANEOF
self.assertEqual(TypeSet(lanes=True, ints=(8, 8), floats=(32, 32)),
t.map_inverse(TypeVar.LANEOF))
# Inverse of empty set is still empty across LANEOF
self.assertEqual(TypeSet(),
TypeSet().map_inverse(TypeVar.LANEOF))
# ASBOOL
t = TypeSet(lanes=(1, 4), bools=(1, 64))
self.assertEqual(t.map_inverse(TypeVar.ASBOOL),
TypeSet(lanes=(1, 4), ints=True, bools=True,
floats=True))
# Inverse image across ASBOOL of TS not involving b1 cannot have
# lanes=1
t = TypeSet(lanes=(1, 4), bools=(16, 32))
self.assertEqual(t.map_inverse(TypeVar.ASBOOL),
TypeSet(lanes=(2, 4), ints=(16, 32), bools=(16, 32),
floats=(32, 32)))
# Half/Double Vector
t = TypeSet(lanes=(1, 1), ints=(8, 8))
t1 = TypeSet(lanes=(256, 256), ints=(8, 8))
self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR).size(), 0)
self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR).size(), 0)
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 32))
t1 = TypeSet(lanes=(64, 256), bools=(1, 32))
self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR),
TypeSet(lanes=(1, 8), ints=(8, 16), floats=(32, 32)))
self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR),
TypeSet(lanes=(128, 256), bools=(1, 32)))
# Half/Double Width
t = TypeSet(ints=(8, 8), floats=(32, 32), bools=(1, 8))
t1 = TypeSet(ints=(64, 64), floats=(64, 64), bools=(64, 64))
self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH).size(), 0)
self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH).size(), 0)
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 64))
t1 = TypeSet(lanes=(64, 256), bools=(1, 64))
self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH),
TypeSet(lanes=(1, 16), ints=(8, 8), floats=(32, 32)))
self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH),
TypeSet(lanes=(64, 256), bools=(16, 64)))
class TestTypeVar(TestCase): class TestTypeVar(TestCase):
def test_functions(self): def test_functions(self):
@@ -164,3 +218,57 @@ class TestTypeVar(TestCase):
self.assertEqual(max(x.type_set.lanes), 4) self.assertEqual(max(x.type_set.lanes), 4)
self.assertEqual(len(x.type_set.floats), 0) self.assertEqual(len(x.type_set.floats), 0)
self.assertEqual(len(x.type_set.bools), 0) self.assertEqual(len(x.type_set.bools), 0)
def test_stress_constrain_types(self):
# Get all 49 possible derived vars of lentgh 2. Since we have SAMEAS
# this includes singly derived and non-derived vars
funcs = [TypeVar.SAMEAS, TypeVar.LANEOF,
TypeVar.ASBOOL, TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR,
TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
v = list(product(*[funcs, funcs]))
# For each pair of derived variables
for (i1, i2) in product(v, v):
# Compute the derived sets for each starting with a full typeset
full_ts = TypeSet(lanes=True, floats=True, ints=True, bools=True)
ts1 = reduce(lambda ts, func: ts.map(func), i1, full_ts)
ts2 = reduce(lambda ts, func: ts.map(func), i2, full_ts)
# Compute intersection
intersect = ts1.copy()
intersect &= ts2
# Propagate instersections backward
ts1_src = reduce(lambda ts, func: ts.map_inverse(func),
reversed(i1),
intersect)
ts2_src = reduce(lambda ts, func: ts.map_inverse(func),
reversed(i2),
intersect)
# If the intersection or its propagated forms are empty, then these
# two variables can never overlap. For example x.double_vector and
# x.lane_of.
if (intersect.size() == 0 or ts1_src.size() == 0 or
ts2_src.size() == 0):
continue
# Should be safe to create derived tvs from ts1_src and ts2_src
tv1 = reduce(lambda tv, func: TypeVar.derived(tv, func),
i1,
TypeVar.from_typeset(ts1_src))
tv2 = reduce(lambda tv, func: TypeVar.derived(tv, func),
i2,
TypeVar.from_typeset(ts2_src))
# The typesets of the two derived variables should be subsets of
# the intersection we computed originally
assert tv1.get_typeset().issubset(intersect)
assert tv2.get_typeset().issubset(intersect)
# In the absence of AS_BOOL map(map_inverse(f)) == f so the
# typesets of tv1 and tv2 should be exactly intersection
assert (tv1.get_typeset() == tv2.get_typeset() and
tv1.get_typeset() == intersect) or\
TypeVar.ASBOOL in set(i1 + i2)

View File

@@ -3,8 +3,9 @@ from __future__ import absolute_import
import math import math
try: try:
from typing import Dict, List # noqa from typing import Dict, List, cast, TYPE_CHECKING # noqa
except ImportError: except ImportError:
TYPE_CHECKING = False
pass pass
@@ -97,7 +98,7 @@ class VectorType(ValueType):
# type: (ScalarType, int) -> None # type: (ScalarType, int) -> None
assert isinstance(base, ScalarType), 'SIMD lanes must be scalar types' assert isinstance(base, ScalarType), 'SIMD lanes must be scalar types'
super(VectorType, self).__init__( super(VectorType, self).__init__(
name=VectorType.get_name(base, lanes), name='{}x{}'.format(base.name, lanes),
membytes=lanes*base.membytes, membytes=lanes*base.membytes,
doc=""" doc="""
A SIMD vector with {} lanes containing a `{}` each. A SIMD vector with {} lanes containing a `{}` each.
@@ -111,11 +112,6 @@ class VectorType(ValueType):
return ('VectorType(base={}, lanes={})' return ('VectorType(base={}, lanes={})'
.format(self.base.name, self.lanes)) .format(self.base.name, self.lanes))
@staticmethod
def get_name(base, lanes):
# type: (ValueType, int) -> str
return '{}x{}'.format(base.name, lanes)
class IntType(ScalarType): class IntType(ScalarType):
"""A concrete scalar integer type.""" """A concrete scalar integer type."""
@@ -124,7 +120,7 @@ class IntType(ScalarType):
# type: (int) -> None # type: (int) -> None
assert bits > 0, 'IntType must have positive number of bits' assert bits > 0, 'IntType must have positive number of bits'
super(IntType, self).__init__( super(IntType, self).__init__(
name=IntType.get_name(bits), name='i{:d}'.format(bits),
membytes=bits // 8, membytes=bits // 8,
doc="An integer type with {} bits.".format(bits)) doc="An integer type with {} bits.".format(bits))
self.bits = bits self.bits = bits
@@ -134,9 +130,13 @@ class IntType(ScalarType):
return 'IntType(bits={})'.format(self.bits) return 'IntType(bits={})'.format(self.bits)
@staticmethod @staticmethod
def get_name(bits): def with_bits(bits):
# type: (int) -> str # type: (int) -> IntType
return 'i{:d}'.format(bits) typ = ValueType.by_name('i{:d}'.format(bits))
if TYPE_CHECKING:
return cast(IntType, typ)
else:
return typ
class FloatType(ScalarType): class FloatType(ScalarType):
@@ -146,7 +146,7 @@ class FloatType(ScalarType):
# type: (int, str) -> None # type: (int, str) -> None
assert bits > 0, 'FloatType must have positive number of bits' assert bits > 0, 'FloatType must have positive number of bits'
super(FloatType, self).__init__( super(FloatType, self).__init__(
name=FloatType.get_name(bits), name='f{:d}'.format(bits),
membytes=bits // 8, membytes=bits // 8,
doc=doc) doc=doc)
self.bits = bits self.bits = bits
@@ -156,9 +156,13 @@ class FloatType(ScalarType):
return 'FloatType(bits={})'.format(self.bits) return 'FloatType(bits={})'.format(self.bits)
@staticmethod @staticmethod
def get_name(bits): def with_bits(bits):
# type: (int) -> str # type: (int) -> FloatType
return 'f{:d}'.format(bits) typ = ValueType.by_name('f{:d}'.format(bits))
if TYPE_CHECKING:
return cast(FloatType, typ)
else:
return typ
class BoolType(ScalarType): class BoolType(ScalarType):
@@ -168,7 +172,7 @@ class BoolType(ScalarType):
# type: (int) -> None # type: (int) -> None
assert bits > 0, 'BoolType must have positive number of bits' assert bits > 0, 'BoolType must have positive number of bits'
super(BoolType, self).__init__( super(BoolType, self).__init__(
name=BoolType.get_name(bits), name='b{:d}'.format(bits),
membytes=bits // 8, membytes=bits // 8,
doc="A boolean type with {} bits.".format(bits)) doc="A boolean type with {} bits.".format(bits))
self.bits = bits self.bits = bits
@@ -178,6 +182,10 @@ class BoolType(ScalarType):
return 'BoolType(bits={})'.format(self.bits) return 'BoolType(bits={})'.format(self.bits)
@staticmethod @staticmethod
def get_name(bits): def with_bits(bits):
# type: (int) -> str # type: (int) -> BoolType
return 'b{:d}'.format(bits) typ = ValueType.by_name('b{:d}'.format(bits))
if TYPE_CHECKING:
return cast(BoolType, typ)
else:
return typ

View File

@@ -8,18 +8,17 @@ from __future__ import absolute_import
import math import math
from . import types, is_power_of_two from . import types, is_power_of_two
from copy import deepcopy from copy import deepcopy
from .types import ValueType, IntType, FloatType, BoolType from .types import IntType, FloatType, BoolType
try: try:
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
from typing import cast
if TYPE_CHECKING: if TYPE_CHECKING:
from srcgen import Formatter # noqa from srcgen import Formatter # noqa
from .types import ValueType # noqa
Interval = Tuple[int, int] Interval = Tuple[int, int]
# An Interval where `True` means 'everything' # An Interval where `True` means 'everything'
BoolInterval = Union[bool, Interval] BoolInterval = Union[bool, Interval]
except ImportError: except ImportError:
TYPE_CHECKING = False
pass pass
MAX_LANES = 256 MAX_LANES = 256
@@ -263,6 +262,16 @@ class TypeSet(object):
return self return self
def issubset(self, other):
# type: (TypeSet) -> bool
"""
Return true iff self is a subset of other
"""
return self.lanes.issubset(other.lanes) and \
self.ints.issubset(other.ints) and \
self.floats.issubset(other.floats) and \
self.bools.issubset(other.bools)
def lane_of(self): def lane_of(self):
# type: () -> TypeSet # type: () -> TypeSet
""" """
@@ -292,9 +301,9 @@ class TypeSet(object):
Return a TypeSet describing the image of self across halfwidth Return a TypeSet describing the image of self across halfwidth
""" """
new = self.copy() new = self.copy()
new.ints = set([x/2 for x in self.ints if x > 8]) new.ints = set([x//2 for x in self.ints if x > 8])
new.floats = set([x/2 for x in self.floats if x > 32]) new.floats = set([x//2 for x in self.floats if x > 32])
new.bools = set([x/2 for x in self.bools if x > 8]) new.bools = set([x//2 for x in self.bools if x > 8])
return new return new
@@ -317,7 +326,7 @@ class TypeSet(object):
Return a TypeSet describing the image of self across halfvector Return a TypeSet describing the image of self across halfvector
""" """
new = self.copy() new = self.copy()
new.lanes = set([x/2 for x in self.lanes if x > 1]) new.lanes = set([x//2 for x in self.lanes if x > 1])
return new return new
@@ -331,6 +340,67 @@ class TypeSet(object):
return new return new
def map(self, func):
# type: (str) -> TypeSet
"""
Return the image of self across the derived function func
"""
if (func == TypeVar.SAMEAS):
return self
elif (func == TypeVar.LANEOF):
return self.lane_of()
elif (func == TypeVar.ASBOOL):
return self.as_bool()
elif (func == TypeVar.HALFWIDTH):
return self.half_width()
elif (func == TypeVar.DOUBLEWIDTH):
return self.double_width()
elif (func == TypeVar.HALFVECTOR):
return self.half_vector()
elif (func == TypeVar.DOUBLEVECTOR):
return self.double_vector()
else:
assert False, "Unknown derived function: " + func
def map_inverse(self, func):
# type: (str) -> TypeSet
"""
Return the inverse image of self across the derived function func
"""
# The inverse of the empty set is always empty
if (self.size() == 0):
return self
if (func == TypeVar.SAMEAS):
return self
elif (func == TypeVar.LANEOF):
new = self.copy()
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
return new
elif (func == TypeVar.ASBOOL):
new = self.copy()
new.ints = self.bools.difference(set([1]))
new.floats = self.bools.intersection(set([32, 64]))
if 1 not in self.bools:
try:
# If the range doesn't have b1, then the domain can't
# include scalars, as as_bool(scalar)=b1
new.lanes.remove(1)
except KeyError:
pass
return new
elif (func == TypeVar.HALFWIDTH):
return self.double_width()
elif (func == TypeVar.DOUBLEWIDTH):
return self.half_width()
elif (func == TypeVar.HALFVECTOR):
return self.double_vector()
elif (func == TypeVar.DOUBLEVECTOR):
return self.half_vector()
else:
assert False, "Unknown derived function: " + func
def size(self): def size(self):
# type: () -> int # type: () -> int
""" """
@@ -346,25 +416,20 @@ class TypeSet(object):
typesets containing 1 type. typesets containing 1 type.
""" """
assert self.size() == 1 assert self.size() == 1
scalar_type = None # type: types.ScalarType
if len(self.ints) > 0: if len(self.ints) > 0:
bits = tuple(self.ints)[0] scalar_type = IntType.with_bits(tuple(self.ints)[0])
scalar_type = ValueType.by_name(IntType.get_name(bits))
elif len(self.floats) > 0: elif len(self.floats) > 0:
bits = tuple(self.floats)[0] scalar_type = FloatType.with_bits(tuple(self.floats)[0])
scalar_type = ValueType.by_name(FloatType.get_name(bits))
else: else:
bits = tuple(self.bools)[0] scalar_type = BoolType.with_bits(tuple(self.bools)[0])
scalar_type = ValueType.by_name(BoolType.get_name(bits))
nlanes = tuple(self.lanes)[0] nlanes = tuple(self.lanes)[0]
if nlanes == 1: if nlanes == 1:
return scalar_type return scalar_type
else: else:
if TYPE_CHECKING: return scalar_type.by(nlanes)
return cast(types.ScalarType, scalar_type).by(nlanes)
else:
return scalar_type.by(nlanes)
class TypeVar(object): class TypeVar(object):
@@ -483,6 +548,14 @@ class TypeVar(object):
"""Create a type variable that is a function of another.""" """Create a type variable that is a function of another."""
return TypeVar(None, None, base=base, derived_func=derived_func) return TypeVar(None, None, base=base, derived_func=derived_func)
@staticmethod
def from_typeset(ts):
# type: (TypeSet) -> TypeVar
""" Create a type variable from a type set."""
tv = TypeVar(None, None)
tv.type_set = ts
return tv
def change_to_derived(self, base, derived_func): def change_to_derived(self, base, derived_func):
# type: (TypeVar, str) -> None # type: (TypeVar, str) -> None
"""Change this type variable into a derived one.""" """Change this type variable into a derived one."""
@@ -615,6 +688,17 @@ class TypeVar(object):
else: else:
return self.name return self.name
def constrain_types_by_ts(self, ts):
# type: (TypeSet) -> None
"""
Constrain the range of types this variable can assume to a subset of
those in the typeset ts.
"""
if not self.is_derived:
self.type_set &= ts
else:
self.base.constrain_types_by_ts(ts.map_inverse(self.derived_func))
def constrain_types(self, other): def constrain_types(self, other):
# type: (TypeVar) -> None # type: (TypeVar) -> None
""" """
@@ -628,18 +712,7 @@ class TypeVar(object):
if a is b: if a is b:
return return
if not a.is_derived and not b.is_derived: a.constrain_types_by_ts(b.get_typeset())
a.type_set &= b.type_set
return
# TODO: Implement constraints for derived type variables.
#
# If a and b are both derived with the same derived_func, we could say
# `a.base.constrain_types(b.base)`, but unless the derived_func is
# injective, that may constrain `a.base` more than necessary.
#
# For the fully general case, we would need to compute an image typeset
# for `b` and propagate a `a.derived_func` pre-image to `a.base`.
def get_typeset(self): def get_typeset(self):
# type: () -> TypeSet # type: () -> TypeSet