Fix the handling of special types in type variables.

- Allow the syntax "specials=True" to indicate that a type variable can
  assume all special types. Use this for the unconstrained type variable
  created in ast.py.
- Fix TypeSet.copy() to avoid deepcopy() which doesn't do the right
  thing for the self.specials set.
- Fix TypeSet.typeset_key() to just use the name of special types
  instead of the full SpecialType objects.
This commit is contained in:
Jakob Stoklund Olesen
2018-01-16 10:26:13 -08:00
parent 85aab278dd
commit ce4cc8ce12
2 changed files with 35 additions and 15 deletions

View File

@@ -247,7 +247,8 @@ class Var(Atom):
'typeof_{}'.format(self),
'Type of the pattern variable `{}`'.format(self),
ints=True, floats=True, bools=True,
scalars=True, simd=True, bitvecs=True)
scalars=True, simd=True, bitvecs=True,
specials=True)
self.original_typevar = tv
self.typevar = tv
return self.typevar

View File

@@ -7,16 +7,17 @@ polymorphic by using type variables.
from __future__ import absolute_import
import math
from . import types, is_power_of_two
from copy import deepcopy
from copy import copy
try:
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
if TYPE_CHECKING:
from srcgen import Formatter # noqa
from .types import ValueType # noqa
Interval = Tuple[int, int]
# An Interval where `True` means 'everything'
BoolInterval = Union[bool, Interval]
# Set of special types: None, False, True, or iterable.
SpecialSpec = Union[bool, Iterable[types.SpecialType]]
except ImportError:
pass
@@ -88,7 +89,7 @@ def decode_interval(intv, full_range, default=None):
An explicit interval
"""
if isinstance(intv, tuple):
# mypy buig here: 'builtins.None' object is not iterable
# mypy bug here: 'builtins.None' object is not iterable
lo, hi = intv
assert is_power_of_two(lo)
assert is_power_of_two(hi)
@@ -175,7 +176,7 @@ class TypeSet(object):
widths.
:param bitvecs : `(min, max)` inclusive range of permitted bitvector
widths.
:param specials: Sequence of speical types to appear in the set.
:param specials: Sequence of special types to appear in the set.
"""
def __init__(
@@ -185,7 +186,7 @@ class TypeSet(object):
floats=None, # type: BoolInterval
bools=None, # type: BoolInterval
bitvecs=None, # type: BoolInterval
specials=None # type: Iterable[types.SpecialType]
specials=None # type: SpecialSpec
):
# type: (...) -> None
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
@@ -195,15 +196,27 @@ class TypeSet(object):
self.bools = set(filter(legal_bool, self.bools))
self.bitvecs = interval_to_set(decode_interval(bitvecs,
(1, MAX_BITVEC)))
self.specials = set(specials) if specials else set()
# Allow specials=None, specials=True, specials=(...)
self.specials = set() # type: Set[types.SpecialType]
if isinstance(specials, bool):
if specials:
self.specials = set(types.ValueType.all_special_types)
elif specials:
self.specials = set(specials)
def copy(self):
# type: (TypeSet) -> TypeSet
"""
Return a copy of our self. deepcopy is sufficient and safe here, since
TypeSet contains only sets of numbers.
Return a copy of our self.
"""
return deepcopy(self)
n = TypeSet()
n.lanes = copy(self.lanes)
n.ints = copy(self.ints)
n.floats = copy(self.floats)
n.bools = copy(self.bools)
n.bitvecs = copy(self.bitvecs)
n.specials = copy(self.specials)
return n
def typeset_key(self):
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple, Tuple, Tuple]
@@ -213,7 +226,7 @@ class TypeSet(object):
tuple(sorted(list(self.floats))),
tuple(sorted(list(self.bools))),
tuple(sorted(list(self.bitvecs))),
tuple(sorted(list(self.specials))))
tuple(sorted(s.name for s in self.specials)))
def __hash__(self):
# type: () -> int
@@ -301,7 +314,8 @@ class TypeSet(object):
self.ints.issubset(other.ints) and \
self.floats.issubset(other.floats) and \
self.bools.issubset(other.bools) and \
self.bitvecs.issubset(other.bitvecs)
self.bitvecs.issubset(other.bitvecs) and \
self.specials.issubset(other.specials)
def lane_of(self):
# type: () -> TypeSet
@@ -340,6 +354,7 @@ class TypeSet(object):
new.floats = set([x//2 for x in self.floats if x > 32])
new.bools = set([x//2 for x in self.bools if x > 8])
new.bitvecs = set([x//2 for x in self.bitvecs if x > 1])
new.specials = set()
return new
@@ -354,6 +369,7 @@ class TypeSet(object):
new.bools = set(filter(legal_bool,
set([x*2 for x in self.bools if x < MAX_BITS])))
new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC])
new.specials = set()
return new
@@ -365,6 +381,7 @@ class TypeSet(object):
new = self.copy()
new.bitvecs = set()
new.lanes = set([x//2 for x in self.lanes if x > 1])
new.specials = set()
return new
@@ -376,6 +393,7 @@ class TypeSet(object):
new = self.copy()
new.bitvecs = set()
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
new.specials = set()
return new
@@ -394,6 +412,7 @@ class TypeSet(object):
new.floats = set()
new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars
for nlanes in self.lanes])
new.specials = set()
return new
@@ -576,7 +595,7 @@ class TypeVar(object):
bitvecs=False, # type: BoolInterval
base=None, # type: TypeVar
derived_func=None, # type: str
specials=None # type: Iterable[types.SpecialType]
specials=None # type: SpecialSpec
):
# type: (...) -> None
self.name = name
@@ -603,7 +622,7 @@ class TypeVar(object):
def singleton(typ):
# type: (types.ValueType) -> TypeVar
"""Create a type variable that can only assume a single type."""
scalar = None # type: ValueType
scalar = None # type: types.ValueType
if isinstance(typ, types.VectorType):
scalar = typ.base
lanes = (typ.lanes, typ.lanes)
@@ -806,7 +825,7 @@ class TypeVar(object):
return TypeVar.derived(self, self.TOBITVEC)
def singleton_type(self):
# type: () -> ValueType
# type: () -> types.ValueType
"""
If the associated typeset has a single type return it. Otherwise return
None