Add image computation of typesets; Remove TypeVar.singleton_type - instead derive singleton type from typeset; (#104)
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
9487b885da
commit
6a9438d274
@@ -2,15 +2,10 @@
|
|||||||
The base.types module predefines all the Cretonne scalar types.
|
The base.types module predefines all the Cretonne scalar types.
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from cdsl.types import ScalarType, IntType, FloatType, BoolType
|
from cdsl.types import IntType, FloatType, BoolType
|
||||||
|
|
||||||
#: Boolean.
|
#: Boolean.
|
||||||
b1 = ScalarType(
|
b1 = BoolType(1) #: 1-bit bool. Type is abstract (can't be stored in mem)
|
||||||
'b1', 0,
|
|
||||||
"""
|
|
||||||
A boolean value that is either true or false.
|
|
||||||
""")
|
|
||||||
|
|
||||||
b8 = BoolType(8) #: 8-bit bool.
|
b8 = BoolType(8) #: 8-bit bool.
|
||||||
b16 = BoolType(16) #: 16-bit bool.
|
b16 = BoolType(16) #: 16-bit bool.
|
||||||
b32 = BoolType(32) #: 32-bit bool.
|
b32 = BoolType(32) #: 32-bit bool.
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ class Instruction(object):
|
|||||||
try:
|
try:
|
||||||
opnum = self.value_opnums[self.format.typevar_operand]
|
opnum = self.value_opnums[self.format.typevar_operand]
|
||||||
tv = self.ins[opnum].typevar
|
tv = self.ins[opnum].typevar
|
||||||
if tv is tv.free_typevar():
|
if tv is tv.free_typevar() or tv.singleton_type() is not None:
|
||||||
self.other_typevars = self._verify_ctrl_typevar(tv)
|
self.other_typevars = self._verify_ctrl_typevar(tv)
|
||||||
self.ctrl_typevar = tv
|
self.ctrl_typevar = tv
|
||||||
self.use_typevar_operand = True
|
self.use_typevar_operand = True
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from unittest import TestCase
|
|||||||
from doctest import DocTestSuite
|
from doctest import DocTestSuite
|
||||||
from . import typevar
|
from . import typevar
|
||||||
from .typevar import TypeSet, TypeVar
|
from .typevar import TypeSet, TypeVar
|
||||||
from base.types import i32
|
from base.types import i32, i16, b1, f64
|
||||||
|
|
||||||
|
|
||||||
def load_tests(loader, tests, ignore):
|
def load_tests(loader, tests, ignore):
|
||||||
@@ -45,6 +45,84 @@ class TestTypeSet(TestCase):
|
|||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
a in s
|
a in s
|
||||||
|
|
||||||
|
def test_forward_images(self):
|
||||||
|
a = TypeSet(lanes=(2, 8), ints=(8, 8), floats=(32, 32))
|
||||||
|
b = TypeSet(lanes=(1, 8), ints=(8, 8), floats=(32, 32))
|
||||||
|
self.assertEqual(a.lane_of(), TypeSet(ints=(8, 8), floats=(32, 32)))
|
||||||
|
|
||||||
|
c = TypeSet(lanes=(2, 8))
|
||||||
|
c.bools = set([8, 32])
|
||||||
|
|
||||||
|
# Test case with disjoint intervals
|
||||||
|
self.assertEqual(a.as_bool(), c)
|
||||||
|
|
||||||
|
# For as_bool check b1 is present when 1 \in lanes
|
||||||
|
d = TypeSet(lanes=(1, 8))
|
||||||
|
d.bools = set([1, 8, 32])
|
||||||
|
self.assertEqual(b.as_bool(), d)
|
||||||
|
|
||||||
|
self.assertEqual(TypeSet(lanes=(1, 32)).half_vector(),
|
||||||
|
TypeSet(lanes=(1, 16)))
|
||||||
|
|
||||||
|
self.assertEqual(TypeSet(lanes=(1, 32)).double_vector(),
|
||||||
|
TypeSet(lanes=(2, 64)))
|
||||||
|
|
||||||
|
self.assertEqual(TypeSet(lanes=(128, 256)).double_vector(),
|
||||||
|
TypeSet(lanes=(256, 256)))
|
||||||
|
|
||||||
|
self.assertEqual(TypeSet(ints=(8, 32)).half_width(),
|
||||||
|
TypeSet(ints=(8, 16)))
|
||||||
|
|
||||||
|
self.assertEqual(TypeSet(ints=(8, 32)).double_width(),
|
||||||
|
TypeSet(ints=(16, 64)))
|
||||||
|
|
||||||
|
self.assertEqual(TypeSet(ints=(32, 64)).double_width(),
|
||||||
|
TypeSet(ints=(64, 64)))
|
||||||
|
|
||||||
|
# Should produce an empty ts
|
||||||
|
self.assertEqual(TypeSet(floats=(32, 32)).half_width(),
|
||||||
|
TypeSet())
|
||||||
|
|
||||||
|
self.assertEqual(TypeSet(floats=(32, 64)).half_width(),
|
||||||
|
TypeSet(floats=(32, 32)))
|
||||||
|
|
||||||
|
self.assertEqual(TypeSet(floats=(32, 32)).double_width(),
|
||||||
|
TypeSet(floats=(64, 64)))
|
||||||
|
|
||||||
|
self.assertEqual(TypeSet(floats=(32, 64)).double_width(),
|
||||||
|
TypeSet(floats=(64, 64)))
|
||||||
|
|
||||||
|
# Bools have trickier behavior around b1 (since b2, b4 don't exist)
|
||||||
|
self.assertEqual(TypeSet(bools=(1, 8)).half_width(),
|
||||||
|
TypeSet())
|
||||||
|
|
||||||
|
t = TypeSet()
|
||||||
|
t.bools = set([8, 16])
|
||||||
|
self.assertEqual(TypeSet(bools=(1, 32)).half_width(), t)
|
||||||
|
|
||||||
|
# double_width() of bools={1, 8, 16} must not include 2 or 8
|
||||||
|
t.bools = set([16, 32])
|
||||||
|
self.assertEqual(TypeSet(bools=(1, 16)).double_width(), t)
|
||||||
|
|
||||||
|
self.assertEqual(TypeSet(bools=(32, 64)).double_width(),
|
||||||
|
TypeSet(bools=(64, 64)))
|
||||||
|
|
||||||
|
def test_get_singleton(self):
|
||||||
|
# Raise error when calling get_singleton() on non-singleton TS
|
||||||
|
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
t.get_singleton()
|
||||||
|
t = TypeSet(lanes=(1, 2), floats=(32, 32))
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
t.get_singleton()
|
||||||
|
|
||||||
|
self.assertEqual(TypeSet(ints=(16, 16)).get_singleton(), i16)
|
||||||
|
self.assertEqual(TypeSet(floats=(64, 64)).get_singleton(), f64)
|
||||||
|
self.assertEqual(TypeSet(bools=(1, 1)).get_singleton(), b1)
|
||||||
|
self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(),
|
||||||
|
i32.by(4))
|
||||||
|
|
||||||
|
|
||||||
class TestTypeVar(TestCase):
|
class TestTypeVar(TestCase):
|
||||||
def test_functions(self):
|
def test_functions(self):
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ class VectorType(ValueType):
|
|||||||
# type: (ScalarType, int) -> None
|
# type: (ScalarType, int) -> None
|
||||||
assert isinstance(base, ScalarType), 'SIMD lanes must be scalar types'
|
assert isinstance(base, ScalarType), 'SIMD lanes must be scalar types'
|
||||||
super(VectorType, self).__init__(
|
super(VectorType, self).__init__(
|
||||||
name='{}x{}'.format(base.name, lanes),
|
name=VectorType.get_name(base, lanes),
|
||||||
membytes=lanes*base.membytes,
|
membytes=lanes*base.membytes,
|
||||||
doc="""
|
doc="""
|
||||||
A SIMD vector with {} lanes containing a `{}` each.
|
A SIMD vector with {} lanes containing a `{}` each.
|
||||||
@@ -111,6 +111,11 @@ class VectorType(ValueType):
|
|||||||
return ('VectorType(base={}, lanes={})'
|
return ('VectorType(base={}, lanes={})'
|
||||||
.format(self.base.name, self.lanes))
|
.format(self.base.name, self.lanes))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name(base, lanes):
|
||||||
|
# type: (ValueType, int) -> str
|
||||||
|
return '{}x{}'.format(base.name, lanes)
|
||||||
|
|
||||||
|
|
||||||
class IntType(ScalarType):
|
class IntType(ScalarType):
|
||||||
"""A concrete scalar integer type."""
|
"""A concrete scalar integer type."""
|
||||||
@@ -119,7 +124,7 @@ class IntType(ScalarType):
|
|||||||
# type: (int) -> None
|
# type: (int) -> None
|
||||||
assert bits > 0, 'IntType must have positive number of bits'
|
assert bits > 0, 'IntType must have positive number of bits'
|
||||||
super(IntType, self).__init__(
|
super(IntType, self).__init__(
|
||||||
name='i{:d}'.format(bits),
|
name=IntType.get_name(bits),
|
||||||
membytes=bits // 8,
|
membytes=bits // 8,
|
||||||
doc="An integer type with {} bits.".format(bits))
|
doc="An integer type with {} bits.".format(bits))
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
@@ -128,6 +133,11 @@ class IntType(ScalarType):
|
|||||||
# type: () -> str
|
# type: () -> str
|
||||||
return 'IntType(bits={})'.format(self.bits)
|
return 'IntType(bits={})'.format(self.bits)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name(bits):
|
||||||
|
# type: (int) -> str
|
||||||
|
return 'i{:d}'.format(bits)
|
||||||
|
|
||||||
|
|
||||||
class FloatType(ScalarType):
|
class FloatType(ScalarType):
|
||||||
"""A concrete scalar floating point type."""
|
"""A concrete scalar floating point type."""
|
||||||
@@ -136,7 +146,7 @@ class FloatType(ScalarType):
|
|||||||
# type: (int, str) -> None
|
# type: (int, str) -> None
|
||||||
assert bits > 0, 'FloatType must have positive number of bits'
|
assert bits > 0, 'FloatType must have positive number of bits'
|
||||||
super(FloatType, self).__init__(
|
super(FloatType, self).__init__(
|
||||||
name='f{:d}'.format(bits),
|
name=FloatType.get_name(bits),
|
||||||
membytes=bits // 8,
|
membytes=bits // 8,
|
||||||
doc=doc)
|
doc=doc)
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
@@ -145,6 +155,11 @@ class FloatType(ScalarType):
|
|||||||
# type: () -> str
|
# type: () -> str
|
||||||
return 'FloatType(bits={})'.format(self.bits)
|
return 'FloatType(bits={})'.format(self.bits)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name(bits):
|
||||||
|
# type: (int) -> str
|
||||||
|
return 'f{:d}'.format(bits)
|
||||||
|
|
||||||
|
|
||||||
class BoolType(ScalarType):
|
class BoolType(ScalarType):
|
||||||
"""A concrete scalar boolean type."""
|
"""A concrete scalar boolean type."""
|
||||||
@@ -153,7 +168,7 @@ class BoolType(ScalarType):
|
|||||||
# type: (int) -> None
|
# type: (int) -> None
|
||||||
assert bits > 0, 'BoolType must have positive number of bits'
|
assert bits > 0, 'BoolType must have positive number of bits'
|
||||||
super(BoolType, self).__init__(
|
super(BoolType, self).__init__(
|
||||||
name='b{:d}'.format(bits),
|
name=BoolType.get_name(bits),
|
||||||
membytes=bits // 8,
|
membytes=bits // 8,
|
||||||
doc="A boolean type with {} bits.".format(bits))
|
doc="A boolean type with {} bits.".format(bits))
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
@@ -161,3 +176,8 @@ class BoolType(ScalarType):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# type: () -> str
|
# type: () -> str
|
||||||
return 'BoolType(bits={})'.format(self.bits)
|
return 'BoolType(bits={})'.format(self.bits)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name(bits):
|
||||||
|
# type: (int) -> str
|
||||||
|
return 'b{:d}'.format(bits)
|
||||||
|
|||||||
@@ -7,15 +7,19 @@ polymorphic by using type variables.
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
import math
|
import math
|
||||||
from . import types, is_power_of_two
|
from . import types, is_power_of_two
|
||||||
|
from copy import deepcopy
|
||||||
|
from .types import ValueType, IntType, FloatType, BoolType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
|
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
|
||||||
|
from typing import cast
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from srcgen import Formatter # noqa
|
from srcgen import Formatter # noqa
|
||||||
Interval = Tuple[int, int]
|
Interval = Tuple[int, int]
|
||||||
# An Interval where `True` means 'everything'
|
# An Interval where `True` means 'everything'
|
||||||
BoolInterval = Union[bool, Interval]
|
BoolInterval = Union[bool, Interval]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
TYPE_CHECKING = False
|
||||||
pass
|
pass
|
||||||
|
|
||||||
MAX_LANES = 256
|
MAX_LANES = 256
|
||||||
@@ -112,6 +116,16 @@ def interval_to_set(intv):
|
|||||||
return set([2**i for i in range(int_log2(lo), int_log2(hi)+1)])
|
return set([2**i for i in range(int_log2(lo), int_log2(hi)+1)])
|
||||||
|
|
||||||
|
|
||||||
|
def legal_bool(bits):
|
||||||
|
# type: (int) -> bool
|
||||||
|
"""
|
||||||
|
True iff bits is a legal bit width for a bool type.
|
||||||
|
bits == 1 || bits \in { 8, 16, .. MAX_BITS }
|
||||||
|
"""
|
||||||
|
return bits == 1 or \
|
||||||
|
(bits >= 8 and bits <= MAX_BITS and is_power_of_two(bits))
|
||||||
|
|
||||||
|
|
||||||
class TypeSet(object):
|
class TypeSet(object):
|
||||||
"""
|
"""
|
||||||
A set of types.
|
A set of types.
|
||||||
@@ -165,7 +179,15 @@ class TypeSet(object):
|
|||||||
self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS)))
|
self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS)))
|
||||||
self.floats = interval_to_set(decode_interval(floats, (32, 64)))
|
self.floats = interval_to_set(decode_interval(floats, (32, 64)))
|
||||||
self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS)))
|
self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS)))
|
||||||
self.bools = set(filter(lambda x: x == 1 or x >= 8, self.bools))
|
self.bools = set(filter(legal_bool, self.bools))
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
# type: (TypeSet) -> TypeSet
|
||||||
|
"""
|
||||||
|
Return a copy of our self. deepcopy is sufficient and safe here, since
|
||||||
|
TypeSet contains only sets of numbers.
|
||||||
|
"""
|
||||||
|
return deepcopy(self)
|
||||||
|
|
||||||
def typeset_key(self):
|
def typeset_key(self):
|
||||||
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple]
|
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple]
|
||||||
@@ -241,6 +263,109 @@ class TypeSet(object):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def lane_of(self):
|
||||||
|
# type: () -> TypeSet
|
||||||
|
"""
|
||||||
|
Return a TypeSet describing the image of self across lane_of
|
||||||
|
"""
|
||||||
|
new = self.copy()
|
||||||
|
new.lanes = set([1])
|
||||||
|
return new
|
||||||
|
|
||||||
|
def as_bool(self):
|
||||||
|
# type: () -> TypeSet
|
||||||
|
"""
|
||||||
|
Return a TypeSet describing the image of self across as_bool
|
||||||
|
"""
|
||||||
|
new = self.copy()
|
||||||
|
new.ints = set()
|
||||||
|
new.floats = set()
|
||||||
|
new.bools = self.ints.union(self.floats).union(self.bools)
|
||||||
|
|
||||||
|
if 1 in self.lanes:
|
||||||
|
new.bools.add(1)
|
||||||
|
return new
|
||||||
|
|
||||||
|
def half_width(self):
|
||||||
|
# type: () -> TypeSet
|
||||||
|
"""
|
||||||
|
Return a TypeSet describing the image of self across halfwidth
|
||||||
|
"""
|
||||||
|
new = self.copy()
|
||||||
|
new.ints = set([x/2 for x in self.ints if x > 8])
|
||||||
|
new.floats = set([x/2 for x in self.floats if x > 32])
|
||||||
|
new.bools = set([x/2 for x in self.bools if x > 8])
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
|
def double_width(self):
|
||||||
|
# type: () -> TypeSet
|
||||||
|
"""
|
||||||
|
Return a TypeSet describing the image of self across doublewidth
|
||||||
|
"""
|
||||||
|
new = self.copy()
|
||||||
|
new.ints = set([x*2 for x in self.ints if x < MAX_BITS])
|
||||||
|
new.floats = set([x*2 for x in self.floats if x < MAX_BITS])
|
||||||
|
new.bools = set(filter(legal_bool,
|
||||||
|
set([x*2 for x in self.bools if x < MAX_BITS])))
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
|
def half_vector(self):
|
||||||
|
# type: () -> TypeSet
|
||||||
|
"""
|
||||||
|
Return a TypeSet describing the image of self across halfvector
|
||||||
|
"""
|
||||||
|
new = self.copy()
|
||||||
|
new.lanes = set([x/2 for x in self.lanes if x > 1])
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
|
def double_vector(self):
|
||||||
|
# type: () -> TypeSet
|
||||||
|
"""
|
||||||
|
Return a TypeSet describing the image of self across doublevector
|
||||||
|
"""
|
||||||
|
new = self.copy()
|
||||||
|
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
# type: () -> int
|
||||||
|
"""
|
||||||
|
Return the number of concrete types represented by this typeset
|
||||||
|
"""
|
||||||
|
return len(self.lanes) * (len(self.ints) + len(self.floats) +
|
||||||
|
len(self.bools))
|
||||||
|
|
||||||
|
def get_singleton(self):
|
||||||
|
# type: () -> types.ValueType
|
||||||
|
"""
|
||||||
|
Return the singleton type represented by self. Can only call on
|
||||||
|
typesets containing 1 type.
|
||||||
|
"""
|
||||||
|
assert self.size() == 1
|
||||||
|
if len(self.ints) > 0:
|
||||||
|
bits = tuple(self.ints)[0]
|
||||||
|
scalar_type = ValueType.by_name(IntType.get_name(bits))
|
||||||
|
elif len(self.floats) > 0:
|
||||||
|
bits = tuple(self.floats)[0]
|
||||||
|
scalar_type = ValueType.by_name(FloatType.get_name(bits))
|
||||||
|
else:
|
||||||
|
bits = tuple(self.bools)[0]
|
||||||
|
scalar_type = ValueType.by_name(BoolType.get_name(bits))
|
||||||
|
|
||||||
|
nlanes = tuple(self.lanes)[0]
|
||||||
|
|
||||||
|
if nlanes == 1:
|
||||||
|
return scalar_type
|
||||||
|
else:
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
return cast(types.ScalarType, scalar_type).by(nlanes)
|
||||||
|
else:
|
||||||
|
return scalar_type.by(nlanes)
|
||||||
|
|
||||||
|
|
||||||
class TypeVar(object):
|
class TypeVar(object):
|
||||||
"""
|
"""
|
||||||
@@ -271,7 +396,6 @@ class TypeVar(object):
|
|||||||
# type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, TypeVar, str) -> None # noqa
|
# type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, TypeVar, str) -> None # noqa
|
||||||
self.name = name
|
self.name = name
|
||||||
self.__doc__ = doc
|
self.__doc__ = doc
|
||||||
self.singleton_type = None # type: types.ValueType
|
|
||||||
self.is_derived = isinstance(base, TypeVar)
|
self.is_derived = isinstance(base, TypeVar)
|
||||||
if base:
|
if base:
|
||||||
assert self.is_derived
|
assert self.is_derived
|
||||||
@@ -313,7 +437,6 @@ class TypeVar(object):
|
|||||||
tv = TypeVar(
|
tv = TypeVar(
|
||||||
typ.name, typ.__doc__,
|
typ.name, typ.__doc__,
|
||||||
ints, floats, bools, simd=lanes)
|
ints, floats, bools, simd=lanes)
|
||||||
tv.singleton_type = typ
|
|
||||||
return tv
|
return tv
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -406,14 +529,13 @@ class TypeVar(object):
|
|||||||
Return a derived type variable that has the same number of vector lanes
|
Return a derived type variable that has the same number of vector lanes
|
||||||
as this one, but the lanes are half the width.
|
as this one, but the lanes are half the width.
|
||||||
"""
|
"""
|
||||||
if not self.is_derived:
|
ts = self.get_typeset()
|
||||||
ts = self.type_set
|
if len(ts.ints) > 0:
|
||||||
if len(ts.ints) > 0:
|
assert min(ts.ints) > 8, "Can't halve all integer types"
|
||||||
assert min(ts.ints) > 8, "Can't halve all integer types"
|
if len(ts.floats) > 0:
|
||||||
if len(ts.floats) > 0:
|
assert min(ts.floats) > 32, "Can't halve all float types"
|
||||||
assert min(ts.floats) > 32, "Can't halve all float types"
|
if len(ts.bools) > 0:
|
||||||
if len(ts.bools) > 0:
|
assert min(ts.bools) > 8, "Can't halve all boolean types"
|
||||||
assert min(ts.bools) > 8, "Can't halve all boolean types"
|
|
||||||
|
|
||||||
return TypeVar.derived(self, self.HALFWIDTH)
|
return TypeVar.derived(self, self.HALFWIDTH)
|
||||||
|
|
||||||
@@ -423,16 +545,13 @@ class TypeVar(object):
|
|||||||
Return a derived type variable that has the same number of vector lanes
|
Return a derived type variable that has the same number of vector lanes
|
||||||
as this one, but the lanes are double the width.
|
as this one, but the lanes are double the width.
|
||||||
"""
|
"""
|
||||||
if not self.is_derived:
|
ts = self.get_typeset()
|
||||||
ts = self.type_set
|
if len(ts.ints) > 0:
|
||||||
if len(ts.ints) > 0:
|
assert max(ts.ints) < MAX_BITS, "Can't double all integer types."
|
||||||
assert max(ts.ints) < MAX_BITS,\
|
if len(ts.floats) > 0:
|
||||||
"Can't double all integer types."
|
assert max(ts.floats) < MAX_BITS, "Can't double all float types."
|
||||||
if len(ts.floats) > 0:
|
if len(ts.bools) > 0:
|
||||||
assert max(ts.floats) < MAX_BITS,\
|
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
|
||||||
"Can't double all float types."
|
|
||||||
if len(ts.bools) > 0:
|
|
||||||
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
|
|
||||||
|
|
||||||
return TypeVar.derived(self, self.DOUBLEWIDTH)
|
return TypeVar.derived(self, self.DOUBLEWIDTH)
|
||||||
|
|
||||||
@@ -442,9 +561,8 @@ class TypeVar(object):
|
|||||||
Return a derived type variable that has half the number of vector lanes
|
Return a derived type variable that has half the number of vector lanes
|
||||||
as this one, with the same lane type.
|
as this one, with the same lane type.
|
||||||
"""
|
"""
|
||||||
if not self.is_derived:
|
ts = self.get_typeset()
|
||||||
ts = self.type_set
|
assert min(ts.lanes) > 1, "Can't halve a scalar type"
|
||||||
assert min(ts.lanes) > 1, "Can't halve a scalar type"
|
|
||||||
|
|
||||||
return TypeVar.derived(self, self.HALFVECTOR)
|
return TypeVar.derived(self, self.HALFVECTOR)
|
||||||
|
|
||||||
@@ -454,12 +572,23 @@ class TypeVar(object):
|
|||||||
Return a derived type variable that has twice the number of vector
|
Return a derived type variable that has twice the number of vector
|
||||||
lanes as this one, with the same lane type.
|
lanes as this one, with the same lane type.
|
||||||
"""
|
"""
|
||||||
if not self.is_derived:
|
ts = self.get_typeset()
|
||||||
ts = self.type_set
|
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
|
||||||
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
|
|
||||||
|
|
||||||
return TypeVar.derived(self, self.DOUBLEVECTOR)
|
return TypeVar.derived(self, self.DOUBLEVECTOR)
|
||||||
|
|
||||||
|
def singleton_type(self):
|
||||||
|
# type: () -> ValueType
|
||||||
|
"""
|
||||||
|
If the associated typeset has a single type return it. Otherwise return
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
ts = self.get_typeset()
|
||||||
|
if ts.size() != 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ts.get_singleton()
|
||||||
|
|
||||||
def free_typevar(self):
|
def free_typevar(self):
|
||||||
# type: () -> TypeVar
|
# type: () -> TypeVar
|
||||||
"""
|
"""
|
||||||
@@ -467,7 +596,7 @@ class TypeVar(object):
|
|||||||
"""
|
"""
|
||||||
if self.is_derived:
|
if self.is_derived:
|
||||||
return self.base
|
return self.base
|
||||||
elif self.singleton_type:
|
elif self.singleton_type() is not None:
|
||||||
# A singleton type variable is not a proper free variable.
|
# A singleton type variable is not a proper free variable.
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
@@ -481,8 +610,8 @@ class TypeVar(object):
|
|||||||
if self.is_derived:
|
if self.is_derived:
|
||||||
return '{}.{}()'.format(
|
return '{}.{}()'.format(
|
||||||
self.base.rust_expr(), self.derived_func)
|
self.base.rust_expr(), self.derived_func)
|
||||||
elif self.singleton_type:
|
elif self.singleton_type():
|
||||||
return self.singleton_type.rust_name()
|
return self.singleton_type().rust_name()
|
||||||
else:
|
else:
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
@@ -501,9 +630,6 @@ class TypeVar(object):
|
|||||||
|
|
||||||
if not a.is_derived and not b.is_derived:
|
if not a.is_derived and not b.is_derived:
|
||||||
a.type_set &= b.type_set
|
a.type_set &= b.type_set
|
||||||
# TODO: What if a.type_set becomes empty?
|
|
||||||
if not a.singleton_type:
|
|
||||||
a.singleton_type = b.singleton_type
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: Implement constraints for derived type variables.
|
# TODO: Implement constraints for derived type variables.
|
||||||
@@ -514,3 +640,29 @@ class TypeVar(object):
|
|||||||
#
|
#
|
||||||
# For the fully general case, we would need to compute an image typeset
|
# For the fully general case, we would need to compute an image typeset
|
||||||
# for `b` and propagate a `a.derived_func` pre-image to `a.base`.
|
# for `b` and propagate a `a.derived_func` pre-image to `a.base`.
|
||||||
|
|
||||||
|
def get_typeset(self):
|
||||||
|
# type: () -> TypeSet
|
||||||
|
"""
|
||||||
|
Returns the typeset for this TV. If the TV is derived, computes it
|
||||||
|
recursively from the derived function and the base's typeset.
|
||||||
|
"""
|
||||||
|
if not self.is_derived:
|
||||||
|
return self.type_set
|
||||||
|
else:
|
||||||
|
if (self.derived_func == TypeVar.SAMEAS):
|
||||||
|
return self.base.get_typeset()
|
||||||
|
elif (self.derived_func == TypeVar.LANEOF):
|
||||||
|
return self.base.get_typeset().lane_of()
|
||||||
|
elif (self.derived_func == TypeVar.ASBOOL):
|
||||||
|
return self.base.get_typeset().as_bool()
|
||||||
|
elif (self.derived_func == TypeVar.HALFWIDTH):
|
||||||
|
return self.base.get_typeset().half_width()
|
||||||
|
elif (self.derived_func == TypeVar.DOUBLEWIDTH):
|
||||||
|
return self.base.get_typeset().double_width()
|
||||||
|
elif (self.derived_func == TypeVar.HALFVECTOR):
|
||||||
|
return self.base.get_typeset().half_vector()
|
||||||
|
elif (self.derived_func == TypeVar.DOUBLEVECTOR):
|
||||||
|
return self.base.get_typeset().double_vector()
|
||||||
|
else:
|
||||||
|
assert False, "Unknown derived function: " + self.derived_func
|
||||||
|
|||||||
@@ -254,7 +254,7 @@ class XForm(object):
|
|||||||
# Some variables have a fixed type which appears as a type variable
|
# Some variables have a fixed type which appears as a type variable
|
||||||
# with a singleton_type field set. That's allowed for temps too.
|
# with a singleton_type field set. That's allowed for temps too.
|
||||||
for v in fvars:
|
for v in fvars:
|
||||||
if v.is_temp() and not v.typevar.singleton_type:
|
if v.is_temp() and not v.typevar.singleton_type():
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"Cannot determine type of temp '{}' in xform:\n{}"
|
"Cannot determine type of temp '{}' in xform:\n{}"
|
||||||
.format(v, self))
|
.format(v, self))
|
||||||
|
|||||||
@@ -321,8 +321,8 @@ def get_constraint(op, ctrl_typevar, type_sets):
|
|||||||
tv = op.typevar
|
tv = op.typevar
|
||||||
|
|
||||||
# A concrete value type.
|
# A concrete value type.
|
||||||
if tv.singleton_type:
|
if tv.singleton_type():
|
||||||
return 'Concrete({})'.format(tv.singleton_type.rust_name())
|
return 'Concrete({})'.format(tv.singleton_type().rust_name())
|
||||||
|
|
||||||
if tv.free_typevar() is not ctrl_typevar:
|
if tv.free_typevar() is not ctrl_typevar:
|
||||||
assert not tv.is_derived
|
assert not tv.is_derived
|
||||||
|
|||||||
Reference in New Issue
Block a user