diff --git a/lib/cretonne/meta/cdsl/ast.py b/lib/cretonne/meta/cdsl/ast.py index d6c0f42952..38d633e17d 100644 --- a/lib/cretonne/meta/cdsl/ast.py +++ b/lib/cretonne/meta/cdsl/ast.py @@ -247,7 +247,8 @@ class Var(Atom): 'typeof_{}'.format(self), 'Type of the pattern variable `{}`'.format(self), ints=True, floats=True, bools=True, - scalars=True, simd=True, bitvecs=True) + scalars=True, simd=True, bitvecs=True, + specials=True) self.original_typevar = tv self.typevar = tv return self.typevar diff --git a/lib/cretonne/meta/cdsl/typevar.py b/lib/cretonne/meta/cdsl/typevar.py index f818719a4d..8cda58785e 100644 --- a/lib/cretonne/meta/cdsl/typevar.py +++ b/lib/cretonne/meta/cdsl/typevar.py @@ -7,16 +7,17 @@ polymorphic by using type variables. from __future__ import absolute_import import math from . import types, is_power_of_two -from copy import deepcopy +from copy import copy try: from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa if TYPE_CHECKING: from srcgen import Formatter # noqa - from .types import ValueType # noqa Interval = Tuple[int, int] # An Interval where `True` means 'everything' BoolInterval = Union[bool, Interval] + # Set of special types: None, False, True, or iterable. + SpecialSpec = Union[bool, Iterable[types.SpecialType]] except ImportError: pass @@ -88,7 +89,7 @@ def decode_interval(intv, full_range, default=None): An explicit interval """ if isinstance(intv, tuple): - # mypy buig here: 'builtins.None' object is not iterable + # mypy bug here: 'builtins.None' object is not iterable lo, hi = intv assert is_power_of_two(lo) assert is_power_of_two(hi) @@ -175,7 +176,7 @@ class TypeSet(object): widths. :param bitvecs : `(min, max)` inclusive range of permitted bitvector widths. - :param specials: Sequence of speical types to appear in the set. + :param specials: Sequence of special types to appear in the set. """ def __init__( @@ -185,7 +186,7 @@ class TypeSet(object): floats=None, # type: BoolInterval bools=None, # type: BoolInterval bitvecs=None, # type: BoolInterval - specials=None # type: Iterable[types.SpecialType] + specials=None # type: SpecialSpec ): # type: (...) -> None self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1)) @@ -195,15 +196,27 @@ class TypeSet(object): self.bools = set(filter(legal_bool, self.bools)) self.bitvecs = interval_to_set(decode_interval(bitvecs, (1, MAX_BITVEC))) - self.specials = set(specials) if specials else set() + # Allow specials=None, specials=True, specials=(...) + self.specials = set() # type: Set[types.SpecialType] + if isinstance(specials, bool): + if specials: + self.specials = set(types.ValueType.all_special_types) + elif specials: + self.specials = set(specials) def copy(self): # type: (TypeSet) -> TypeSet """ - Return a copy of our self. deepcopy is sufficient and safe here, since - TypeSet contains only sets of numbers. + Return a copy of our self. """ - return deepcopy(self) + n = TypeSet() + n.lanes = copy(self.lanes) + n.ints = copy(self.ints) + n.floats = copy(self.floats) + n.bools = copy(self.bools) + n.bitvecs = copy(self.bitvecs) + n.specials = copy(self.specials) + return n def typeset_key(self): # type: () -> Tuple[Tuple, Tuple, Tuple, Tuple, Tuple, Tuple] @@ -213,7 +226,7 @@ class TypeSet(object): tuple(sorted(list(self.floats))), tuple(sorted(list(self.bools))), tuple(sorted(list(self.bitvecs))), - tuple(sorted(list(self.specials)))) + tuple(sorted(s.name for s in self.specials))) def __hash__(self): # type: () -> int @@ -301,7 +314,8 @@ class TypeSet(object): self.ints.issubset(other.ints) and \ self.floats.issubset(other.floats) and \ self.bools.issubset(other.bools) and \ - self.bitvecs.issubset(other.bitvecs) + self.bitvecs.issubset(other.bitvecs) and \ + self.specials.issubset(other.specials) def lane_of(self): # type: () -> TypeSet @@ -340,6 +354,7 @@ class TypeSet(object): new.floats = set([x//2 for x in self.floats if x > 32]) new.bools = set([x//2 for x in self.bools if x > 8]) new.bitvecs = set([x//2 for x in self.bitvecs if x > 1]) + new.specials = set() return new @@ -354,6 +369,7 @@ class TypeSet(object): new.bools = set(filter(legal_bool, set([x*2 for x in self.bools if x < MAX_BITS]))) new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC]) + new.specials = set() return new @@ -365,6 +381,7 @@ class TypeSet(object): new = self.copy() new.bitvecs = set() new.lanes = set([x//2 for x in self.lanes if x > 1]) + new.specials = set() return new @@ -376,6 +393,7 @@ class TypeSet(object): new = self.copy() new.bitvecs = set() new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES]) + new.specials = set() return new @@ -394,6 +412,7 @@ class TypeSet(object): new.floats = set() new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars for nlanes in self.lanes]) + new.specials = set() return new @@ -576,7 +595,7 @@ class TypeVar(object): bitvecs=False, # type: BoolInterval base=None, # type: TypeVar derived_func=None, # type: str - specials=None # type: Iterable[types.SpecialType] + specials=None # type: SpecialSpec ): # type: (...) -> None self.name = name @@ -603,7 +622,7 @@ class TypeVar(object): def singleton(typ): # type: (types.ValueType) -> TypeVar """Create a type variable that can only assume a single type.""" - scalar = None # type: ValueType + scalar = None # type: types.ValueType if isinstance(typ, types.VectorType): scalar = typ.base lanes = (typ.lanes, typ.lanes) @@ -806,7 +825,7 @@ class TypeVar(object): return TypeVar.derived(self, self.TOBITVEC) def singleton_type(self): - # type: () -> ValueType + # type: () -> types.ValueType """ If the associated typeset has a single type return it. Otherwise return None