Fix the handling of special types in type variables.
- Allow the syntax "specials=True" to indicate that a type variable can assume all special types. Use this for the unconstrained type variable created in ast.py. - Fix TypeSet.copy() to avoid deepcopy() which doesn't do the right thing for the self.specials set. - Fix TypeSet.typeset_key() to just use the name of special types instead of the full SpecialType objects.
This commit is contained in:
@@ -247,7 +247,8 @@ class Var(Atom):
|
|||||||
'typeof_{}'.format(self),
|
'typeof_{}'.format(self),
|
||||||
'Type of the pattern variable `{}`'.format(self),
|
'Type of the pattern variable `{}`'.format(self),
|
||||||
ints=True, floats=True, bools=True,
|
ints=True, floats=True, bools=True,
|
||||||
scalars=True, simd=True, bitvecs=True)
|
scalars=True, simd=True, bitvecs=True,
|
||||||
|
specials=True)
|
||||||
self.original_typevar = tv
|
self.original_typevar = tv
|
||||||
self.typevar = tv
|
self.typevar = tv
|
||||||
return self.typevar
|
return self.typevar
|
||||||
|
|||||||
@@ -7,16 +7,17 @@ polymorphic by using type variables.
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
import math
|
import math
|
||||||
from . import types, is_power_of_two
|
from . import types, is_power_of_two
|
||||||
from copy import deepcopy
|
from copy import copy
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
|
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from srcgen import Formatter # noqa
|
from srcgen import Formatter # noqa
|
||||||
from .types import ValueType # noqa
|
|
||||||
Interval = Tuple[int, int]
|
Interval = Tuple[int, int]
|
||||||
# An Interval where `True` means 'everything'
|
# An Interval where `True` means 'everything'
|
||||||
BoolInterval = Union[bool, Interval]
|
BoolInterval = Union[bool, Interval]
|
||||||
|
# Set of special types: None, False, True, or iterable.
|
||||||
|
SpecialSpec = Union[bool, Iterable[types.SpecialType]]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -88,7 +89,7 @@ def decode_interval(intv, full_range, default=None):
|
|||||||
An explicit interval
|
An explicit interval
|
||||||
"""
|
"""
|
||||||
if isinstance(intv, tuple):
|
if isinstance(intv, tuple):
|
||||||
# mypy buig here: 'builtins.None' object is not iterable
|
# mypy bug here: 'builtins.None' object is not iterable
|
||||||
lo, hi = intv
|
lo, hi = intv
|
||||||
assert is_power_of_two(lo)
|
assert is_power_of_two(lo)
|
||||||
assert is_power_of_two(hi)
|
assert is_power_of_two(hi)
|
||||||
@@ -175,7 +176,7 @@ class TypeSet(object):
|
|||||||
widths.
|
widths.
|
||||||
:param bitvecs : `(min, max)` inclusive range of permitted bitvector
|
:param bitvecs : `(min, max)` inclusive range of permitted bitvector
|
||||||
widths.
|
widths.
|
||||||
:param specials: Sequence of speical types to appear in the set.
|
:param specials: Sequence of special types to appear in the set.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -185,7 +186,7 @@ class TypeSet(object):
|
|||||||
floats=None, # type: BoolInterval
|
floats=None, # type: BoolInterval
|
||||||
bools=None, # type: BoolInterval
|
bools=None, # type: BoolInterval
|
||||||
bitvecs=None, # type: BoolInterval
|
bitvecs=None, # type: BoolInterval
|
||||||
specials=None # type: Iterable[types.SpecialType]
|
specials=None # type: SpecialSpec
|
||||||
):
|
):
|
||||||
# type: (...) -> None
|
# type: (...) -> None
|
||||||
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
|
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
|
||||||
@@ -195,15 +196,27 @@ class TypeSet(object):
|
|||||||
self.bools = set(filter(legal_bool, self.bools))
|
self.bools = set(filter(legal_bool, self.bools))
|
||||||
self.bitvecs = interval_to_set(decode_interval(bitvecs,
|
self.bitvecs = interval_to_set(decode_interval(bitvecs,
|
||||||
(1, MAX_BITVEC)))
|
(1, MAX_BITVEC)))
|
||||||
self.specials = set(specials) if specials else set()
|
# Allow specials=None, specials=True, specials=(...)
|
||||||
|
self.specials = set() # type: Set[types.SpecialType]
|
||||||
|
if isinstance(specials, bool):
|
||||||
|
if specials:
|
||||||
|
self.specials = set(types.ValueType.all_special_types)
|
||||||
|
elif specials:
|
||||||
|
self.specials = set(specials)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
# type: (TypeSet) -> TypeSet
|
# type: (TypeSet) -> TypeSet
|
||||||
"""
|
"""
|
||||||
Return a copy of our self. deepcopy is sufficient and safe here, since
|
Return a copy of our self.
|
||||||
TypeSet contains only sets of numbers.
|
|
||||||
"""
|
"""
|
||||||
return deepcopy(self)
|
n = TypeSet()
|
||||||
|
n.lanes = copy(self.lanes)
|
||||||
|
n.ints = copy(self.ints)
|
||||||
|
n.floats = copy(self.floats)
|
||||||
|
n.bools = copy(self.bools)
|
||||||
|
n.bitvecs = copy(self.bitvecs)
|
||||||
|
n.specials = copy(self.specials)
|
||||||
|
return n
|
||||||
|
|
||||||
def typeset_key(self):
|
def typeset_key(self):
|
||||||
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple, Tuple, Tuple]
|
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple, Tuple, Tuple]
|
||||||
@@ -213,7 +226,7 @@ class TypeSet(object):
|
|||||||
tuple(sorted(list(self.floats))),
|
tuple(sorted(list(self.floats))),
|
||||||
tuple(sorted(list(self.bools))),
|
tuple(sorted(list(self.bools))),
|
||||||
tuple(sorted(list(self.bitvecs))),
|
tuple(sorted(list(self.bitvecs))),
|
||||||
tuple(sorted(list(self.specials))))
|
tuple(sorted(s.name for s in self.specials)))
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
# type: () -> int
|
# type: () -> int
|
||||||
@@ -301,7 +314,8 @@ class TypeSet(object):
|
|||||||
self.ints.issubset(other.ints) and \
|
self.ints.issubset(other.ints) and \
|
||||||
self.floats.issubset(other.floats) and \
|
self.floats.issubset(other.floats) and \
|
||||||
self.bools.issubset(other.bools) and \
|
self.bools.issubset(other.bools) and \
|
||||||
self.bitvecs.issubset(other.bitvecs)
|
self.bitvecs.issubset(other.bitvecs) and \
|
||||||
|
self.specials.issubset(other.specials)
|
||||||
|
|
||||||
def lane_of(self):
|
def lane_of(self):
|
||||||
# type: () -> TypeSet
|
# type: () -> TypeSet
|
||||||
@@ -340,6 +354,7 @@ class TypeSet(object):
|
|||||||
new.floats = set([x//2 for x in self.floats if x > 32])
|
new.floats = set([x//2 for x in self.floats if x > 32])
|
||||||
new.bools = set([x//2 for x in self.bools if x > 8])
|
new.bools = set([x//2 for x in self.bools if x > 8])
|
||||||
new.bitvecs = set([x//2 for x in self.bitvecs if x > 1])
|
new.bitvecs = set([x//2 for x in self.bitvecs if x > 1])
|
||||||
|
new.specials = set()
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@@ -354,6 +369,7 @@ class TypeSet(object):
|
|||||||
new.bools = set(filter(legal_bool,
|
new.bools = set(filter(legal_bool,
|
||||||
set([x*2 for x in self.bools if x < MAX_BITS])))
|
set([x*2 for x in self.bools if x < MAX_BITS])))
|
||||||
new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC])
|
new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC])
|
||||||
|
new.specials = set()
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@@ -365,6 +381,7 @@ class TypeSet(object):
|
|||||||
new = self.copy()
|
new = self.copy()
|
||||||
new.bitvecs = set()
|
new.bitvecs = set()
|
||||||
new.lanes = set([x//2 for x in self.lanes if x > 1])
|
new.lanes = set([x//2 for x in self.lanes if x > 1])
|
||||||
|
new.specials = set()
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@@ -376,6 +393,7 @@ class TypeSet(object):
|
|||||||
new = self.copy()
|
new = self.copy()
|
||||||
new.bitvecs = set()
|
new.bitvecs = set()
|
||||||
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
|
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
|
||||||
|
new.specials = set()
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@@ -394,6 +412,7 @@ class TypeSet(object):
|
|||||||
new.floats = set()
|
new.floats = set()
|
||||||
new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars
|
new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars
|
||||||
for nlanes in self.lanes])
|
for nlanes in self.lanes])
|
||||||
|
new.specials = set()
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@@ -576,7 +595,7 @@ class TypeVar(object):
|
|||||||
bitvecs=False, # type: BoolInterval
|
bitvecs=False, # type: BoolInterval
|
||||||
base=None, # type: TypeVar
|
base=None, # type: TypeVar
|
||||||
derived_func=None, # type: str
|
derived_func=None, # type: str
|
||||||
specials=None # type: Iterable[types.SpecialType]
|
specials=None # type: SpecialSpec
|
||||||
):
|
):
|
||||||
# type: (...) -> None
|
# type: (...) -> None
|
||||||
self.name = name
|
self.name = name
|
||||||
@@ -603,7 +622,7 @@ class TypeVar(object):
|
|||||||
def singleton(typ):
|
def singleton(typ):
|
||||||
# type: (types.ValueType) -> TypeVar
|
# type: (types.ValueType) -> TypeVar
|
||||||
"""Create a type variable that can only assume a single type."""
|
"""Create a type variable that can only assume a single type."""
|
||||||
scalar = None # type: ValueType
|
scalar = None # type: types.ValueType
|
||||||
if isinstance(typ, types.VectorType):
|
if isinstance(typ, types.VectorType):
|
||||||
scalar = typ.base
|
scalar = typ.base
|
||||||
lanes = (typ.lanes, typ.lanes)
|
lanes = (typ.lanes, typ.lanes)
|
||||||
@@ -806,7 +825,7 @@ class TypeVar(object):
|
|||||||
return TypeVar.derived(self, self.TOBITVEC)
|
return TypeVar.derived(self, self.TOBITVEC)
|
||||||
|
|
||||||
def singleton_type(self):
|
def singleton_type(self):
|
||||||
# type: () -> ValueType
|
# type: () -> types.ValueType
|
||||||
"""
|
"""
|
||||||
If the associated typeset has a single type return it. Otherwise return
|
If the associated typeset has a single type return it. Otherwise return
|
||||||
None
|
None
|
||||||
|
|||||||
Reference in New Issue
Block a user