Cleanup ValueType.get_names to with_bits form previous PR; Add computation of inverse image of typeset across a derived function - TypeSet.map_inverse; Change TypeVar.constrain_type to perform a more-general computation using inverse images of TypeSets; Tests for map_inverse;
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
7c298078c8
commit
c073d919f4
@@ -4,6 +4,8 @@ from doctest import DocTestSuite
|
|||||||
from . import typevar
|
from . import typevar
|
||||||
from .typevar import TypeSet, TypeVar
|
from .typevar import TypeSet, TypeVar
|
||||||
from base.types import i32, i16, b1, f64
|
from base.types import i32, i16, b1, f64
|
||||||
|
from itertools import product
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
|
||||||
def load_tests(loader, tests, ignore):
|
def load_tests(loader, tests, ignore):
|
||||||
@@ -123,6 +125,58 @@ class TestTypeSet(TestCase):
|
|||||||
self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(),
|
self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(),
|
||||||
i32.by(4))
|
i32.by(4))
|
||||||
|
|
||||||
|
def test_map_inverse(self):
|
||||||
|
t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32))
|
||||||
|
self.assertEqual(t, t.map_inverse(TypeVar.SAMEAS))
|
||||||
|
|
||||||
|
# LANEOF
|
||||||
|
self.assertEqual(TypeSet(lanes=True, ints=(8, 8), floats=(32, 32)),
|
||||||
|
t.map_inverse(TypeVar.LANEOF))
|
||||||
|
# Inverse of empty set is still empty across LANEOF
|
||||||
|
self.assertEqual(TypeSet(),
|
||||||
|
TypeSet().map_inverse(TypeVar.LANEOF))
|
||||||
|
|
||||||
|
# ASBOOL
|
||||||
|
t = TypeSet(lanes=(1, 4), bools=(1, 64))
|
||||||
|
self.assertEqual(t.map_inverse(TypeVar.ASBOOL),
|
||||||
|
TypeSet(lanes=(1, 4), ints=True, bools=True,
|
||||||
|
floats=True))
|
||||||
|
|
||||||
|
# Inverse image across ASBOOL of TS not involving b1 cannot have
|
||||||
|
# lanes=1
|
||||||
|
t = TypeSet(lanes=(1, 4), bools=(16, 32))
|
||||||
|
self.assertEqual(t.map_inverse(TypeVar.ASBOOL),
|
||||||
|
TypeSet(lanes=(2, 4), ints=(16, 32), bools=(16, 32),
|
||||||
|
floats=(32, 32)))
|
||||||
|
|
||||||
|
# Half/Double Vector
|
||||||
|
t = TypeSet(lanes=(1, 1), ints=(8, 8))
|
||||||
|
t1 = TypeSet(lanes=(256, 256), ints=(8, 8))
|
||||||
|
self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR).size(), 0)
|
||||||
|
self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR).size(), 0)
|
||||||
|
|
||||||
|
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 32))
|
||||||
|
t1 = TypeSet(lanes=(64, 256), bools=(1, 32))
|
||||||
|
|
||||||
|
self.assertEqual(t.map_inverse(TypeVar.DOUBLEVECTOR),
|
||||||
|
TypeSet(lanes=(1, 8), ints=(8, 16), floats=(32, 32)))
|
||||||
|
self.assertEqual(t1.map_inverse(TypeVar.HALFVECTOR),
|
||||||
|
TypeSet(lanes=(128, 256), bools=(1, 32)))
|
||||||
|
|
||||||
|
# Half/Double Width
|
||||||
|
t = TypeSet(ints=(8, 8), floats=(32, 32), bools=(1, 8))
|
||||||
|
t1 = TypeSet(ints=(64, 64), floats=(64, 64), bools=(64, 64))
|
||||||
|
self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH).size(), 0)
|
||||||
|
self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH).size(), 0)
|
||||||
|
|
||||||
|
t = TypeSet(lanes=(1, 16), ints=(8, 16), floats=(32, 64))
|
||||||
|
t1 = TypeSet(lanes=(64, 256), bools=(1, 64))
|
||||||
|
|
||||||
|
self.assertEqual(t.map_inverse(TypeVar.DOUBLEWIDTH),
|
||||||
|
TypeSet(lanes=(1, 16), ints=(8, 8), floats=(32, 32)))
|
||||||
|
self.assertEqual(t1.map_inverse(TypeVar.HALFWIDTH),
|
||||||
|
TypeSet(lanes=(64, 256), bools=(16, 64)))
|
||||||
|
|
||||||
|
|
||||||
class TestTypeVar(TestCase):
|
class TestTypeVar(TestCase):
|
||||||
def test_functions(self):
|
def test_functions(self):
|
||||||
@@ -164,3 +218,57 @@ class TestTypeVar(TestCase):
|
|||||||
self.assertEqual(max(x.type_set.lanes), 4)
|
self.assertEqual(max(x.type_set.lanes), 4)
|
||||||
self.assertEqual(len(x.type_set.floats), 0)
|
self.assertEqual(len(x.type_set.floats), 0)
|
||||||
self.assertEqual(len(x.type_set.bools), 0)
|
self.assertEqual(len(x.type_set.bools), 0)
|
||||||
|
|
||||||
|
def test_stress_constrain_types(self):
|
||||||
|
# Get all 49 possible derived vars of lentgh 2. Since we have SAMEAS
|
||||||
|
# this includes singly derived and non-derived vars
|
||||||
|
funcs = [TypeVar.SAMEAS, TypeVar.LANEOF,
|
||||||
|
TypeVar.ASBOOL, TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR,
|
||||||
|
TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
|
||||||
|
v = list(product(*[funcs, funcs]))
|
||||||
|
|
||||||
|
# For each pair of derived variables
|
||||||
|
for (i1, i2) in product(v, v):
|
||||||
|
# Compute the derived sets for each starting with a full typeset
|
||||||
|
full_ts = TypeSet(lanes=True, floats=True, ints=True, bools=True)
|
||||||
|
ts1 = reduce(lambda ts, func: ts.map(func), i1, full_ts)
|
||||||
|
ts2 = reduce(lambda ts, func: ts.map(func), i2, full_ts)
|
||||||
|
|
||||||
|
# Compute intersection
|
||||||
|
intersect = ts1.copy()
|
||||||
|
intersect &= ts2
|
||||||
|
|
||||||
|
# Propagate instersections backward
|
||||||
|
ts1_src = reduce(lambda ts, func: ts.map_inverse(func),
|
||||||
|
reversed(i1),
|
||||||
|
intersect)
|
||||||
|
ts2_src = reduce(lambda ts, func: ts.map_inverse(func),
|
||||||
|
reversed(i2),
|
||||||
|
intersect)
|
||||||
|
|
||||||
|
# If the intersection or its propagated forms are empty, then these
|
||||||
|
# two variables can never overlap. For example x.double_vector and
|
||||||
|
# x.lane_of.
|
||||||
|
if (intersect.size() == 0 or ts1_src.size() == 0 or
|
||||||
|
ts2_src.size() == 0):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Should be safe to create derived tvs from ts1_src and ts2_src
|
||||||
|
tv1 = reduce(lambda tv, func: TypeVar.derived(tv, func),
|
||||||
|
i1,
|
||||||
|
TypeVar.from_typeset(ts1_src))
|
||||||
|
|
||||||
|
tv2 = reduce(lambda tv, func: TypeVar.derived(tv, func),
|
||||||
|
i2,
|
||||||
|
TypeVar.from_typeset(ts2_src))
|
||||||
|
|
||||||
|
# The typesets of the two derived variables should be subsets of
|
||||||
|
# the intersection we computed originally
|
||||||
|
assert tv1.get_typeset().issubset(intersect)
|
||||||
|
assert tv2.get_typeset().issubset(intersect)
|
||||||
|
|
||||||
|
# In the absence of AS_BOOL map(map_inverse(f)) == f so the
|
||||||
|
# typesets of tv1 and tv2 should be exactly intersection
|
||||||
|
assert (tv1.get_typeset() == tv2.get_typeset() and
|
||||||
|
tv1.get_typeset() == intersect) or\
|
||||||
|
TypeVar.ASBOOL in set(i1 + i2)
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ from __future__ import absolute_import
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Dict, List # noqa
|
from typing import Dict, List, cast, TYPE_CHECKING # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
TYPE_CHECKING = False
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -97,7 +98,7 @@ class VectorType(ValueType):
|
|||||||
# type: (ScalarType, int) -> None
|
# type: (ScalarType, int) -> None
|
||||||
assert isinstance(base, ScalarType), 'SIMD lanes must be scalar types'
|
assert isinstance(base, ScalarType), 'SIMD lanes must be scalar types'
|
||||||
super(VectorType, self).__init__(
|
super(VectorType, self).__init__(
|
||||||
name=VectorType.get_name(base, lanes),
|
name='{}x{}'.format(base.name, lanes),
|
||||||
membytes=lanes*base.membytes,
|
membytes=lanes*base.membytes,
|
||||||
doc="""
|
doc="""
|
||||||
A SIMD vector with {} lanes containing a `{}` each.
|
A SIMD vector with {} lanes containing a `{}` each.
|
||||||
@@ -111,11 +112,6 @@ class VectorType(ValueType):
|
|||||||
return ('VectorType(base={}, lanes={})'
|
return ('VectorType(base={}, lanes={})'
|
||||||
.format(self.base.name, self.lanes))
|
.format(self.base.name, self.lanes))
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_name(base, lanes):
|
|
||||||
# type: (ValueType, int) -> str
|
|
||||||
return '{}x{}'.format(base.name, lanes)
|
|
||||||
|
|
||||||
|
|
||||||
class IntType(ScalarType):
|
class IntType(ScalarType):
|
||||||
"""A concrete scalar integer type."""
|
"""A concrete scalar integer type."""
|
||||||
@@ -124,7 +120,7 @@ class IntType(ScalarType):
|
|||||||
# type: (int) -> None
|
# type: (int) -> None
|
||||||
assert bits > 0, 'IntType must have positive number of bits'
|
assert bits > 0, 'IntType must have positive number of bits'
|
||||||
super(IntType, self).__init__(
|
super(IntType, self).__init__(
|
||||||
name=IntType.get_name(bits),
|
name='i{:d}'.format(bits),
|
||||||
membytes=bits // 8,
|
membytes=bits // 8,
|
||||||
doc="An integer type with {} bits.".format(bits))
|
doc="An integer type with {} bits.".format(bits))
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
@@ -134,9 +130,13 @@ class IntType(ScalarType):
|
|||||||
return 'IntType(bits={})'.format(self.bits)
|
return 'IntType(bits={})'.format(self.bits)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name(bits):
|
def with_bits(bits):
|
||||||
# type: (int) -> str
|
# type: (int) -> IntType
|
||||||
return 'i{:d}'.format(bits)
|
typ = ValueType.by_name('i{:d}'.format(bits))
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
return cast(IntType, typ)
|
||||||
|
else:
|
||||||
|
return typ
|
||||||
|
|
||||||
|
|
||||||
class FloatType(ScalarType):
|
class FloatType(ScalarType):
|
||||||
@@ -146,7 +146,7 @@ class FloatType(ScalarType):
|
|||||||
# type: (int, str) -> None
|
# type: (int, str) -> None
|
||||||
assert bits > 0, 'FloatType must have positive number of bits'
|
assert bits > 0, 'FloatType must have positive number of bits'
|
||||||
super(FloatType, self).__init__(
|
super(FloatType, self).__init__(
|
||||||
name=FloatType.get_name(bits),
|
name='f{:d}'.format(bits),
|
||||||
membytes=bits // 8,
|
membytes=bits // 8,
|
||||||
doc=doc)
|
doc=doc)
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
@@ -156,9 +156,13 @@ class FloatType(ScalarType):
|
|||||||
return 'FloatType(bits={})'.format(self.bits)
|
return 'FloatType(bits={})'.format(self.bits)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name(bits):
|
def with_bits(bits):
|
||||||
# type: (int) -> str
|
# type: (int) -> FloatType
|
||||||
return 'f{:d}'.format(bits)
|
typ = ValueType.by_name('f{:d}'.format(bits))
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
return cast(FloatType, typ)
|
||||||
|
else:
|
||||||
|
return typ
|
||||||
|
|
||||||
|
|
||||||
class BoolType(ScalarType):
|
class BoolType(ScalarType):
|
||||||
@@ -168,7 +172,7 @@ class BoolType(ScalarType):
|
|||||||
# type: (int) -> None
|
# type: (int) -> None
|
||||||
assert bits > 0, 'BoolType must have positive number of bits'
|
assert bits > 0, 'BoolType must have positive number of bits'
|
||||||
super(BoolType, self).__init__(
|
super(BoolType, self).__init__(
|
||||||
name=BoolType.get_name(bits),
|
name='b{:d}'.format(bits),
|
||||||
membytes=bits // 8,
|
membytes=bits // 8,
|
||||||
doc="A boolean type with {} bits.".format(bits))
|
doc="A boolean type with {} bits.".format(bits))
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
@@ -178,6 +182,10 @@ class BoolType(ScalarType):
|
|||||||
return 'BoolType(bits={})'.format(self.bits)
|
return 'BoolType(bits={})'.format(self.bits)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name(bits):
|
def with_bits(bits):
|
||||||
# type: (int) -> str
|
# type: (int) -> BoolType
|
||||||
return 'b{:d}'.format(bits)
|
typ = ValueType.by_name('b{:d}'.format(bits))
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
return cast(BoolType, typ)
|
||||||
|
else:
|
||||||
|
return typ
|
||||||
|
|||||||
@@ -8,18 +8,17 @@ from __future__ import absolute_import
|
|||||||
import math
|
import math
|
||||||
from . import types, is_power_of_two
|
from . import types, is_power_of_two
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from .types import ValueType, IntType, FloatType, BoolType
|
from .types import IntType, FloatType, BoolType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
|
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
|
||||||
from typing import cast
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from srcgen import Formatter # noqa
|
from srcgen import Formatter # noqa
|
||||||
|
from .types import ValueType # noqa
|
||||||
Interval = Tuple[int, int]
|
Interval = Tuple[int, int]
|
||||||
# An Interval where `True` means 'everything'
|
# An Interval where `True` means 'everything'
|
||||||
BoolInterval = Union[bool, Interval]
|
BoolInterval = Union[bool, Interval]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
TYPE_CHECKING = False
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
MAX_LANES = 256
|
MAX_LANES = 256
|
||||||
@@ -263,6 +262,16 @@ class TypeSet(object):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def issubset(self, other):
|
||||||
|
# type: (TypeSet) -> bool
|
||||||
|
"""
|
||||||
|
Return true iff self is a subset of other
|
||||||
|
"""
|
||||||
|
return self.lanes.issubset(other.lanes) and \
|
||||||
|
self.ints.issubset(other.ints) and \
|
||||||
|
self.floats.issubset(other.floats) and \
|
||||||
|
self.bools.issubset(other.bools)
|
||||||
|
|
||||||
def lane_of(self):
|
def lane_of(self):
|
||||||
# type: () -> TypeSet
|
# type: () -> TypeSet
|
||||||
"""
|
"""
|
||||||
@@ -292,9 +301,9 @@ class TypeSet(object):
|
|||||||
Return a TypeSet describing the image of self across halfwidth
|
Return a TypeSet describing the image of self across halfwidth
|
||||||
"""
|
"""
|
||||||
new = self.copy()
|
new = self.copy()
|
||||||
new.ints = set([x/2 for x in self.ints if x > 8])
|
new.ints = set([x//2 for x in self.ints if x > 8])
|
||||||
new.floats = set([x/2 for x in self.floats if x > 32])
|
new.floats = set([x//2 for x in self.floats if x > 32])
|
||||||
new.bools = set([x/2 for x in self.bools if x > 8])
|
new.bools = set([x//2 for x in self.bools if x > 8])
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@@ -317,7 +326,7 @@ class TypeSet(object):
|
|||||||
Return a TypeSet describing the image of self across halfvector
|
Return a TypeSet describing the image of self across halfvector
|
||||||
"""
|
"""
|
||||||
new = self.copy()
|
new = self.copy()
|
||||||
new.lanes = set([x/2 for x in self.lanes if x > 1])
|
new.lanes = set([x//2 for x in self.lanes if x > 1])
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@@ -331,6 +340,67 @@ class TypeSet(object):
|
|||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
def map(self, func):
|
||||||
|
# type: (str) -> TypeSet
|
||||||
|
"""
|
||||||
|
Return the image of self across the derived function func
|
||||||
|
"""
|
||||||
|
if (func == TypeVar.SAMEAS):
|
||||||
|
return self
|
||||||
|
elif (func == TypeVar.LANEOF):
|
||||||
|
return self.lane_of()
|
||||||
|
elif (func == TypeVar.ASBOOL):
|
||||||
|
return self.as_bool()
|
||||||
|
elif (func == TypeVar.HALFWIDTH):
|
||||||
|
return self.half_width()
|
||||||
|
elif (func == TypeVar.DOUBLEWIDTH):
|
||||||
|
return self.double_width()
|
||||||
|
elif (func == TypeVar.HALFVECTOR):
|
||||||
|
return self.half_vector()
|
||||||
|
elif (func == TypeVar.DOUBLEVECTOR):
|
||||||
|
return self.double_vector()
|
||||||
|
else:
|
||||||
|
assert False, "Unknown derived function: " + func
|
||||||
|
|
||||||
|
def map_inverse(self, func):
|
||||||
|
# type: (str) -> TypeSet
|
||||||
|
"""
|
||||||
|
Return the inverse image of self across the derived function func
|
||||||
|
"""
|
||||||
|
# The inverse of the empty set is always empty
|
||||||
|
if (self.size() == 0):
|
||||||
|
return self
|
||||||
|
|
||||||
|
if (func == TypeVar.SAMEAS):
|
||||||
|
return self
|
||||||
|
elif (func == TypeVar.LANEOF):
|
||||||
|
new = self.copy()
|
||||||
|
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
|
||||||
|
return new
|
||||||
|
elif (func == TypeVar.ASBOOL):
|
||||||
|
new = self.copy()
|
||||||
|
new.ints = self.bools.difference(set([1]))
|
||||||
|
new.floats = self.bools.intersection(set([32, 64]))
|
||||||
|
|
||||||
|
if 1 not in self.bools:
|
||||||
|
try:
|
||||||
|
# If the range doesn't have b1, then the domain can't
|
||||||
|
# include scalars, as as_bool(scalar)=b1
|
||||||
|
new.lanes.remove(1)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return new
|
||||||
|
elif (func == TypeVar.HALFWIDTH):
|
||||||
|
return self.double_width()
|
||||||
|
elif (func == TypeVar.DOUBLEWIDTH):
|
||||||
|
return self.half_width()
|
||||||
|
elif (func == TypeVar.HALFVECTOR):
|
||||||
|
return self.double_vector()
|
||||||
|
elif (func == TypeVar.DOUBLEVECTOR):
|
||||||
|
return self.half_vector()
|
||||||
|
else:
|
||||||
|
assert False, "Unknown derived function: " + func
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
# type: () -> int
|
# type: () -> int
|
||||||
"""
|
"""
|
||||||
@@ -346,23 +416,18 @@ class TypeSet(object):
|
|||||||
typesets containing 1 type.
|
typesets containing 1 type.
|
||||||
"""
|
"""
|
||||||
assert self.size() == 1
|
assert self.size() == 1
|
||||||
|
scalar_type = None # type: types.ScalarType
|
||||||
if len(self.ints) > 0:
|
if len(self.ints) > 0:
|
||||||
bits = tuple(self.ints)[0]
|
scalar_type = IntType.with_bits(tuple(self.ints)[0])
|
||||||
scalar_type = ValueType.by_name(IntType.get_name(bits))
|
|
||||||
elif len(self.floats) > 0:
|
elif len(self.floats) > 0:
|
||||||
bits = tuple(self.floats)[0]
|
scalar_type = FloatType.with_bits(tuple(self.floats)[0])
|
||||||
scalar_type = ValueType.by_name(FloatType.get_name(bits))
|
|
||||||
else:
|
else:
|
||||||
bits = tuple(self.bools)[0]
|
scalar_type = BoolType.with_bits(tuple(self.bools)[0])
|
||||||
scalar_type = ValueType.by_name(BoolType.get_name(bits))
|
|
||||||
|
|
||||||
nlanes = tuple(self.lanes)[0]
|
nlanes = tuple(self.lanes)[0]
|
||||||
|
|
||||||
if nlanes == 1:
|
if nlanes == 1:
|
||||||
return scalar_type
|
return scalar_type
|
||||||
else:
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
return cast(types.ScalarType, scalar_type).by(nlanes)
|
|
||||||
else:
|
else:
|
||||||
return scalar_type.by(nlanes)
|
return scalar_type.by(nlanes)
|
||||||
|
|
||||||
@@ -483,6 +548,14 @@ class TypeVar(object):
|
|||||||
"""Create a type variable that is a function of another."""
|
"""Create a type variable that is a function of another."""
|
||||||
return TypeVar(None, None, base=base, derived_func=derived_func)
|
return TypeVar(None, None, base=base, derived_func=derived_func)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_typeset(ts):
|
||||||
|
# type: (TypeSet) -> TypeVar
|
||||||
|
""" Create a type variable from a type set."""
|
||||||
|
tv = TypeVar(None, None)
|
||||||
|
tv.type_set = ts
|
||||||
|
return tv
|
||||||
|
|
||||||
def change_to_derived(self, base, derived_func):
|
def change_to_derived(self, base, derived_func):
|
||||||
# type: (TypeVar, str) -> None
|
# type: (TypeVar, str) -> None
|
||||||
"""Change this type variable into a derived one."""
|
"""Change this type variable into a derived one."""
|
||||||
@@ -615,6 +688,17 @@ class TypeVar(object):
|
|||||||
else:
|
else:
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
def constrain_types_by_ts(self, ts):
|
||||||
|
# type: (TypeSet) -> None
|
||||||
|
"""
|
||||||
|
Constrain the range of types this variable can assume to a subset of
|
||||||
|
those in the typeset ts.
|
||||||
|
"""
|
||||||
|
if not self.is_derived:
|
||||||
|
self.type_set &= ts
|
||||||
|
else:
|
||||||
|
self.base.constrain_types_by_ts(ts.map_inverse(self.derived_func))
|
||||||
|
|
||||||
def constrain_types(self, other):
|
def constrain_types(self, other):
|
||||||
# type: (TypeVar) -> None
|
# type: (TypeVar) -> None
|
||||||
"""
|
"""
|
||||||
@@ -628,18 +712,7 @@ class TypeVar(object):
|
|||||||
if a is b:
|
if a is b:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not a.is_derived and not b.is_derived:
|
a.constrain_types_by_ts(b.get_typeset())
|
||||||
a.type_set &= b.type_set
|
|
||||||
return
|
|
||||||
|
|
||||||
# TODO: Implement constraints for derived type variables.
|
|
||||||
#
|
|
||||||
# If a and b are both derived with the same derived_func, we could say
|
|
||||||
# `a.base.constrain_types(b.base)`, but unless the derived_func is
|
|
||||||
# injective, that may constrain `a.base` more than necessary.
|
|
||||||
#
|
|
||||||
# For the fully general case, we would need to compute an image typeset
|
|
||||||
# for `b` and propagate a `a.derived_func` pre-image to `a.base`.
|
|
||||||
|
|
||||||
def get_typeset(self):
|
def get_typeset(self):
|
||||||
# type: () -> TypeSet
|
# type: () -> TypeSet
|
||||||
|
|||||||
Reference in New Issue
Block a user