Add image computation of typesets; Remove TypeVar.singleton_type - instead derive singleton type from typeset; (#104)

This commit is contained in:
d1m0
2017-06-23 11:57:24 -07:00
committed by Jakob Stoklund Olesen
parent 9487b885da
commit 6a9438d274
7 changed files with 294 additions and 49 deletions

View File

@@ -2,15 +2,10 @@
The base.types module predefines all the Cretonne scalar types. The base.types module predefines all the Cretonne scalar types.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from cdsl.types import ScalarType, IntType, FloatType, BoolType from cdsl.types import IntType, FloatType, BoolType
#: Boolean. #: Boolean.
b1 = ScalarType( b1 = BoolType(1) #: 1-bit bool. Type is abstract (can't be stored in mem)
'b1', 0,
"""
A boolean value that is either true or false.
""")
b8 = BoolType(8) #: 8-bit bool. b8 = BoolType(8) #: 8-bit bool.
b16 = BoolType(16) #: 16-bit bool. b16 = BoolType(16) #: 16-bit bool.
b32 = BoolType(32) #: 32-bit bool. b32 = BoolType(32) #: 32-bit bool.

View File

@@ -186,7 +186,7 @@ class Instruction(object):
try: try:
opnum = self.value_opnums[self.format.typevar_operand] opnum = self.value_opnums[self.format.typevar_operand]
tv = self.ins[opnum].typevar tv = self.ins[opnum].typevar
if tv is tv.free_typevar(): if tv is tv.free_typevar() or tv.singleton_type() is not None:
self.other_typevars = self._verify_ctrl_typevar(tv) self.other_typevars = self._verify_ctrl_typevar(tv)
self.ctrl_typevar = tv self.ctrl_typevar = tv
self.use_typevar_operand = True self.use_typevar_operand = True

View File

@@ -3,7 +3,7 @@ from unittest import TestCase
from doctest import DocTestSuite from doctest import DocTestSuite
from . import typevar from . import typevar
from .typevar import TypeSet, TypeVar from .typevar import TypeSet, TypeVar
from base.types import i32 from base.types import i32, i16, b1, f64
def load_tests(loader, tests, ignore): def load_tests(loader, tests, ignore):
@@ -45,6 +45,84 @@ class TestTypeSet(TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
a in s a in s
def test_forward_images(self):
a = TypeSet(lanes=(2, 8), ints=(8, 8), floats=(32, 32))
b = TypeSet(lanes=(1, 8), ints=(8, 8), floats=(32, 32))
self.assertEqual(a.lane_of(), TypeSet(ints=(8, 8), floats=(32, 32)))
c = TypeSet(lanes=(2, 8))
c.bools = set([8, 32])
# Test case with disjoint intervals
self.assertEqual(a.as_bool(), c)
# For as_bool check b1 is present when 1 \in lanes
d = TypeSet(lanes=(1, 8))
d.bools = set([1, 8, 32])
self.assertEqual(b.as_bool(), d)
self.assertEqual(TypeSet(lanes=(1, 32)).half_vector(),
TypeSet(lanes=(1, 16)))
self.assertEqual(TypeSet(lanes=(1, 32)).double_vector(),
TypeSet(lanes=(2, 64)))
self.assertEqual(TypeSet(lanes=(128, 256)).double_vector(),
TypeSet(lanes=(256, 256)))
self.assertEqual(TypeSet(ints=(8, 32)).half_width(),
TypeSet(ints=(8, 16)))
self.assertEqual(TypeSet(ints=(8, 32)).double_width(),
TypeSet(ints=(16, 64)))
self.assertEqual(TypeSet(ints=(32, 64)).double_width(),
TypeSet(ints=(64, 64)))
# Should produce an empty ts
self.assertEqual(TypeSet(floats=(32, 32)).half_width(),
TypeSet())
self.assertEqual(TypeSet(floats=(32, 64)).half_width(),
TypeSet(floats=(32, 32)))
self.assertEqual(TypeSet(floats=(32, 32)).double_width(),
TypeSet(floats=(64, 64)))
self.assertEqual(TypeSet(floats=(32, 64)).double_width(),
TypeSet(floats=(64, 64)))
# Bools have trickier behavior around b1 (since b2, b4 don't exist)
self.assertEqual(TypeSet(bools=(1, 8)).half_width(),
TypeSet())
t = TypeSet()
t.bools = set([8, 16])
self.assertEqual(TypeSet(bools=(1, 32)).half_width(), t)
# double_width() of bools={1, 8, 16} must not include 2 or 8
t.bools = set([16, 32])
self.assertEqual(TypeSet(bools=(1, 16)).double_width(), t)
self.assertEqual(TypeSet(bools=(32, 64)).double_width(),
TypeSet(bools=(64, 64)))
def test_get_singleton(self):
# Raise error when calling get_singleton() on non-singleton TS
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
with self.assertRaises(AssertionError):
t.get_singleton()
t = TypeSet(lanes=(1, 2), floats=(32, 32))
with self.assertRaises(AssertionError):
t.get_singleton()
self.assertEqual(TypeSet(ints=(16, 16)).get_singleton(), i16)
self.assertEqual(TypeSet(floats=(64, 64)).get_singleton(), f64)
self.assertEqual(TypeSet(bools=(1, 1)).get_singleton(), b1)
self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(),
i32.by(4))
class TestTypeVar(TestCase): class TestTypeVar(TestCase):
def test_functions(self): def test_functions(self):

View File

@@ -97,7 +97,7 @@ class VectorType(ValueType):
# type: (ScalarType, int) -> None # type: (ScalarType, int) -> None
assert isinstance(base, ScalarType), 'SIMD lanes must be scalar types' assert isinstance(base, ScalarType), 'SIMD lanes must be scalar types'
super(VectorType, self).__init__( super(VectorType, self).__init__(
name='{}x{}'.format(base.name, lanes), name=VectorType.get_name(base, lanes),
membytes=lanes*base.membytes, membytes=lanes*base.membytes,
doc=""" doc="""
A SIMD vector with {} lanes containing a `{}` each. A SIMD vector with {} lanes containing a `{}` each.
@@ -111,6 +111,11 @@ class VectorType(ValueType):
return ('VectorType(base={}, lanes={})' return ('VectorType(base={}, lanes={})'
.format(self.base.name, self.lanes)) .format(self.base.name, self.lanes))
@staticmethod
def get_name(base, lanes):
# type: (ValueType, int) -> str
return '{}x{}'.format(base.name, lanes)
class IntType(ScalarType): class IntType(ScalarType):
"""A concrete scalar integer type.""" """A concrete scalar integer type."""
@@ -119,7 +124,7 @@ class IntType(ScalarType):
# type: (int) -> None # type: (int) -> None
assert bits > 0, 'IntType must have positive number of bits' assert bits > 0, 'IntType must have positive number of bits'
super(IntType, self).__init__( super(IntType, self).__init__(
name='i{:d}'.format(bits), name=IntType.get_name(bits),
membytes=bits // 8, membytes=bits // 8,
doc="An integer type with {} bits.".format(bits)) doc="An integer type with {} bits.".format(bits))
self.bits = bits self.bits = bits
@@ -128,6 +133,11 @@ class IntType(ScalarType):
# type: () -> str # type: () -> str
return 'IntType(bits={})'.format(self.bits) return 'IntType(bits={})'.format(self.bits)
@staticmethod
def get_name(bits):
# type: (int) -> str
return 'i{:d}'.format(bits)
class FloatType(ScalarType): class FloatType(ScalarType):
"""A concrete scalar floating point type.""" """A concrete scalar floating point type."""
@@ -136,7 +146,7 @@ class FloatType(ScalarType):
# type: (int, str) -> None # type: (int, str) -> None
assert bits > 0, 'FloatType must have positive number of bits' assert bits > 0, 'FloatType must have positive number of bits'
super(FloatType, self).__init__( super(FloatType, self).__init__(
name='f{:d}'.format(bits), name=FloatType.get_name(bits),
membytes=bits // 8, membytes=bits // 8,
doc=doc) doc=doc)
self.bits = bits self.bits = bits
@@ -145,6 +155,11 @@ class FloatType(ScalarType):
# type: () -> str # type: () -> str
return 'FloatType(bits={})'.format(self.bits) return 'FloatType(bits={})'.format(self.bits)
@staticmethod
def get_name(bits):
# type: (int) -> str
return 'f{:d}'.format(bits)
class BoolType(ScalarType): class BoolType(ScalarType):
"""A concrete scalar boolean type.""" """A concrete scalar boolean type."""
@@ -153,7 +168,7 @@ class BoolType(ScalarType):
# type: (int) -> None # type: (int) -> None
assert bits > 0, 'BoolType must have positive number of bits' assert bits > 0, 'BoolType must have positive number of bits'
super(BoolType, self).__init__( super(BoolType, self).__init__(
name='b{:d}'.format(bits), name=BoolType.get_name(bits),
membytes=bits // 8, membytes=bits // 8,
doc="A boolean type with {} bits.".format(bits)) doc="A boolean type with {} bits.".format(bits))
self.bits = bits self.bits = bits
@@ -161,3 +176,8 @@ class BoolType(ScalarType):
def __repr__(self): def __repr__(self):
# type: () -> str # type: () -> str
return 'BoolType(bits={})'.format(self.bits) return 'BoolType(bits={})'.format(self.bits)
@staticmethod
def get_name(bits):
# type: (int) -> str
return 'b{:d}'.format(bits)

View File

@@ -7,15 +7,19 @@ polymorphic by using type variables.
from __future__ import absolute_import from __future__ import absolute_import
import math import math
from . import types, is_power_of_two from . import types, is_power_of_two
from copy import deepcopy
from .types import ValueType, IntType, FloatType, BoolType
try: try:
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
from typing import cast
if TYPE_CHECKING: if TYPE_CHECKING:
from srcgen import Formatter # noqa from srcgen import Formatter # noqa
Interval = Tuple[int, int] Interval = Tuple[int, int]
# An Interval where `True` means 'everything' # An Interval where `True` means 'everything'
BoolInterval = Union[bool, Interval] BoolInterval = Union[bool, Interval]
except ImportError: except ImportError:
TYPE_CHECKING = False
pass pass
MAX_LANES = 256 MAX_LANES = 256
@@ -112,6 +116,16 @@ def interval_to_set(intv):
return set([2**i for i in range(int_log2(lo), int_log2(hi)+1)]) return set([2**i for i in range(int_log2(lo), int_log2(hi)+1)])
def legal_bool(bits):
# type: (int) -> bool
"""
True iff bits is a legal bit width for a bool type.
bits == 1 || bits \in { 8, 16, .. MAX_BITS }
"""
return bits == 1 or \
(bits >= 8 and bits <= MAX_BITS and is_power_of_two(bits))
class TypeSet(object): class TypeSet(object):
""" """
A set of types. A set of types.
@@ -165,7 +179,15 @@ class TypeSet(object):
self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS))) self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS)))
self.floats = interval_to_set(decode_interval(floats, (32, 64))) self.floats = interval_to_set(decode_interval(floats, (32, 64)))
self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS))) self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS)))
self.bools = set(filter(lambda x: x == 1 or x >= 8, self.bools)) self.bools = set(filter(legal_bool, self.bools))
def copy(self):
# type: (TypeSet) -> TypeSet
"""
Return a copy of our self. deepcopy is sufficient and safe here, since
TypeSet contains only sets of numbers.
"""
return deepcopy(self)
def typeset_key(self): def typeset_key(self):
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple] # type: () -> Tuple[Tuple, Tuple, Tuple, Tuple]
@@ -241,6 +263,109 @@ class TypeSet(object):
return self return self
def lane_of(self):
# type: () -> TypeSet
"""
Return a TypeSet describing the image of self across lane_of
"""
new = self.copy()
new.lanes = set([1])
return new
def as_bool(self):
# type: () -> TypeSet
"""
Return a TypeSet describing the image of self across as_bool
"""
new = self.copy()
new.ints = set()
new.floats = set()
new.bools = self.ints.union(self.floats).union(self.bools)
if 1 in self.lanes:
new.bools.add(1)
return new
def half_width(self):
# type: () -> TypeSet
"""
Return a TypeSet describing the image of self across halfwidth
"""
new = self.copy()
new.ints = set([x/2 for x in self.ints if x > 8])
new.floats = set([x/2 for x in self.floats if x > 32])
new.bools = set([x/2 for x in self.bools if x > 8])
return new
def double_width(self):
# type: () -> TypeSet
"""
Return a TypeSet describing the image of self across doublewidth
"""
new = self.copy()
new.ints = set([x*2 for x in self.ints if x < MAX_BITS])
new.floats = set([x*2 for x in self.floats if x < MAX_BITS])
new.bools = set(filter(legal_bool,
set([x*2 for x in self.bools if x < MAX_BITS])))
return new
def half_vector(self):
# type: () -> TypeSet
"""
Return a TypeSet describing the image of self across halfvector
"""
new = self.copy()
new.lanes = set([x/2 for x in self.lanes if x > 1])
return new
def double_vector(self):
# type: () -> TypeSet
"""
Return a TypeSet describing the image of self across doublevector
"""
new = self.copy()
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
return new
def size(self):
# type: () -> int
"""
Return the number of concrete types represented by this typeset
"""
return len(self.lanes) * (len(self.ints) + len(self.floats) +
len(self.bools))
def get_singleton(self):
# type: () -> types.ValueType
"""
Return the singleton type represented by self. Can only call on
typesets containing 1 type.
"""
assert self.size() == 1
if len(self.ints) > 0:
bits = tuple(self.ints)[0]
scalar_type = ValueType.by_name(IntType.get_name(bits))
elif len(self.floats) > 0:
bits = tuple(self.floats)[0]
scalar_type = ValueType.by_name(FloatType.get_name(bits))
else:
bits = tuple(self.bools)[0]
scalar_type = ValueType.by_name(BoolType.get_name(bits))
nlanes = tuple(self.lanes)[0]
if nlanes == 1:
return scalar_type
else:
if TYPE_CHECKING:
return cast(types.ScalarType, scalar_type).by(nlanes)
else:
return scalar_type.by(nlanes)
class TypeVar(object): class TypeVar(object):
""" """
@@ -271,7 +396,6 @@ class TypeVar(object):
# type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, TypeVar, str) -> None # noqa # type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, TypeVar, str) -> None # noqa
self.name = name self.name = name
self.__doc__ = doc self.__doc__ = doc
self.singleton_type = None # type: types.ValueType
self.is_derived = isinstance(base, TypeVar) self.is_derived = isinstance(base, TypeVar)
if base: if base:
assert self.is_derived assert self.is_derived
@@ -313,7 +437,6 @@ class TypeVar(object):
tv = TypeVar( tv = TypeVar(
typ.name, typ.__doc__, typ.name, typ.__doc__,
ints, floats, bools, simd=lanes) ints, floats, bools, simd=lanes)
tv.singleton_type = typ
return tv return tv
def __str__(self): def __str__(self):
@@ -406,14 +529,13 @@ class TypeVar(object):
Return a derived type variable that has the same number of vector lanes Return a derived type variable that has the same number of vector lanes
as this one, but the lanes are half the width. as this one, but the lanes are half the width.
""" """
if not self.is_derived: ts = self.get_typeset()
ts = self.type_set if len(ts.ints) > 0:
if len(ts.ints) > 0: assert min(ts.ints) > 8, "Can't halve all integer types"
assert min(ts.ints) > 8, "Can't halve all integer types" if len(ts.floats) > 0:
if len(ts.floats) > 0: assert min(ts.floats) > 32, "Can't halve all float types"
assert min(ts.floats) > 32, "Can't halve all float types" if len(ts.bools) > 0:
if len(ts.bools) > 0: assert min(ts.bools) > 8, "Can't halve all boolean types"
assert min(ts.bools) > 8, "Can't halve all boolean types"
return TypeVar.derived(self, self.HALFWIDTH) return TypeVar.derived(self, self.HALFWIDTH)
@@ -423,16 +545,13 @@ class TypeVar(object):
Return a derived type variable that has the same number of vector lanes Return a derived type variable that has the same number of vector lanes
as this one, but the lanes are double the width. as this one, but the lanes are double the width.
""" """
if not self.is_derived: ts = self.get_typeset()
ts = self.type_set if len(ts.ints) > 0:
if len(ts.ints) > 0: assert max(ts.ints) < MAX_BITS, "Can't double all integer types."
assert max(ts.ints) < MAX_BITS,\ if len(ts.floats) > 0:
"Can't double all integer types." assert max(ts.floats) < MAX_BITS, "Can't double all float types."
if len(ts.floats) > 0: if len(ts.bools) > 0:
assert max(ts.floats) < MAX_BITS,\ assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
"Can't double all float types."
if len(ts.bools) > 0:
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
return TypeVar.derived(self, self.DOUBLEWIDTH) return TypeVar.derived(self, self.DOUBLEWIDTH)
@@ -442,9 +561,8 @@ class TypeVar(object):
Return a derived type variable that has half the number of vector lanes Return a derived type variable that has half the number of vector lanes
as this one, with the same lane type. as this one, with the same lane type.
""" """
if not self.is_derived: ts = self.get_typeset()
ts = self.type_set assert min(ts.lanes) > 1, "Can't halve a scalar type"
assert min(ts.lanes) > 1, "Can't halve a scalar type"
return TypeVar.derived(self, self.HALFVECTOR) return TypeVar.derived(self, self.HALFVECTOR)
@@ -454,12 +572,23 @@ class TypeVar(object):
Return a derived type variable that has twice the number of vector Return a derived type variable that has twice the number of vector
lanes as this one, with the same lane type. lanes as this one, with the same lane type.
""" """
if not self.is_derived: ts = self.get_typeset()
ts = self.type_set assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
return TypeVar.derived(self, self.DOUBLEVECTOR) return TypeVar.derived(self, self.DOUBLEVECTOR)
def singleton_type(self):
# type: () -> ValueType
"""
If the associated typeset has a single type return it. Otherwise return
None
"""
ts = self.get_typeset()
if ts.size() != 1:
return None
return ts.get_singleton()
def free_typevar(self): def free_typevar(self):
# type: () -> TypeVar # type: () -> TypeVar
""" """
@@ -467,7 +596,7 @@ class TypeVar(object):
""" """
if self.is_derived: if self.is_derived:
return self.base return self.base
elif self.singleton_type: elif self.singleton_type() is not None:
# A singleton type variable is not a proper free variable. # A singleton type variable is not a proper free variable.
return None return None
else: else:
@@ -481,8 +610,8 @@ class TypeVar(object):
if self.is_derived: if self.is_derived:
return '{}.{}()'.format( return '{}.{}()'.format(
self.base.rust_expr(), self.derived_func) self.base.rust_expr(), self.derived_func)
elif self.singleton_type: elif self.singleton_type():
return self.singleton_type.rust_name() return self.singleton_type().rust_name()
else: else:
return self.name return self.name
@@ -501,9 +630,6 @@ class TypeVar(object):
if not a.is_derived and not b.is_derived: if not a.is_derived and not b.is_derived:
a.type_set &= b.type_set a.type_set &= b.type_set
# TODO: What if a.type_set becomes empty?
if not a.singleton_type:
a.singleton_type = b.singleton_type
return return
# TODO: Implement constraints for derived type variables. # TODO: Implement constraints for derived type variables.
@@ -514,3 +640,29 @@ class TypeVar(object):
# #
# For the fully general case, we would need to compute an image typeset # For the fully general case, we would need to compute an image typeset
# for `b` and propagate a `a.derived_func` pre-image to `a.base`. # for `b` and propagate a `a.derived_func` pre-image to `a.base`.
def get_typeset(self):
# type: () -> TypeSet
"""
Returns the typeset for this TV. If the TV is derived, computes it
recursively from the derived function and the base's typeset.
"""
if not self.is_derived:
return self.type_set
else:
if (self.derived_func == TypeVar.SAMEAS):
return self.base.get_typeset()
elif (self.derived_func == TypeVar.LANEOF):
return self.base.get_typeset().lane_of()
elif (self.derived_func == TypeVar.ASBOOL):
return self.base.get_typeset().as_bool()
elif (self.derived_func == TypeVar.HALFWIDTH):
return self.base.get_typeset().half_width()
elif (self.derived_func == TypeVar.DOUBLEWIDTH):
return self.base.get_typeset().double_width()
elif (self.derived_func == TypeVar.HALFVECTOR):
return self.base.get_typeset().half_vector()
elif (self.derived_func == TypeVar.DOUBLEVECTOR):
return self.base.get_typeset().double_vector()
else:
assert False, "Unknown derived function: " + self.derived_func

View File

@@ -254,7 +254,7 @@ class XForm(object):
# Some variables have a fixed type which appears as a type variable # Some variables have a fixed type which appears as a type variable
# with a singleton_type field set. That's allowed for temps too. # with a singleton_type field set. That's allowed for temps too.
for v in fvars: for v in fvars:
if v.is_temp() and not v.typevar.singleton_type: if v.is_temp() and not v.typevar.singleton_type():
raise AssertionError( raise AssertionError(
"Cannot determine type of temp '{}' in xform:\n{}" "Cannot determine type of temp '{}' in xform:\n{}"
.format(v, self)) .format(v, self))

View File

@@ -321,8 +321,8 @@ def get_constraint(op, ctrl_typevar, type_sets):
tv = op.typevar tv = op.typevar
# A concrete value type. # A concrete value type.
if tv.singleton_type: if tv.singleton_type():
return 'Concrete({})'.format(tv.singleton_type.rust_name()) return 'Concrete({})'.format(tv.singleton_type().rust_name())
if tv.free_typevar() is not ctrl_typevar: if tv.free_typevar() is not ctrl_typevar:
assert not tv.is_derived assert not tv.is_derived