Add type annotations to TypeVar
This commit is contained in:
@@ -6,11 +6,13 @@ polymorphic by using type variables.
|
|||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
import math
|
import math
|
||||||
from . import value
|
from . import OperandKind, value # noqa
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Tuple # noqa
|
from typing import Tuple, Union # noqa
|
||||||
Interval = Tuple[int, int]
|
Interval = Tuple[int, int]
|
||||||
|
# An Interval where `True` means 'everything'
|
||||||
|
BoolInterval = Union[bool, Interval]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -47,6 +49,34 @@ def intersect(a, b):
|
|||||||
return (None, None)
|
return (None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_interval(intv, full_range, default=None):
|
||||||
|
# type: (BoolInterval, Interval, int) -> Interval
|
||||||
|
"""
|
||||||
|
Decode an interval specification which can take the following values:
|
||||||
|
|
||||||
|
True
|
||||||
|
Use the `full_range`.
|
||||||
|
`False` or `None`
|
||||||
|
An empty interval
|
||||||
|
(lo, hi)
|
||||||
|
An explicit interval
|
||||||
|
"""
|
||||||
|
if isinstance(intv, tuple):
|
||||||
|
# mypy buig here: 'builtins.None' object is not iterable
|
||||||
|
lo, hi = intv # type: ignore
|
||||||
|
assert is_power_of_two(lo)
|
||||||
|
assert is_power_of_two(hi)
|
||||||
|
assert lo <= hi
|
||||||
|
assert lo >= full_range[0]
|
||||||
|
assert hi <= full_range[1]
|
||||||
|
return intv
|
||||||
|
|
||||||
|
if intv:
|
||||||
|
return full_range
|
||||||
|
else:
|
||||||
|
return (default, default)
|
||||||
|
|
||||||
|
|
||||||
class TypeSet(object):
|
class TypeSet(object):
|
||||||
"""
|
"""
|
||||||
A set of types.
|
A set of types.
|
||||||
@@ -95,55 +125,12 @@ class TypeSet(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lanes=None, ints=None, floats=None, bools=None):
|
def __init__(self, lanes=None, ints=None, floats=None, bools=None):
|
||||||
# type: (Interval, Interval, Interval, Interval) -> None
|
# type: (BoolInterval, BoolInterval, BoolInterval, BoolInterval) -> None # noqa
|
||||||
if lanes:
|
self.min_lanes, self.max_lanes = decode_interval(
|
||||||
if lanes is True:
|
lanes, (1, MAX_LANES), 1)
|
||||||
lanes = (1, MAX_LANES)
|
self.min_int, self.max_int = decode_interval(ints, (8, MAX_BITS))
|
||||||
self.min_lanes, self.max_lanes = lanes
|
self.min_float, self.max_float = decode_interval(floats, (32, 64))
|
||||||
assert is_power_of_two(self.min_lanes)
|
self.min_bool, self.max_bool = decode_interval(bools, (1, MAX_BITS))
|
||||||
assert is_power_of_two(self.max_lanes)
|
|
||||||
assert self.max_lanes <= MAX_LANES
|
|
||||||
else:
|
|
||||||
self.min_lanes = 1
|
|
||||||
self.max_lanes = 1
|
|
||||||
assert self.min_lanes <= self.max_lanes
|
|
||||||
|
|
||||||
if ints:
|
|
||||||
if ints is True:
|
|
||||||
ints = (8, MAX_BITS)
|
|
||||||
self.min_int, self.max_int = ints
|
|
||||||
assert is_power_of_two(self.min_int)
|
|
||||||
assert is_power_of_two(self.max_int)
|
|
||||||
assert self.max_int <= MAX_BITS
|
|
||||||
assert self.min_int <= self.max_int
|
|
||||||
else:
|
|
||||||
self.min_int = None
|
|
||||||
self.max_int = None
|
|
||||||
|
|
||||||
if floats:
|
|
||||||
if floats is True:
|
|
||||||
floats = (32, 64)
|
|
||||||
self.min_float, self.max_float = floats
|
|
||||||
assert is_power_of_two(self.min_float)
|
|
||||||
assert self.min_float >= 32
|
|
||||||
assert is_power_of_two(self.max_float)
|
|
||||||
assert self.max_float <= 64
|
|
||||||
assert self.min_float <= self.max_float
|
|
||||||
else:
|
|
||||||
self.min_float = None
|
|
||||||
self.max_float = None
|
|
||||||
|
|
||||||
if bools:
|
|
||||||
if bools is True:
|
|
||||||
bools = (1, MAX_BITS)
|
|
||||||
self.min_bool, self.max_bool = bools
|
|
||||||
assert is_power_of_two(self.min_bool)
|
|
||||||
assert is_power_of_two(self.max_bool)
|
|
||||||
assert self.max_bool <= MAX_BITS
|
|
||||||
assert self.min_bool <= self.max_bool
|
|
||||||
else:
|
|
||||||
self.min_bool = None
|
|
||||||
self.max_bool = None
|
|
||||||
|
|
||||||
def typeset_key(self):
|
def typeset_key(self):
|
||||||
# type: () -> Tuple[int, int, int, int, int, int, int, int]
|
# type: () -> Tuple[int, int, int, int, int, int, int, int]
|
||||||
@@ -253,6 +240,7 @@ class TypeVar(object):
|
|||||||
ints=False, floats=False, bools=False,
|
ints=False, floats=False, bools=False,
|
||||||
scalars=True, simd=False,
|
scalars=True, simd=False,
|
||||||
base=None, derived_func=None):
|
base=None, derived_func=None):
|
||||||
|
# type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, TypeVar, str) -> None # noqa
|
||||||
self.name = name
|
self.name = name
|
||||||
self.__doc__ = doc
|
self.__doc__ = doc
|
||||||
self.is_derived = isinstance(base, TypeVar)
|
self.is_derived = isinstance(base, TypeVar)
|
||||||
@@ -264,25 +252,19 @@ class TypeVar(object):
|
|||||||
self.name = '{}({})'.format(derived_func, base.name)
|
self.name = '{}({})'.format(derived_func, base.name)
|
||||||
else:
|
else:
|
||||||
min_lanes = 1 if scalars else 2
|
min_lanes = 1 if scalars else 2
|
||||||
if simd:
|
lanes = decode_interval(simd, (min_lanes, MAX_LANES), 1)
|
||||||
if simd is True:
|
|
||||||
max_lanes = MAX_LANES
|
|
||||||
else:
|
|
||||||
min_lanes, max_lanes = simd
|
|
||||||
assert not scalars or min_lanes <= 2
|
|
||||||
else:
|
|
||||||
max_lanes = 1
|
|
||||||
|
|
||||||
self.type_set = TypeSet(
|
self.type_set = TypeSet(
|
||||||
lanes=(min_lanes, max_lanes),
|
lanes=lanes,
|
||||||
ints=ints,
|
ints=ints,
|
||||||
floats=floats,
|
floats=floats,
|
||||||
bools=bools)
|
bools=bools)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
# type: () -> str
|
||||||
return "`{}`".format(self.name)
|
return "`{}`".format(self.name)
|
||||||
|
|
||||||
def lane_of(self):
|
def lane_of(self):
|
||||||
|
# type: () -> TypeVar
|
||||||
"""
|
"""
|
||||||
Return a derived type variable that is the scalar lane type of this
|
Return a derived type variable that is the scalar lane type of this
|
||||||
type variable.
|
type variable.
|
||||||
@@ -293,6 +275,7 @@ class TypeVar(object):
|
|||||||
return TypeVar(None, None, base=self, derived_func='LaneOf')
|
return TypeVar(None, None, base=self, derived_func='LaneOf')
|
||||||
|
|
||||||
def as_bool(self):
|
def as_bool(self):
|
||||||
|
# type: () -> TypeVar
|
||||||
"""
|
"""
|
||||||
Return a derived type variable that has the same vector geometry as
|
Return a derived type variable that has the same vector geometry as
|
||||||
this type variable, but with boolean lanes. Scalar types map to `b1`.
|
this type variable, but with boolean lanes. Scalar types map to `b1`.
|
||||||
@@ -300,6 +283,7 @@ class TypeVar(object):
|
|||||||
return TypeVar(None, None, base=self, derived_func='AsBool')
|
return TypeVar(None, None, base=self, derived_func='AsBool')
|
||||||
|
|
||||||
def half_width(self):
|
def half_width(self):
|
||||||
|
# type: () -> TypeVar
|
||||||
"""
|
"""
|
||||||
Return a derived type variable that has the same number of vector lanes
|
Return a derived type variable that has the same number of vector lanes
|
||||||
as this one, but the lanes are half the width.
|
as this one, but the lanes are half the width.
|
||||||
@@ -315,6 +299,7 @@ class TypeVar(object):
|
|||||||
return TypeVar(None, None, base=self, derived_func='HalfWidth')
|
return TypeVar(None, None, base=self, derived_func='HalfWidth')
|
||||||
|
|
||||||
def double_width(self):
|
def double_width(self):
|
||||||
|
# type: () -> TypeVar
|
||||||
"""
|
"""
|
||||||
Return a derived type variable that has the same number of vector lanes
|
Return a derived type variable that has the same number of vector lanes
|
||||||
as this one, but the lanes are double the width.
|
as this one, but the lanes are double the width.
|
||||||
@@ -330,12 +315,14 @@ class TypeVar(object):
|
|||||||
return TypeVar(None, None, base=self, derived_func='DoubleWidth')
|
return TypeVar(None, None, base=self, derived_func='DoubleWidth')
|
||||||
|
|
||||||
def operand_kind(self):
|
def operand_kind(self):
|
||||||
|
# type: () -> OperandKind
|
||||||
# When a `TypeVar` object is used to describe the type of an `Operand`
|
# When a `TypeVar` object is used to describe the type of an `Operand`
|
||||||
# in an instruction definition, the kind of that operand is an SSA
|
# in an instruction definition, the kind of that operand is an SSA
|
||||||
# value.
|
# value.
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def free_typevar(self):
|
def free_typevar(self):
|
||||||
|
# type: () -> TypeVar
|
||||||
if self.is_derived:
|
if self.is_derived:
|
||||||
return self.base
|
return self.base
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user