Fix the handling of special types in type variables.

- Allow the syntax "specials=True" to indicate that a type variable can
  assume all special types. Use this for the unconstrained type variable
  created in ast.py.
- Fix TypeSet.copy() to avoid deepcopy() which doesn't do the right
  thing for the self.specials set.
- Fix TypeSet.typeset_key() to just use the name of special types
  instead of the full SpecialType objects.
This commit is contained in:
Jakob Stoklund Olesen
2018-01-16 10:26:13 -08:00
parent 85aab278dd
commit ce4cc8ce12
2 changed files with 35 additions and 15 deletions

View File

@@ -247,7 +247,8 @@ class Var(Atom):
'typeof_{}'.format(self), 'typeof_{}'.format(self),
'Type of the pattern variable `{}`'.format(self), 'Type of the pattern variable `{}`'.format(self),
ints=True, floats=True, bools=True, ints=True, floats=True, bools=True,
scalars=True, simd=True, bitvecs=True) scalars=True, simd=True, bitvecs=True,
specials=True)
self.original_typevar = tv self.original_typevar = tv
self.typevar = tv self.typevar = tv
return self.typevar return self.typevar

View File

@@ -7,16 +7,17 @@ polymorphic by using type variables.
from __future__ import absolute_import from __future__ import absolute_import
import math import math
from . import types, is_power_of_two from . import types, is_power_of_two
from copy import deepcopy from copy import copy
try: try:
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
if TYPE_CHECKING: if TYPE_CHECKING:
from srcgen import Formatter # noqa from srcgen import Formatter # noqa
from .types import ValueType # noqa
Interval = Tuple[int, int] Interval = Tuple[int, int]
# An Interval where `True` means 'everything' # An Interval where `True` means 'everything'
BoolInterval = Union[bool, Interval] BoolInterval = Union[bool, Interval]
# Set of special types: None, False, True, or iterable.
SpecialSpec = Union[bool, Iterable[types.SpecialType]]
except ImportError: except ImportError:
pass pass
@@ -88,7 +89,7 @@ def decode_interval(intv, full_range, default=None):
An explicit interval An explicit interval
""" """
if isinstance(intv, tuple): if isinstance(intv, tuple):
# mypy buig here: 'builtins.None' object is not iterable # mypy bug here: 'builtins.None' object is not iterable
lo, hi = intv lo, hi = intv
assert is_power_of_two(lo) assert is_power_of_two(lo)
assert is_power_of_two(hi) assert is_power_of_two(hi)
@@ -175,7 +176,7 @@ class TypeSet(object):
widths. widths.
:param bitvecs : `(min, max)` inclusive range of permitted bitvector :param bitvecs : `(min, max)` inclusive range of permitted bitvector
widths. widths.
:param specials: Sequence of speical types to appear in the set. :param specials: Sequence of special types to appear in the set.
""" """
def __init__( def __init__(
@@ -185,7 +186,7 @@ class TypeSet(object):
floats=None, # type: BoolInterval floats=None, # type: BoolInterval
bools=None, # type: BoolInterval bools=None, # type: BoolInterval
bitvecs=None, # type: BoolInterval bitvecs=None, # type: BoolInterval
specials=None # type: Iterable[types.SpecialType] specials=None # type: SpecialSpec
): ):
# type: (...) -> None # type: (...) -> None
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1)) self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
@@ -195,15 +196,27 @@ class TypeSet(object):
self.bools = set(filter(legal_bool, self.bools)) self.bools = set(filter(legal_bool, self.bools))
self.bitvecs = interval_to_set(decode_interval(bitvecs, self.bitvecs = interval_to_set(decode_interval(bitvecs,
(1, MAX_BITVEC))) (1, MAX_BITVEC)))
self.specials = set(specials) if specials else set() # Allow specials=None, specials=True, specials=(...)
self.specials = set() # type: Set[types.SpecialType]
if isinstance(specials, bool):
if specials:
self.specials = set(types.ValueType.all_special_types)
elif specials:
self.specials = set(specials)
def copy(self): def copy(self):
# type: (TypeSet) -> TypeSet # type: (TypeSet) -> TypeSet
""" """
Return a copy of our self. deepcopy is sufficient and safe here, since Return a copy of our self.
TypeSet contains only sets of numbers.
""" """
return deepcopy(self) n = TypeSet()
n.lanes = copy(self.lanes)
n.ints = copy(self.ints)
n.floats = copy(self.floats)
n.bools = copy(self.bools)
n.bitvecs = copy(self.bitvecs)
n.specials = copy(self.specials)
return n
def typeset_key(self): def typeset_key(self):
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple, Tuple, Tuple] # type: () -> Tuple[Tuple, Tuple, Tuple, Tuple, Tuple, Tuple]
@@ -213,7 +226,7 @@ class TypeSet(object):
tuple(sorted(list(self.floats))), tuple(sorted(list(self.floats))),
tuple(sorted(list(self.bools))), tuple(sorted(list(self.bools))),
tuple(sorted(list(self.bitvecs))), tuple(sorted(list(self.bitvecs))),
tuple(sorted(list(self.specials)))) tuple(sorted(s.name for s in self.specials)))
def __hash__(self): def __hash__(self):
# type: () -> int # type: () -> int
@@ -301,7 +314,8 @@ class TypeSet(object):
self.ints.issubset(other.ints) and \ self.ints.issubset(other.ints) and \
self.floats.issubset(other.floats) and \ self.floats.issubset(other.floats) and \
self.bools.issubset(other.bools) and \ self.bools.issubset(other.bools) and \
self.bitvecs.issubset(other.bitvecs) self.bitvecs.issubset(other.bitvecs) and \
self.specials.issubset(other.specials)
def lane_of(self): def lane_of(self):
# type: () -> TypeSet # type: () -> TypeSet
@@ -340,6 +354,7 @@ class TypeSet(object):
new.floats = set([x//2 for x in self.floats if x > 32]) new.floats = set([x//2 for x in self.floats if x > 32])
new.bools = set([x//2 for x in self.bools if x > 8]) new.bools = set([x//2 for x in self.bools if x > 8])
new.bitvecs = set([x//2 for x in self.bitvecs if x > 1]) new.bitvecs = set([x//2 for x in self.bitvecs if x > 1])
new.specials = set()
return new return new
@@ -354,6 +369,7 @@ class TypeSet(object):
new.bools = set(filter(legal_bool, new.bools = set(filter(legal_bool,
set([x*2 for x in self.bools if x < MAX_BITS]))) set([x*2 for x in self.bools if x < MAX_BITS])))
new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC]) new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC])
new.specials = set()
return new return new
@@ -365,6 +381,7 @@ class TypeSet(object):
new = self.copy() new = self.copy()
new.bitvecs = set() new.bitvecs = set()
new.lanes = set([x//2 for x in self.lanes if x > 1]) new.lanes = set([x//2 for x in self.lanes if x > 1])
new.specials = set()
return new return new
@@ -376,6 +393,7 @@ class TypeSet(object):
new = self.copy() new = self.copy()
new.bitvecs = set() new.bitvecs = set()
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES]) new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
new.specials = set()
return new return new
@@ -394,6 +412,7 @@ class TypeSet(object):
new.floats = set() new.floats = set()
new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars
for nlanes in self.lanes]) for nlanes in self.lanes])
new.specials = set()
return new return new
@@ -576,7 +595,7 @@ class TypeVar(object):
bitvecs=False, # type: BoolInterval bitvecs=False, # type: BoolInterval
base=None, # type: TypeVar base=None, # type: TypeVar
derived_func=None, # type: str derived_func=None, # type: str
specials=None # type: Iterable[types.SpecialType] specials=None # type: SpecialSpec
): ):
# type: (...) -> None # type: (...) -> None
self.name = name self.name = name
@@ -603,7 +622,7 @@ class TypeVar(object):
def singleton(typ): def singleton(typ):
# type: (types.ValueType) -> TypeVar # type: (types.ValueType) -> TypeVar
"""Create a type variable that can only assume a single type.""" """Create a type variable that can only assume a single type."""
scalar = None # type: ValueType scalar = None # type: types.ValueType
if isinstance(typ, types.VectorType): if isinstance(typ, types.VectorType):
scalar = typ.base scalar = typ.base
lanes = (typ.lanes, typ.lanes) lanes = (typ.lanes, typ.lanes)
@@ -806,7 +825,7 @@ class TypeVar(object):
return TypeVar.derived(self, self.TOBITVEC) return TypeVar.derived(self, self.TOBITVEC)
def singleton_type(self): def singleton_type(self):
# type: () -> ValueType # type: () -> types.ValueType
""" """
If the associated typeset has a single type return it. Otherwise return If the associated typeset has a single type return it. Otherwise return
None None