Convert interval sets inside TypeSet/ValueTypeSet in general sets (#102)

* Convert TypeSet fields to sets; Add BitSet<T> type to rust; Encode ValueTypeSets using BitSet; (still need mypy cleanup)

* nits

* cleanup nits

* forgot mypy type annotations

* rustfmt fixes

* Round 1 comments: filer b2, b4; doc comments in python; move bitset in its own toplevel module; Use Into<u32>

* fixes

* Revert comment to appease rustfmt
This commit is contained in:
d1m0
2017-06-22 16:47:14 -07:00
committed by Jakob Stoklund Olesen
parent cf967642a3
commit 4ebc0e8587
5 changed files with 292 additions and 131 deletions

View File

@@ -40,7 +40,7 @@ class TestTypeSet(TestCase):
a = TypeSet(lanes=True, ints=True, floats=True)
s = set()
s.add(a)
a.max_int = 32
a.ints.remove(64)
# Can't rehash after modification.
with self.assertRaises(AssertionError):
a in s
@@ -71,14 +71,18 @@ class TestTypeVar(TestCase):
def test_singleton(self):
x = TypeVar.singleton(i32)
self.assertEqual(str(x), '`i32`')
self.assertEqual(x.type_set.min_int, 32)
self.assertEqual(x.type_set.max_int, 32)
self.assertEqual(x.type_set.min_lanes, 1)
self.assertEqual(x.type_set.max_lanes, 1)
self.assertEqual(min(x.type_set.ints), 32)
self.assertEqual(max(x.type_set.ints), 32)
self.assertEqual(min(x.type_set.lanes), 1)
self.assertEqual(max(x.type_set.lanes), 1)
self.assertEqual(len(x.type_set.floats), 0)
self.assertEqual(len(x.type_set.bools), 0)
x = TypeVar.singleton(i32.by(4))
self.assertEqual(str(x), '`i32x4`')
self.assertEqual(x.type_set.min_int, 32)
self.assertEqual(x.type_set.max_int, 32)
self.assertEqual(x.type_set.min_lanes, 4)
self.assertEqual(x.type_set.max_lanes, 4)
self.assertEqual(min(x.type_set.ints), 32)
self.assertEqual(max(x.type_set.ints), 32)
self.assertEqual(min(x.type_set.lanes), 4)
self.assertEqual(max(x.type_set.lanes), 4)
self.assertEqual(len(x.type_set.floats), 0)
self.assertEqual(len(x.type_set.bools), 0)

View File

@@ -9,7 +9,7 @@ import math
from . import types, is_power_of_two
try:
from typing import Tuple, Union, TYPE_CHECKING # noqa
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
if TYPE_CHECKING:
from srcgen import Formatter # noqa
Interval = Tuple[int, int]
@@ -46,6 +46,32 @@ def intersect(a, b):
return (None, None)
def is_empty(intv):
# type: (Interval) -> bool
return intv is None or intv is False or intv == (None, None)
def encode_bitset(vals, size):
# type: (Iterable[int], int) -> int
"""
Encode a set of values (each between 0 and size) as a bitset of width size.
"""
res = 0
assert is_power_of_two(size) and size <= 64
for v in vals:
assert 0 <= v and v < size
res |= 1 << v
return res
def pp_set(s):
# type: (Iterable[Any]) -> str
"""
Return a consistent string representation of a set (ordering is fixed)
"""
return '{' + ', '.join([repr(x) for x in sorted(s)]) + '}'
def decode_interval(intv, full_range, default=None):
# type: (BoolInterval, Interval, int) -> Interval
"""
@@ -74,6 +100,18 @@ def decode_interval(intv, full_range, default=None):
return (default, default)
def interval_to_set(intv):
# type: (Interval) -> Set
if is_empty(intv):
return set()
(lo, hi) = intv
assert is_power_of_two(lo)
assert is_power_of_two(hi)
assert lo <= hi
return set([2**i for i in range(int_log2(lo), int_log2(hi)+1)])
class TypeSet(object):
"""
A set of types.
@@ -95,22 +133,22 @@ class TypeSet(object):
A typeset representing scalar integer types `i8` through `i32`:
>>> TypeSet(ints=(8, 32))
TypeSet(lanes=(1, 1), ints=(8, 32))
TypeSet(lanes={1}, ints={8, 16, 32})
Passing `True` instead of a range selects all available scalar types:
>>> TypeSet(ints=True)
TypeSet(lanes=(1, 1), ints=(8, 64))
TypeSet(lanes={1}, ints={8, 16, 32, 64})
>>> TypeSet(floats=True)
TypeSet(lanes=(1, 1), floats=(32, 64))
TypeSet(lanes={1}, floats={32, 64})
>>> TypeSet(bools=True)
TypeSet(lanes=(1, 1), bools=(1, 64))
TypeSet(lanes={1}, bools={1, 8, 16, 32, 64})
Similarly, passing `True` for the lanes selects all possible scalar and
vector types:
>>> TypeSet(lanes=True, ints=True)
TypeSet(lanes=(1, 256), ints=(8, 64))
TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256}, ints={8, 16, 32, 64})
:param lanes: `(min, max)` inclusive range of permitted vector lane counts.
:param ints: `(min, max)` inclusive range of permitted scalar integer
@@ -123,19 +161,19 @@ class TypeSet(object):
def __init__(self, lanes=None, ints=None, floats=None, bools=None):
# type: (BoolInterval, BoolInterval, BoolInterval, BoolInterval) -> None # noqa
self.min_lanes, self.max_lanes = decode_interval(
lanes, (1, MAX_LANES), 1)
self.min_int, self.max_int = decode_interval(ints, (8, MAX_BITS))
self.min_float, self.max_float = decode_interval(floats, (32, 64))
self.min_bool, self.max_bool = decode_interval(bools, (1, MAX_BITS))
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS)))
self.floats = interval_to_set(decode_interval(floats, (32, 64)))
self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS)))
self.bools = set(filter(lambda x: x == 1 or x >= 8, self.bools))
def typeset_key(self):
# type: () -> Tuple[int, int, int, int, int, int, int, int]
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple]
"""Key tuple used for hashing and equality."""
return (self.min_lanes, self.max_lanes,
self.min_int, self.max_int,
self.min_float, self.max_float,
self.min_bool, self.max_bool)
return (tuple(sorted(list(self.lanes))),
tuple(sorted(list(self.ints))),
tuple(sorted(list(self.floats))),
tuple(sorted(list(self.bools))))
def __hash__(self):
# type: () -> int
@@ -153,31 +191,29 @@ class TypeSet(object):
def __repr__(self):
# type: () -> str
s = 'TypeSet(lanes=({}, {})'.format(self.min_lanes, self.max_lanes)
if self.min_int is not None:
s += ', ints=({}, {})'.format(self.min_int, self.max_int)
if self.min_float is not None:
s += ', floats=({}, {})'.format(self.min_float, self.max_float)
if self.min_bool is not None:
s += ', bools=({}, {})'.format(self.min_bool, self.max_bool)
s = 'TypeSet(lanes={}'.format(pp_set(self.lanes))
if len(self.ints) > 0:
s += ', ints={}'.format(pp_set(self.ints))
if len(self.floats) > 0:
s += ', floats={}'.format(pp_set(self.floats))
if len(self.bools) > 0:
s += ', bools={}'.format(pp_set(self.bools))
return s + ')'
def emit_fields(self, fmt):
# type: (Formatter) -> None
"""Emit field initializers for this typeset."""
fmt.comment(repr(self))
fields = ('lanes', 'int', 'float', 'bool')
for field in fields:
min_val = getattr(self, 'min_' + field)
max_val = getattr(self, 'max_' + field)
if min_val is None:
fmt.line('min_{}: 0,'.format(field))
fmt.line('max_{}: 0,'.format(field))
else:
fmt.line('min_{}: {},'.format(
field, int_log2(min_val)))
fmt.line('max_{}: {},'.format(
field, int_log2(max_val) + 1))
fields = (('lanes', 16),
('ints', 8),
('floats', 8),
('bools', 8))
for (field, bits) in fields:
vals = [int_log2(x) for x in getattr(self, field)]
fmt.line('{}: BitSet::<u{}>({}),'
.format(field, bits, encode_bitset(vals, bits)))
def __iand__(self, other):
# type: (TypeSet) -> TypeSet
@@ -186,32 +222,22 @@ class TypeSet(object):
>>> a = TypeSet(lanes=True, ints=(16, 32))
>>> a
TypeSet(lanes=(1, 256), ints=(16, 32))
TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256}, ints={16, 32})
>>> b = TypeSet(lanes=(4, 16), ints=True)
>>> a &= b
>>> a
TypeSet(lanes=(4, 16), ints=(16, 32))
TypeSet(lanes={4, 8, 16}, ints={16, 32})
>>> a = TypeSet(lanes=True, bools=(1, 8))
>>> b = TypeSet(lanes=True, bools=(16, 32))
>>> a &= b
>>> a
TypeSet(lanes=(1, 256))
TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256})
"""
self.min_lanes = max(self.min_lanes, other.min_lanes)
self.max_lanes = min(self.max_lanes, other.max_lanes)
self.min_int, self.max_int = intersect(
(self.min_int, self.max_int),
(other.min_int, other.max_int))
self.min_float, self.max_float = intersect(
(self.min_float, self.max_float),
(other.min_float, other.max_float))
self.min_bool, self.max_bool = intersect(
(self.min_bool, self.max_bool),
(other.min_bool, other.max_bool))
self.lanes.intersection_update(other.lanes)
self.ints.intersection_update(other.ints)
self.floats.intersection_update(other.floats)
self.bools.intersection_update(other.bools)
return self
@@ -382,12 +408,12 @@ class TypeVar(object):
"""
if not self.is_derived:
ts = self.type_set
if ts.min_int:
assert ts.min_int > 8, "Can't halve all integer types"
if ts.min_float:
assert ts.min_float > 32, "Can't halve all float types"
if ts.min_bool:
assert ts.min_bool > 8, "Can't halve all boolean types"
if len(ts.ints) > 0:
assert min(ts.ints) > 8, "Can't halve all integer types"
if len(ts.floats) > 0:
assert min(ts.floats) > 32, "Can't halve all float types"
if len(ts.bools) > 0:
assert min(ts.bools) > 8, "Can't halve all boolean types"
return TypeVar.derived(self, self.HALFWIDTH)
@@ -399,12 +425,14 @@ class TypeVar(object):
"""
if not self.is_derived:
ts = self.type_set
if ts.max_int:
assert ts.max_int < MAX_BITS, "Can't double all integer types."
if ts.max_float:
assert ts.max_float < MAX_BITS, "Can't double all float types."
if ts.max_bool:
assert ts.max_bool < MAX_BITS, "Can't double all bool types."
if len(ts.ints) > 0:
assert max(ts.ints) < MAX_BITS,\
"Can't double all integer types."
if len(ts.floats) > 0:
assert max(ts.floats) < MAX_BITS,\
"Can't double all float types."
if len(ts.bools) > 0:
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
return TypeVar.derived(self, self.DOUBLEWIDTH)
@@ -416,7 +444,7 @@ class TypeVar(object):
"""
if not self.is_derived:
ts = self.type_set
assert ts.min_lanes > 1, "Can't halve a scalar type"
assert min(ts.lanes) > 1, "Can't halve a scalar type"
return TypeVar.derived(self, self.HALFVECTOR)
@@ -428,7 +456,7 @@ class TypeVar(object):
"""
if not self.is_derived:
ts = self.type_set
assert ts.max_lanes < 256, "Can't double 256 lanes."
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
return TypeVar.derived(self, self.DOUBLEVECTOR)