Fix the handling of special types in type variables.
- Allow the syntax "specials=True" to indicate that a type variable can assume all special types. Use this for the unconstrained type variable created in ast.py. - Fix TypeSet.copy() to avoid deepcopy() which doesn't do the right thing for the self.specials set. - Fix TypeSet.typeset_key() to just use the name of special types instead of the full SpecialType objects.
This commit is contained in:
@@ -247,7 +247,8 @@ class Var(Atom):
|
||||
'typeof_{}'.format(self),
|
||||
'Type of the pattern variable `{}`'.format(self),
|
||||
ints=True, floats=True, bools=True,
|
||||
scalars=True, simd=True, bitvecs=True)
|
||||
scalars=True, simd=True, bitvecs=True,
|
||||
specials=True)
|
||||
self.original_typevar = tv
|
||||
self.typevar = tv
|
||||
return self.typevar
|
||||
|
||||
@@ -7,16 +7,17 @@ polymorphic by using type variables.
|
||||
from __future__ import absolute_import
|
||||
import math
|
||||
from . import types, is_power_of_two
|
||||
from copy import deepcopy
|
||||
from copy import copy
|
||||
|
||||
try:
|
||||
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
|
||||
if TYPE_CHECKING:
|
||||
from srcgen import Formatter # noqa
|
||||
from .types import ValueType # noqa
|
||||
Interval = Tuple[int, int]
|
||||
# An Interval where `True` means 'everything'
|
||||
BoolInterval = Union[bool, Interval]
|
||||
# Set of special types: None, False, True, or iterable.
|
||||
SpecialSpec = Union[bool, Iterable[types.SpecialType]]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -88,7 +89,7 @@ def decode_interval(intv, full_range, default=None):
|
||||
An explicit interval
|
||||
"""
|
||||
if isinstance(intv, tuple):
|
||||
# mypy buig here: 'builtins.None' object is not iterable
|
||||
# mypy bug here: 'builtins.None' object is not iterable
|
||||
lo, hi = intv
|
||||
assert is_power_of_two(lo)
|
||||
assert is_power_of_two(hi)
|
||||
@@ -175,7 +176,7 @@ class TypeSet(object):
|
||||
widths.
|
||||
:param bitvecs : `(min, max)` inclusive range of permitted bitvector
|
||||
widths.
|
||||
:param specials: Sequence of speical types to appear in the set.
|
||||
:param specials: Sequence of special types to appear in the set.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -185,7 +186,7 @@ class TypeSet(object):
|
||||
floats=None, # type: BoolInterval
|
||||
bools=None, # type: BoolInterval
|
||||
bitvecs=None, # type: BoolInterval
|
||||
specials=None # type: Iterable[types.SpecialType]
|
||||
specials=None # type: SpecialSpec
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
|
||||
@@ -195,15 +196,27 @@ class TypeSet(object):
|
||||
self.bools = set(filter(legal_bool, self.bools))
|
||||
self.bitvecs = interval_to_set(decode_interval(bitvecs,
|
||||
(1, MAX_BITVEC)))
|
||||
self.specials = set(specials) if specials else set()
|
||||
# Allow specials=None, specials=True, specials=(...)
|
||||
self.specials = set() # type: Set[types.SpecialType]
|
||||
if isinstance(specials, bool):
|
||||
if specials:
|
||||
self.specials = set(types.ValueType.all_special_types)
|
||||
elif specials:
|
||||
self.specials = set(specials)
|
||||
|
||||
def copy(self):
|
||||
# type: (TypeSet) -> TypeSet
|
||||
"""
|
||||
Return a copy of our self. deepcopy is sufficient and safe here, since
|
||||
TypeSet contains only sets of numbers.
|
||||
Return a copy of our self.
|
||||
"""
|
||||
return deepcopy(self)
|
||||
n = TypeSet()
|
||||
n.lanes = copy(self.lanes)
|
||||
n.ints = copy(self.ints)
|
||||
n.floats = copy(self.floats)
|
||||
n.bools = copy(self.bools)
|
||||
n.bitvecs = copy(self.bitvecs)
|
||||
n.specials = copy(self.specials)
|
||||
return n
|
||||
|
||||
def typeset_key(self):
|
||||
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple, Tuple, Tuple]
|
||||
@@ -213,7 +226,7 @@ class TypeSet(object):
|
||||
tuple(sorted(list(self.floats))),
|
||||
tuple(sorted(list(self.bools))),
|
||||
tuple(sorted(list(self.bitvecs))),
|
||||
tuple(sorted(list(self.specials))))
|
||||
tuple(sorted(s.name for s in self.specials)))
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
@@ -301,7 +314,8 @@ class TypeSet(object):
|
||||
self.ints.issubset(other.ints) and \
|
||||
self.floats.issubset(other.floats) and \
|
||||
self.bools.issubset(other.bools) and \
|
||||
self.bitvecs.issubset(other.bitvecs)
|
||||
self.bitvecs.issubset(other.bitvecs) and \
|
||||
self.specials.issubset(other.specials)
|
||||
|
||||
def lane_of(self):
|
||||
# type: () -> TypeSet
|
||||
@@ -340,6 +354,7 @@ class TypeSet(object):
|
||||
new.floats = set([x//2 for x in self.floats if x > 32])
|
||||
new.bools = set([x//2 for x in self.bools if x > 8])
|
||||
new.bitvecs = set([x//2 for x in self.bitvecs if x > 1])
|
||||
new.specials = set()
|
||||
|
||||
return new
|
||||
|
||||
@@ -354,6 +369,7 @@ class TypeSet(object):
|
||||
new.bools = set(filter(legal_bool,
|
||||
set([x*2 for x in self.bools if x < MAX_BITS])))
|
||||
new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC])
|
||||
new.specials = set()
|
||||
|
||||
return new
|
||||
|
||||
@@ -365,6 +381,7 @@ class TypeSet(object):
|
||||
new = self.copy()
|
||||
new.bitvecs = set()
|
||||
new.lanes = set([x//2 for x in self.lanes if x > 1])
|
||||
new.specials = set()
|
||||
|
||||
return new
|
||||
|
||||
@@ -376,6 +393,7 @@ class TypeSet(object):
|
||||
new = self.copy()
|
||||
new.bitvecs = set()
|
||||
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
|
||||
new.specials = set()
|
||||
|
||||
return new
|
||||
|
||||
@@ -394,6 +412,7 @@ class TypeSet(object):
|
||||
new.floats = set()
|
||||
new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars
|
||||
for nlanes in self.lanes])
|
||||
new.specials = set()
|
||||
|
||||
return new
|
||||
|
||||
@@ -576,7 +595,7 @@ class TypeVar(object):
|
||||
bitvecs=False, # type: BoolInterval
|
||||
base=None, # type: TypeVar
|
||||
derived_func=None, # type: str
|
||||
specials=None # type: Iterable[types.SpecialType]
|
||||
specials=None # type: SpecialSpec
|
||||
):
|
||||
# type: (...) -> None
|
||||
self.name = name
|
||||
@@ -603,7 +622,7 @@ class TypeVar(object):
|
||||
def singleton(typ):
|
||||
# type: (types.ValueType) -> TypeVar
|
||||
"""Create a type variable that can only assume a single type."""
|
||||
scalar = None # type: ValueType
|
||||
scalar = None # type: types.ValueType
|
||||
if isinstance(typ, types.VectorType):
|
||||
scalar = typ.base
|
||||
lanes = (typ.lanes, typ.lanes)
|
||||
@@ -806,7 +825,7 @@ class TypeVar(object):
|
||||
return TypeVar.derived(self, self.TOBITVEC)
|
||||
|
||||
def singleton_type(self):
|
||||
# type: () -> ValueType
|
||||
# type: () -> types.ValueType
|
||||
"""
|
||||
If the associated typeset has a single type return it. Otherwise return
|
||||
None
|
||||
|
||||
Reference in New Issue
Block a user