Cleanup ValueType.get_names to with_bits form previous PR; Add computation of inverse image of typeset across a derived function - TypeSet.map_inverse; Change TypeVar.constrain_type to perform a more-general computation using inverse images of TypeSets; Tests for map_inverse;
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
7c298078c8
commit
c073d919f4
@@ -4,6 +4,8 @@ from doctest import DocTestSuite
|
||||
from . import typevar
|
||||
from .typevar import TypeSet, TypeVar
|
||||
from base.types import i32, i16, b1, f64
|
||||
from itertools import product
|
||||
from functools import reduce
|
||||
|
||||
|
||||
def load_tests(loader, tests, ignore):
|
||||
@@ -123,6 +125,58 @@ class TestTypeSet(TestCase):
|
||||
self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(),
|
||||
i32.by(4))
|
||||
|
||||
def test_map_inverse(self):
|
||||
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
|
||||
self.assertEqual(t, t.map_inverse(TypeVar.SAMEAS))
|
||||
|
||||
# LANEOF
|
||||
self.assertEqual(TypeSet(lanes=True, ints=(8, 8), floats=(32, 32)),
|
||||
t.map_inverse(TypeVar.LANEOF))
|
||||
# Inverse of empty set is still empty across LANEOF
|
||||
self.assertEqual(TypeSet(),
|
||||
TypeSet().map_inverse(TypeVar.LANEOF))
|
||||
|
||||
# ASBOOL
|
||||
t = TypeSet(lanes=(1, 4), bools=(1, 64))
|
||||
self.assertEqual(t.map_inverse(TypeVar.ASBOOL),
|
||||
TypeSet(lanes=(1, 4), ints=True, bools=True,
|
||||
floats=True))
|
||||
|
||||
# Inverse image across ASBOOL of TS not involving b1 cannot have
|
||||
# lanes=1
|
||||
t = TypeSet(lanes=(1, 4), bools=(16, 32))
|
||||
self.assertEqual(t.map_inverse(TypeVar.ASBOOL),
|
||||
TypeSet(lanes=(2, 4), ints=(16, 32), bools=(16, 32),
|
||||
floats=(32, 32)))
|
||||
|
||||
# Half/Double Vector
|
||||
t = TypeSet(lanes=(1, 1), ints=(8, 8))
|
||||
t1 = TypeSet(lanes=(256, 256), ints=(8, 8))
|
||||
self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR).size(), 0)
|
||||
self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR).size(), 0)
|
||||
|
||||
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 32))
|
||||
t1 = TypeSet(lanes=(64, 256), bools=(1, 32))
|
||||
|
||||
self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR),
|
||||
TypeSet(lanes=(1, 8), ints=(8, 16), floats=(32, 32)))
|
||||
self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR),
|
||||
TypeSet(lanes=(128, 256), bools=(1, 32)))
|
||||
|
||||
# Half/Double Width
|
||||
t = TypeSet(ints=(8, 8), floats=(32, 32), bools=(1, 8))
|
||||
t1 = TypeSet(ints=(64, 64), floats=(64, 64), bools=(64, 64))
|
||||
self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH).size(), 0)
|
||||
self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH).size(), 0)
|
||||
|
||||
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 64))
|
||||
t1 = TypeSet(lanes=(64, 256), bools=(1, 64))
|
||||
|
||||
self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH),
|
||||
TypeSet(lanes=(1, 16), ints=(8, 8), floats=(32, 32)))
|
||||
self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH),
|
||||
TypeSet(lanes=(64, 256), bools=(16, 64)))
|
||||
|
||||
|
||||
class TestTypeVar(TestCase):
|
||||
def test_functions(self):
|
||||
@@ -164,3 +218,57 @@ class TestTypeVar(TestCase):
|
||||
self.assertEqual(max(x.type_set.lanes), 4)
|
||||
self.assertEqual(len(x.type_set.floats), 0)
|
||||
self.assertEqual(len(x.type_set.bools), 0)
|
||||
|
||||
def test_stress_constrain_types(self):
|
||||
# Get all 49 possible derived vars of lentgh 2. Since we have SAMEAS
|
||||
# this includes singly derived and non-derived vars
|
||||
funcs = [TypeVar.SAMEAS, TypeVar.LANEOF,
|
||||
TypeVar.ASBOOL, TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR,
|
||||
TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
|
||||
v = list(product(*[funcs, funcs]))
|
||||
|
||||
# For each pair of derived variables
|
||||
for (i1, i2) in product(v, v):
|
||||
# Compute the derived sets for each starting with a full typeset
|
||||
full_ts = TypeSet(lanes=True, floats=True, ints=True, bools=True)
|
||||
ts1 = reduce(lambda ts, func: ts.map(func), i1, full_ts)
|
||||
ts2 = reduce(lambda ts, func: ts.map(func), i2, full_ts)
|
||||
|
||||
# Compute intersection
|
||||
intersect = ts1.copy()
|
||||
intersect &= ts2
|
||||
|
||||
# Propagate instersections backward
|
||||
ts1_src = reduce(lambda ts, func: ts.map_inverse(func),
|
||||
reversed(i1),
|
||||
intersect)
|
||||
ts2_src = reduce(lambda ts, func: ts.map_inverse(func),
|
||||
reversed(i2),
|
||||
intersect)
|
||||
|
||||
# If the intersection or its propagated forms are empty, then these
|
||||
# two variables can never overlap. For example x.double_vector and
|
||||
# x.lane_of.
|
||||
if (intersect.size() == 0 or ts1_src.size() == 0 or
|
||||
ts2_src.size() == 0):
|
||||
continue
|
||||
|
||||
# Should be safe to create derived tvs from ts1_src and ts2_src
|
||||
tv1 = reduce(lambda tv, func: TypeVar.derived(tv, func),
|
||||
i1,
|
||||
TypeVar.from_typeset(ts1_src))
|
||||
|
||||
tv2 = reduce(lambda tv, func: TypeVar.derived(tv, func),
|
||||
i2,
|
||||
TypeVar.from_typeset(ts2_src))
|
||||
|
||||
# The typesets of the two derived variables should be subsets of
|
||||
# the intersection we computed originally
|
||||
assert tv1.get_typeset().issubset(intersect)
|
||||
assert tv2.get_typeset().issubset(intersect)
|
||||
|
||||
# In the absence of AS_BOOL map(map_inverse(f)) == f so the
|
||||
# typesets of tv1 and tv2 should be exactly intersection
|
||||
assert (tv1.get_typeset() == tv2.get_typeset() and
|
||||
tv1.get_typeset() == intersect) or\
|
||||
TypeVar.ASBOOL in set(i1 + i2)
|
||||
|
||||
@@ -3,8 +3,9 @@ from __future__ import absolute_import
|
||||
import math
|
||||
|
||||
try:
|
||||
from typing import Dict, List # noqa
|
||||
from typing import Dict, List, cast, TYPE_CHECKING # noqa
|
||||
except ImportError:
|
||||
TYPE_CHECKING = False
|
||||
pass
|
||||
|
||||
|
||||
@@ -97,7 +98,7 @@ class VectorType(ValueType):
|
||||
# type: (ScalarType, int) -> None
|
||||
assert isinstance(base, ScalarType), 'SIMD lanes must be scalar types'
|
||||
super(VectorType, self).__init__(
|
||||
name=VectorType.get_name(base, lanes),
|
||||
name='{}x{}'.format(base.name, lanes),
|
||||
membytes=lanes*base.membytes,
|
||||
doc="""
|
||||
A SIMD vector with {} lanes containing a `{}` each.
|
||||
@@ -111,11 +112,6 @@ class VectorType(ValueType):
|
||||
return ('VectorType(base={}, lanes={})'
|
||||
.format(self.base.name, self.lanes))
|
||||
|
||||
@staticmethod
|
||||
def get_name(base, lanes):
|
||||
# type: (ValueType, int) -> str
|
||||
return '{}x{}'.format(base.name, lanes)
|
||||
|
||||
|
||||
class IntType(ScalarType):
|
||||
"""A concrete scalar integer type."""
|
||||
@@ -124,7 +120,7 @@ class IntType(ScalarType):
|
||||
# type: (int) -> None
|
||||
assert bits > 0, 'IntType must have positive number of bits'
|
||||
super(IntType, self).__init__(
|
||||
name=IntType.get_name(bits),
|
||||
name='i{:d}'.format(bits),
|
||||
membytes=bits // 8,
|
||||
doc="An integer type with {} bits.".format(bits))
|
||||
self.bits = bits
|
||||
@@ -134,9 +130,13 @@ class IntType(ScalarType):
|
||||
return 'IntType(bits={})'.format(self.bits)
|
||||
|
||||
@staticmethod
|
||||
def get_name(bits):
|
||||
# type: (int) -> str
|
||||
return 'i{:d}'.format(bits)
|
||||
def with_bits(bits):
|
||||
# type: (int) -> IntType
|
||||
typ = ValueType.by_name('i{:d}'.format(bits))
|
||||
if TYPE_CHECKING:
|
||||
return cast(IntType, typ)
|
||||
else:
|
||||
return typ
|
||||
|
||||
|
||||
class FloatType(ScalarType):
|
||||
@@ -146,7 +146,7 @@ class FloatType(ScalarType):
|
||||
# type: (int, str) -> None
|
||||
assert bits > 0, 'FloatType must have positive number of bits'
|
||||
super(FloatType, self).__init__(
|
||||
name=FloatType.get_name(bits),
|
||||
name='f{:d}'.format(bits),
|
||||
membytes=bits // 8,
|
||||
doc=doc)
|
||||
self.bits = bits
|
||||
@@ -156,9 +156,13 @@ class FloatType(ScalarType):
|
||||
return 'FloatType(bits={})'.format(self.bits)
|
||||
|
||||
@staticmethod
|
||||
def get_name(bits):
|
||||
# type: (int) -> str
|
||||
return 'f{:d}'.format(bits)
|
||||
def with_bits(bits):
|
||||
# type: (int) -> FloatType
|
||||
typ = ValueType.by_name('f{:d}'.format(bits))
|
||||
if TYPE_CHECKING:
|
||||
return cast(FloatType, typ)
|
||||
else:
|
||||
return typ
|
||||
|
||||
|
||||
class BoolType(ScalarType):
|
||||
@@ -168,7 +172,7 @@ class BoolType(ScalarType):
|
||||
# type: (int) -> None
|
||||
assert bits > 0, 'BoolType must have positive number of bits'
|
||||
super(BoolType, self).__init__(
|
||||
name=BoolType.get_name(bits),
|
||||
name='b{:d}'.format(bits),
|
||||
membytes=bits // 8,
|
||||
doc="A boolean type with {} bits.".format(bits))
|
||||
self.bits = bits
|
||||
@@ -178,6 +182,10 @@ class BoolType(ScalarType):
|
||||
return 'BoolType(bits={})'.format(self.bits)
|
||||
|
||||
@staticmethod
|
||||
def get_name(bits):
|
||||
# type: (int) -> str
|
||||
return 'b{:d}'.format(bits)
|
||||
def with_bits(bits):
|
||||
# type: (int) -> BoolType
|
||||
typ = ValueType.by_name('b{:d}'.format(bits))
|
||||
if TYPE_CHECKING:
|
||||
return cast(BoolType, typ)
|
||||
else:
|
||||
return typ
|
||||
|
||||
@@ -8,18 +8,17 @@ from __future__ import absolute_import
|
||||
import math
|
||||
from . import types, is_power_of_two
|
||||
from copy import deepcopy
|
||||
from .types import ValueType, IntType, FloatType, BoolType
|
||||
from .types import IntType, FloatType, BoolType
|
||||
|
||||
try:
|
||||
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
|
||||
from typing import cast
|
||||
if TYPE_CHECKING:
|
||||
from srcgen import Formatter # noqa
|
||||
from .types import ValueType # noqa
|
||||
Interval = Tuple[int, int]
|
||||
# An Interval where `True` means 'everything'
|
||||
BoolInterval = Union[bool, Interval]
|
||||
except ImportError:
|
||||
TYPE_CHECKING = False
|
||||
pass
|
||||
|
||||
MAX_LANES = 256
|
||||
@@ -263,6 +262,16 @@ class TypeSet(object):
|
||||
|
||||
return self
|
||||
|
||||
def issubset(self, other):
|
||||
# type: (TypeSet) -> bool
|
||||
"""
|
||||
Return true iff self is a subset of other
|
||||
"""
|
||||
return self.lanes.issubset(other.lanes) and \
|
||||
self.ints.issubset(other.ints) and \
|
||||
self.floats.issubset(other.floats) and \
|
||||
self.bools.issubset(other.bools)
|
||||
|
||||
def lane_of(self):
|
||||
# type: () -> TypeSet
|
||||
"""
|
||||
@@ -292,9 +301,9 @@ class TypeSet(object):
|
||||
Return a TypeSet describing the image of self across halfwidth
|
||||
"""
|
||||
new = self.copy()
|
||||
new.ints = set([x/2 for x in self.ints if x > 8])
|
||||
new.floats = set([x/2 for x in self.floats if x > 32])
|
||||
new.bools = set([x/2 for x in self.bools if x > 8])
|
||||
new.ints = set([x//2 for x in self.ints if x > 8])
|
||||
new.floats = set([x//2 for x in self.floats if x > 32])
|
||||
new.bools = set([x//2 for x in self.bools if x > 8])
|
||||
|
||||
return new
|
||||
|
||||
@@ -317,7 +326,7 @@ class TypeSet(object):
|
||||
Return a TypeSet describing the image of self across halfvector
|
||||
"""
|
||||
new = self.copy()
|
||||
new.lanes = set([x/2 for x in self.lanes if x > 1])
|
||||
new.lanes = set([x//2 for x in self.lanes if x > 1])
|
||||
|
||||
return new
|
||||
|
||||
@@ -331,6 +340,67 @@ class TypeSet(object):
|
||||
|
||||
return new
|
||||
|
||||
def map(self, func):
|
||||
# type: (str) -> TypeSet
|
||||
"""
|
||||
Return the image of self across the derived function func
|
||||
"""
|
||||
if (func == TypeVar.SAMEAS):
|
||||
return self
|
||||
elif (func == TypeVar.LANEOF):
|
||||
return self.lane_of()
|
||||
elif (func == TypeVar.ASBOOL):
|
||||
return self.as_bool()
|
||||
elif (func == TypeVar.HALFWIDTH):
|
||||
return self.half_width()
|
||||
elif (func == TypeVar.DOUBLEWIDTH):
|
||||
return self.double_width()
|
||||
elif (func == TypeVar.HALFVECTOR):
|
||||
return self.half_vector()
|
||||
elif (func == TypeVar.DOUBLEVECTOR):
|
||||
return self.double_vector()
|
||||
else:
|
||||
assert False, "Unknown derived function: " + func
|
||||
|
||||
def map_inverse(self, func):
|
||||
# type: (str) -> TypeSet
|
||||
"""
|
||||
Return the inverse image of self across the derived function func
|
||||
"""
|
||||
# The inverse of the empty set is always empty
|
||||
if (self.size() == 0):
|
||||
return self
|
||||
|
||||
if (func == TypeVar.SAMEAS):
|
||||
return self
|
||||
elif (func == TypeVar.LANEOF):
|
||||
new = self.copy()
|
||||
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
|
||||
return new
|
||||
elif (func == TypeVar.ASBOOL):
|
||||
new = self.copy()
|
||||
new.ints = self.bools.difference(set([1]))
|
||||
new.floats = self.bools.intersection(set([32, 64]))
|
||||
|
||||
if 1 not in self.bools:
|
||||
try:
|
||||
# If the range doesn't have b1, then the domain can't
|
||||
# include scalars, as as_bool(scalar)=b1
|
||||
new.lanes.remove(1)
|
||||
except KeyError:
|
||||
pass
|
||||
return new
|
||||
elif (func == TypeVar.HALFWIDTH):
|
||||
return self.double_width()
|
||||
elif (func == TypeVar.DOUBLEWIDTH):
|
||||
return self.half_width()
|
||||
elif (func == TypeVar.HALFVECTOR):
|
||||
return self.double_vector()
|
||||
elif (func == TypeVar.DOUBLEVECTOR):
|
||||
return self.half_vector()
|
||||
else:
|
||||
assert False, "Unknown derived function: " + func
|
||||
|
||||
def size(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
@@ -346,23 +416,18 @@ class TypeSet(object):
|
||||
typesets containing 1 type.
|
||||
"""
|
||||
assert self.size() == 1
|
||||
scalar_type = None # type: types.ScalarType
|
||||
if len(self.ints) > 0:
|
||||
bits = tuple(self.ints)[0]
|
||||
scalar_type = ValueType.by_name(IntType.get_name(bits))
|
||||
scalar_type = IntType.with_bits(tuple(self.ints)[0])
|
||||
elif len(self.floats) > 0:
|
||||
bits = tuple(self.floats)[0]
|
||||
scalar_type = ValueType.by_name(FloatType.get_name(bits))
|
||||
scalar_type = FloatType.with_bits(tuple(self.floats)[0])
|
||||
else:
|
||||
bits = tuple(self.bools)[0]
|
||||
scalar_type = ValueType.by_name(BoolType.get_name(bits))
|
||||
scalar_type = BoolType.with_bits(tuple(self.bools)[0])
|
||||
|
||||
nlanes = tuple(self.lanes)[0]
|
||||
|
||||
if nlanes == 1:
|
||||
return scalar_type
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
return cast(types.ScalarType, scalar_type).by(nlanes)
|
||||
else:
|
||||
return scalar_type.by(nlanes)
|
||||
|
||||
@@ -483,6 +548,14 @@ class TypeVar(object):
|
||||
"""Create a type variable that is a function of another."""
|
||||
return TypeVar(None, None, base=base, derived_func=derived_func)
|
||||
|
||||
@staticmethod
|
||||
def from_typeset(ts):
|
||||
# type: (TypeSet) -> TypeVar
|
||||
""" Create a type variable from a type set."""
|
||||
tv = TypeVar(None, None)
|
||||
tv.type_set = ts
|
||||
return tv
|
||||
|
||||
def change_to_derived(self, base, derived_func):
|
||||
# type: (TypeVar, str) -> None
|
||||
"""Change this type variable into a derived one."""
|
||||
@@ -615,6 +688,17 @@ class TypeVar(object):
|
||||
else:
|
||||
return self.name
|
||||
|
||||
def constrain_types_by_ts(self, ts):
|
||||
# type: (TypeSet) -> None
|
||||
"""
|
||||
Constrain the range of types this variable can assume to a subset of
|
||||
those in the typeset ts.
|
||||
"""
|
||||
if not self.is_derived:
|
||||
self.type_set &= ts
|
||||
else:
|
||||
self.base.constrain_types_by_ts(ts.map_inverse(self.derived_func))
|
||||
|
||||
def constrain_types(self, other):
|
||||
# type: (TypeVar) -> None
|
||||
"""
|
||||
@@ -628,18 +712,7 @@ class TypeVar(object):
|
||||
if a is b:
|
||||
return
|
||||
|
||||
if not a.is_derived and not b.is_derived:
|
||||
a.type_set &= b.type_set
|
||||
return
|
||||
|
||||
# TODO: Implement constraints for derived type variables.
|
||||
#
|
||||
# If a and b are both derived with the same derived_func, we could say
|
||||
# `a.base.constrain_types(b.base)`, but unless the derived_func is
|
||||
# injective, that may constrain `a.base` more than necessary.
|
||||
#
|
||||
# For the fully general case, we would need to compute an image typeset
|
||||
# for `b` and propagate a `a.derived_func` pre-image to `a.base`.
|
||||
a.constrain_types_by_ts(b.get_typeset())
|
||||
|
||||
def get_typeset(self):
|
||||
# type: () -> TypeSet
|
||||
|
||||
Reference in New Issue
Block a user