Convert interval sets inside TypeSet/ValueTypeSet in general sets (#102)
* Convert TypeSet fields to sets; Add BitSet<T> type to rust; Encode ValueTypeSets using BitSet; (still need mypy cleanup) * nits * cleanup nits * forgot mypy type annotations * rustfmt fixes * Round 1 comments: filer b2, b4; doc comments in python; move bitset in its own toplevel module; Use Into<u32> * fixes * Revert comment to appease rustfmt
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
cf967642a3
commit
4ebc0e8587
@@ -40,7 +40,7 @@ class TestTypeSet(TestCase):
|
|||||||
a = TypeSet(lanes=True, ints=True, floats=True)
|
a = TypeSet(lanes=True, ints=True, floats=True)
|
||||||
s = set()
|
s = set()
|
||||||
s.add(a)
|
s.add(a)
|
||||||
a.max_int = 32
|
a.ints.remove(64)
|
||||||
# Can't rehash after modification.
|
# Can't rehash after modification.
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
a in s
|
a in s
|
||||||
@@ -71,14 +71,18 @@ class TestTypeVar(TestCase):
|
|||||||
def test_singleton(self):
|
def test_singleton(self):
|
||||||
x = TypeVar.singleton(i32)
|
x = TypeVar.singleton(i32)
|
||||||
self.assertEqual(str(x), '`i32`')
|
self.assertEqual(str(x), '`i32`')
|
||||||
self.assertEqual(x.type_set.min_int, 32)
|
self.assertEqual(min(x.type_set.ints), 32)
|
||||||
self.assertEqual(x.type_set.max_int, 32)
|
self.assertEqual(max(x.type_set.ints), 32)
|
||||||
self.assertEqual(x.type_set.min_lanes, 1)
|
self.assertEqual(min(x.type_set.lanes), 1)
|
||||||
self.assertEqual(x.type_set.max_lanes, 1)
|
self.assertEqual(max(x.type_set.lanes), 1)
|
||||||
|
self.assertEqual(len(x.type_set.floats), 0)
|
||||||
|
self.assertEqual(len(x.type_set.bools), 0)
|
||||||
|
|
||||||
x = TypeVar.singleton(i32.by(4))
|
x = TypeVar.singleton(i32.by(4))
|
||||||
self.assertEqual(str(x), '`i32x4`')
|
self.assertEqual(str(x), '`i32x4`')
|
||||||
self.assertEqual(x.type_set.min_int, 32)
|
self.assertEqual(min(x.type_set.ints), 32)
|
||||||
self.assertEqual(x.type_set.max_int, 32)
|
self.assertEqual(max(x.type_set.ints), 32)
|
||||||
self.assertEqual(x.type_set.min_lanes, 4)
|
self.assertEqual(min(x.type_set.lanes), 4)
|
||||||
self.assertEqual(x.type_set.max_lanes, 4)
|
self.assertEqual(max(x.type_set.lanes), 4)
|
||||||
|
self.assertEqual(len(x.type_set.floats), 0)
|
||||||
|
self.assertEqual(len(x.type_set.bools), 0)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import math
|
|||||||
from . import types, is_power_of_two
|
from . import types, is_power_of_two
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Tuple, Union, TYPE_CHECKING # noqa
|
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from srcgen import Formatter # noqa
|
from srcgen import Formatter # noqa
|
||||||
Interval = Tuple[int, int]
|
Interval = Tuple[int, int]
|
||||||
@@ -46,6 +46,32 @@ def intersect(a, b):
|
|||||||
return (None, None)
|
return (None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def is_empty(intv):
|
||||||
|
# type: (Interval) -> bool
|
||||||
|
return intv is None or intv is False or intv == (None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_bitset(vals, size):
|
||||||
|
# type: (Iterable[int], int) -> int
|
||||||
|
"""
|
||||||
|
Encode a set of values (each between 0 and size) as a bitset of width size.
|
||||||
|
"""
|
||||||
|
res = 0
|
||||||
|
assert is_power_of_two(size) and size <= 64
|
||||||
|
for v in vals:
|
||||||
|
assert 0 <= v and v < size
|
||||||
|
res |= 1 << v
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def pp_set(s):
|
||||||
|
# type: (Iterable[Any]) -> str
|
||||||
|
"""
|
||||||
|
Return a consistent string representation of a set (ordering is fixed)
|
||||||
|
"""
|
||||||
|
return '{' + ', '.join([repr(x) for x in sorted(s)]) + '}'
|
||||||
|
|
||||||
|
|
||||||
def decode_interval(intv, full_range, default=None):
|
def decode_interval(intv, full_range, default=None):
|
||||||
# type: (BoolInterval, Interval, int) -> Interval
|
# type: (BoolInterval, Interval, int) -> Interval
|
||||||
"""
|
"""
|
||||||
@@ -74,6 +100,18 @@ def decode_interval(intv, full_range, default=None):
|
|||||||
return (default, default)
|
return (default, default)
|
||||||
|
|
||||||
|
|
||||||
|
def interval_to_set(intv):
|
||||||
|
# type: (Interval) -> Set
|
||||||
|
if is_empty(intv):
|
||||||
|
return set()
|
||||||
|
|
||||||
|
(lo, hi) = intv
|
||||||
|
assert is_power_of_two(lo)
|
||||||
|
assert is_power_of_two(hi)
|
||||||
|
assert lo <= hi
|
||||||
|
return set([2**i for i in range(int_log2(lo), int_log2(hi)+1)])
|
||||||
|
|
||||||
|
|
||||||
class TypeSet(object):
|
class TypeSet(object):
|
||||||
"""
|
"""
|
||||||
A set of types.
|
A set of types.
|
||||||
@@ -95,22 +133,22 @@ class TypeSet(object):
|
|||||||
A typeset representing scalar integer types `i8` through `i32`:
|
A typeset representing scalar integer types `i8` through `i32`:
|
||||||
|
|
||||||
>>> TypeSet(ints=(8, 32))
|
>>> TypeSet(ints=(8, 32))
|
||||||
TypeSet(lanes=(1, 1), ints=(8, 32))
|
TypeSet(lanes={1}, ints={8, 16, 32})
|
||||||
|
|
||||||
Passing `True` instead of a range selects all available scalar types:
|
Passing `True` instead of a range selects all available scalar types:
|
||||||
|
|
||||||
>>> TypeSet(ints=True)
|
>>> TypeSet(ints=True)
|
||||||
TypeSet(lanes=(1, 1), ints=(8, 64))
|
TypeSet(lanes={1}, ints={8, 16, 32, 64})
|
||||||
>>> TypeSet(floats=True)
|
>>> TypeSet(floats=True)
|
||||||
TypeSet(lanes=(1, 1), floats=(32, 64))
|
TypeSet(lanes={1}, floats={32, 64})
|
||||||
>>> TypeSet(bools=True)
|
>>> TypeSet(bools=True)
|
||||||
TypeSet(lanes=(1, 1), bools=(1, 64))
|
TypeSet(lanes={1}, bools={1, 8, 16, 32, 64})
|
||||||
|
|
||||||
Similarly, passing `True` for the lanes selects all possible scalar and
|
Similarly, passing `True` for the lanes selects all possible scalar and
|
||||||
vector types:
|
vector types:
|
||||||
|
|
||||||
>>> TypeSet(lanes=True, ints=True)
|
>>> TypeSet(lanes=True, ints=True)
|
||||||
TypeSet(lanes=(1, 256), ints=(8, 64))
|
TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256}, ints={8, 16, 32, 64})
|
||||||
|
|
||||||
:param lanes: `(min, max)` inclusive range of permitted vector lane counts.
|
:param lanes: `(min, max)` inclusive range of permitted vector lane counts.
|
||||||
:param ints: `(min, max)` inclusive range of permitted scalar integer
|
:param ints: `(min, max)` inclusive range of permitted scalar integer
|
||||||
@@ -123,19 +161,19 @@ class TypeSet(object):
|
|||||||
|
|
||||||
def __init__(self, lanes=None, ints=None, floats=None, bools=None):
|
def __init__(self, lanes=None, ints=None, floats=None, bools=None):
|
||||||
# type: (BoolInterval, BoolInterval, BoolInterval, BoolInterval) -> None # noqa
|
# type: (BoolInterval, BoolInterval, BoolInterval, BoolInterval) -> None # noqa
|
||||||
self.min_lanes, self.max_lanes = decode_interval(
|
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
|
||||||
lanes, (1, MAX_LANES), 1)
|
self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS)))
|
||||||
self.min_int, self.max_int = decode_interval(ints, (8, MAX_BITS))
|
self.floats = interval_to_set(decode_interval(floats, (32, 64)))
|
||||||
self.min_float, self.max_float = decode_interval(floats, (32, 64))
|
self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS)))
|
||||||
self.min_bool, self.max_bool = decode_interval(bools, (1, MAX_BITS))
|
self.bools = set(filter(lambda x: x == 1 or x >= 8, self.bools))
|
||||||
|
|
||||||
def typeset_key(self):
|
def typeset_key(self):
|
||||||
# type: () -> Tuple[int, int, int, int, int, int, int, int]
|
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple]
|
||||||
"""Key tuple used for hashing and equality."""
|
"""Key tuple used for hashing and equality."""
|
||||||
return (self.min_lanes, self.max_lanes,
|
return (tuple(sorted(list(self.lanes))),
|
||||||
self.min_int, self.max_int,
|
tuple(sorted(list(self.ints))),
|
||||||
self.min_float, self.max_float,
|
tuple(sorted(list(self.floats))),
|
||||||
self.min_bool, self.max_bool)
|
tuple(sorted(list(self.bools))))
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
# type: () -> int
|
# type: () -> int
|
||||||
@@ -153,31 +191,29 @@ class TypeSet(object):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# type: () -> str
|
# type: () -> str
|
||||||
s = 'TypeSet(lanes=({}, {})'.format(self.min_lanes, self.max_lanes)
|
s = 'TypeSet(lanes={}'.format(pp_set(self.lanes))
|
||||||
if self.min_int is not None:
|
if len(self.ints) > 0:
|
||||||
s += ', ints=({}, {})'.format(self.min_int, self.max_int)
|
s += ', ints={}'.format(pp_set(self.ints))
|
||||||
if self.min_float is not None:
|
if len(self.floats) > 0:
|
||||||
s += ', floats=({}, {})'.format(self.min_float, self.max_float)
|
s += ', floats={}'.format(pp_set(self.floats))
|
||||||
if self.min_bool is not None:
|
if len(self.bools) > 0:
|
||||||
s += ', bools=({}, {})'.format(self.min_bool, self.max_bool)
|
s += ', bools={}'.format(pp_set(self.bools))
|
||||||
return s + ')'
|
return s + ')'
|
||||||
|
|
||||||
def emit_fields(self, fmt):
|
def emit_fields(self, fmt):
|
||||||
# type: (Formatter) -> None
|
# type: (Formatter) -> None
|
||||||
"""Emit field initializers for this typeset."""
|
"""Emit field initializers for this typeset."""
|
||||||
fmt.comment(repr(self))
|
fmt.comment(repr(self))
|
||||||
fields = ('lanes', 'int', 'float', 'bool')
|
|
||||||
for field in fields:
|
fields = (('lanes', 16),
|
||||||
min_val = getattr(self, 'min_' + field)
|
('ints', 8),
|
||||||
max_val = getattr(self, 'max_' + field)
|
('floats', 8),
|
||||||
if min_val is None:
|
('bools', 8))
|
||||||
fmt.line('min_{}: 0,'.format(field))
|
|
||||||
fmt.line('max_{}: 0,'.format(field))
|
for (field, bits) in fields:
|
||||||
else:
|
vals = [int_log2(x) for x in getattr(self, field)]
|
||||||
fmt.line('min_{}: {},'.format(
|
fmt.line('{}: BitSet::<u{}>({}),'
|
||||||
field, int_log2(min_val)))
|
.format(field, bits, encode_bitset(vals, bits)))
|
||||||
fmt.line('max_{}: {},'.format(
|
|
||||||
field, int_log2(max_val) + 1))
|
|
||||||
|
|
||||||
def __iand__(self, other):
|
def __iand__(self, other):
|
||||||
# type: (TypeSet) -> TypeSet
|
# type: (TypeSet) -> TypeSet
|
||||||
@@ -186,32 +222,22 @@ class TypeSet(object):
|
|||||||
|
|
||||||
>>> a = TypeSet(lanes=True, ints=(16, 32))
|
>>> a = TypeSet(lanes=True, ints=(16, 32))
|
||||||
>>> a
|
>>> a
|
||||||
TypeSet(lanes=(1, 256), ints=(16, 32))
|
TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256}, ints={16, 32})
|
||||||
>>> b = TypeSet(lanes=(4, 16), ints=True)
|
>>> b = TypeSet(lanes=(4, 16), ints=True)
|
||||||
>>> a &= b
|
>>> a &= b
|
||||||
>>> a
|
>>> a
|
||||||
TypeSet(lanes=(4, 16), ints=(16, 32))
|
TypeSet(lanes={4, 8, 16}, ints={16, 32})
|
||||||
|
|
||||||
>>> a = TypeSet(lanes=True, bools=(1, 8))
|
>>> a = TypeSet(lanes=True, bools=(1, 8))
|
||||||
>>> b = TypeSet(lanes=True, bools=(16, 32))
|
>>> b = TypeSet(lanes=True, bools=(16, 32))
|
||||||
>>> a &= b
|
>>> a &= b
|
||||||
>>> a
|
>>> a
|
||||||
TypeSet(lanes=(1, 256))
|
TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256})
|
||||||
"""
|
"""
|
||||||
self.min_lanes = max(self.min_lanes, other.min_lanes)
|
self.lanes.intersection_update(other.lanes)
|
||||||
self.max_lanes = min(self.max_lanes, other.max_lanes)
|
self.ints.intersection_update(other.ints)
|
||||||
|
self.floats.intersection_update(other.floats)
|
||||||
self.min_int, self.max_int = intersect(
|
self.bools.intersection_update(other.bools)
|
||||||
(self.min_int, self.max_int),
|
|
||||||
(other.min_int, other.max_int))
|
|
||||||
|
|
||||||
self.min_float, self.max_float = intersect(
|
|
||||||
(self.min_float, self.max_float),
|
|
||||||
(other.min_float, other.max_float))
|
|
||||||
|
|
||||||
self.min_bool, self.max_bool = intersect(
|
|
||||||
(self.min_bool, self.max_bool),
|
|
||||||
(other.min_bool, other.max_bool))
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -382,12 +408,12 @@ class TypeVar(object):
|
|||||||
"""
|
"""
|
||||||
if not self.is_derived:
|
if not self.is_derived:
|
||||||
ts = self.type_set
|
ts = self.type_set
|
||||||
if ts.min_int:
|
if len(ts.ints) > 0:
|
||||||
assert ts.min_int > 8, "Can't halve all integer types"
|
assert min(ts.ints) > 8, "Can't halve all integer types"
|
||||||
if ts.min_float:
|
if len(ts.floats) > 0:
|
||||||
assert ts.min_float > 32, "Can't halve all float types"
|
assert min(ts.floats) > 32, "Can't halve all float types"
|
||||||
if ts.min_bool:
|
if len(ts.bools) > 0:
|
||||||
assert ts.min_bool > 8, "Can't halve all boolean types"
|
assert min(ts.bools) > 8, "Can't halve all boolean types"
|
||||||
|
|
||||||
return TypeVar.derived(self, self.HALFWIDTH)
|
return TypeVar.derived(self, self.HALFWIDTH)
|
||||||
|
|
||||||
@@ -399,12 +425,14 @@ class TypeVar(object):
|
|||||||
"""
|
"""
|
||||||
if not self.is_derived:
|
if not self.is_derived:
|
||||||
ts = self.type_set
|
ts = self.type_set
|
||||||
if ts.max_int:
|
if len(ts.ints) > 0:
|
||||||
assert ts.max_int < MAX_BITS, "Can't double all integer types."
|
assert max(ts.ints) < MAX_BITS,\
|
||||||
if ts.max_float:
|
"Can't double all integer types."
|
||||||
assert ts.max_float < MAX_BITS, "Can't double all float types."
|
if len(ts.floats) > 0:
|
||||||
if ts.max_bool:
|
assert max(ts.floats) < MAX_BITS,\
|
||||||
assert ts.max_bool < MAX_BITS, "Can't double all bool types."
|
"Can't double all float types."
|
||||||
|
if len(ts.bools) > 0:
|
||||||
|
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
|
||||||
|
|
||||||
return TypeVar.derived(self, self.DOUBLEWIDTH)
|
return TypeVar.derived(self, self.DOUBLEWIDTH)
|
||||||
|
|
||||||
@@ -416,7 +444,7 @@ class TypeVar(object):
|
|||||||
"""
|
"""
|
||||||
if not self.is_derived:
|
if not self.is_derived:
|
||||||
ts = self.type_set
|
ts = self.type_set
|
||||||
assert ts.min_lanes > 1, "Can't halve a scalar type"
|
assert min(ts.lanes) > 1, "Can't halve a scalar type"
|
||||||
|
|
||||||
return TypeVar.derived(self, self.HALFVECTOR)
|
return TypeVar.derived(self, self.HALFVECTOR)
|
||||||
|
|
||||||
@@ -428,7 +456,7 @@ class TypeVar(object):
|
|||||||
"""
|
"""
|
||||||
if not self.is_derived:
|
if not self.is_derived:
|
||||||
ts = self.type_set
|
ts = self.type_set
|
||||||
assert ts.max_lanes < 256, "Can't double 256 lanes."
|
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
|
||||||
|
|
||||||
return TypeVar.derived(self, self.DOUBLEVECTOR)
|
return TypeVar.derived(self, self.DOUBLEVECTOR)
|
||||||
|
|
||||||
|
|||||||
147
lib/cretonne/src/bitset.rs
Normal file
147
lib/cretonne/src/bitset.rs
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
//! Small Bitset
|
||||||
|
//!
|
||||||
|
//! This module defines a struct BitSet<T> encapsulating a bitset built over the type T.
|
||||||
|
//! T is intended to be a primitive unsigned type. Currently it can be any type between u8 and u32
|
||||||
|
//!
|
||||||
|
//! If you would like to add support for larger bitsets in the future, you need to change the trait
|
||||||
|
//! bound Into<u32> and the u32 in the implementation of max_bits()
|
||||||
|
use std::mem::size_of;
|
||||||
|
use std::ops::{Shl, BitOr, Sub, Add};
|
||||||
|
use std::convert::{Into, From};
|
||||||
|
|
||||||
|
/// A small bitset built on a single primitive integer type
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
pub struct BitSet<T>(pub T);
|
||||||
|
|
||||||
|
impl<T> BitSet<T>
|
||||||
|
where T: Into<u32> + From<u8> + BitOr<T, Output = T> + Shl<u8, Output = T> + Sub<T, Output=T> +
|
||||||
|
Add<T, Output=T> + PartialEq + Copy
|
||||||
|
{
|
||||||
|
/// Maximum number of bits supported by this BitSet instance
|
||||||
|
pub fn bits() -> usize {
|
||||||
|
size_of::<T>() * 8
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Maximum number of bits supported by any bitset instance atm.
|
||||||
|
pub fn max_bits() -> usize {
|
||||||
|
size_of::<u32>() * 8
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if this BitSet contains the number num
|
||||||
|
pub fn contains(&self, num: u8) -> bool {
|
||||||
|
assert!((num as usize) < Self::bits());
|
||||||
|
assert!((num as usize) < Self::max_bits());
|
||||||
|
return self.0.into() & (1 << num) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the smallest number contained in the bitset or None if empty
|
||||||
|
pub fn min(&self) -> Option<u8> {
|
||||||
|
if self.0.into() == 0 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.0.into().trailing_zeros() as u8)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the largest number contained in the bitset or None if empty
|
||||||
|
pub fn max(&self) -> Option<u8> {
|
||||||
|
if self.0.into() == 0 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let leading_zeroes = self.0.into().leading_zeros() as usize;
|
||||||
|
Some((Self::max_bits() - leading_zeroes - 1) as u8)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct a BitSet with the half-open range [lo,hi) filled in
|
||||||
|
pub fn from_range(lo: u8, hi: u8) -> BitSet<T> {
|
||||||
|
assert!(lo <= hi);
|
||||||
|
assert!((hi as usize) <= Self::bits());
|
||||||
|
let one : T = T::from(1);
|
||||||
|
// I can't just do (one << hi) - one here as the shift may overflow
|
||||||
|
let hi_rng = if hi >= 1 {
|
||||||
|
(one << (hi-1)) + ((one << (hi-1)) - one)
|
||||||
|
} else {
|
||||||
|
T::from(0)
|
||||||
|
};
|
||||||
|
|
||||||
|
let lo_rng = (one << lo) - one;
|
||||||
|
|
||||||
|
BitSet(hi_rng - lo_rng)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn contains() {
|
||||||
|
let s = BitSet::<u8>(255);
|
||||||
|
for i in 0..7 {
|
||||||
|
assert!(s.contains(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
let s1 = BitSet::<u8>(0);
|
||||||
|
for i in 0..7 {
|
||||||
|
assert!(!s1.contains(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
let s2 = BitSet::<u8>(127);
|
||||||
|
for i in 0..6 {
|
||||||
|
assert!(s2.contains(i));
|
||||||
|
}
|
||||||
|
assert!(!s2.contains(7));
|
||||||
|
|
||||||
|
let s3 = BitSet::<u8>(2 | 4 | 64);
|
||||||
|
assert!(!s3.contains(0) && !s3.contains(3) && !s3.contains(4) && !s3.contains(5) &&
|
||||||
|
!s3.contains(7));
|
||||||
|
assert!(s3.contains(1) && s3.contains(2) && s3.contains(6));
|
||||||
|
|
||||||
|
let s4 = BitSet::<u16>(4 | 8 | 256 | 1024);
|
||||||
|
assert!(!s4.contains(0) && !s4.contains(1) && !s4.contains(4) && !s4.contains(5) &&
|
||||||
|
!s4.contains(6) && !s4.contains(7) &&
|
||||||
|
!s4.contains(9) && !s4.contains(11));
|
||||||
|
assert!(s4.contains(2) && s4.contains(3) && s4.contains(8) && s4.contains(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn minmax() {
|
||||||
|
let s = BitSet::<u8>(255);
|
||||||
|
assert_eq!(s.min(), Some(0));
|
||||||
|
assert_eq!(s.max(), Some(7));
|
||||||
|
assert!(s.min() == Some(0) && s.max() == Some(7));
|
||||||
|
let s1 = BitSet::<u8>(0);
|
||||||
|
assert!(s1.min() == None && s1.max() == None);
|
||||||
|
let s2 = BitSet::<u8>(127);
|
||||||
|
assert!(s2.min() == Some(0) && s2.max() == Some(6));
|
||||||
|
let s3 = BitSet::<u8>(2 | 4 | 64);
|
||||||
|
assert!(s3.min() == Some(1) && s3.max() == Some(6));
|
||||||
|
let s4 = BitSet::<u16>(4 | 8 | 256 | 1024);
|
||||||
|
assert!(s4.min() == Some(2) && s4.max() == Some(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn from_range() {
|
||||||
|
let s = BitSet::<u8>::from_range(5, 5);
|
||||||
|
assert!(s.0 == 0);
|
||||||
|
|
||||||
|
let s = BitSet::<u8>::from_range(0, 8);
|
||||||
|
assert!(s.0 == 255);
|
||||||
|
|
||||||
|
let s = BitSet::<u16>::from_range(0, 8);
|
||||||
|
assert!(s.0 == 255u16);
|
||||||
|
|
||||||
|
let s = BitSet::<u16>::from_range(0, 16);
|
||||||
|
assert!(s.0 == 65535u16);
|
||||||
|
|
||||||
|
let s = BitSet::<u8>::from_range(5, 6);
|
||||||
|
assert!(s.0 == 32u8);
|
||||||
|
|
||||||
|
let s = BitSet::<u8>::from_range(3, 7);
|
||||||
|
assert!(s.0 == 8 | 16 | 32 | 64);
|
||||||
|
|
||||||
|
let s = BitSet::<u16>::from_range(5, 11);
|
||||||
|
assert!(s.0 == 32 | 64 | 128 | 256 | 512 | 1024);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -17,6 +17,7 @@ use ir::types;
|
|||||||
use isa::RegUnit;
|
use isa::RegUnit;
|
||||||
|
|
||||||
use entity_list;
|
use entity_list;
|
||||||
|
use bitset::BitSet;
|
||||||
use ref_slice::{ref_slice, ref_slice_mut};
|
use ref_slice::{ref_slice, ref_slice_mut};
|
||||||
|
|
||||||
/// Some instructions use an external list of argument values because there is not enough space in
|
/// Some instructions use an external list of argument values because there is not enough space in
|
||||||
@@ -499,17 +500,16 @@ impl OpcodeConstraints {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BitSet8 = BitSet<u8>;
|
||||||
|
type BitSet16 = BitSet<u16>;
|
||||||
|
|
||||||
/// A value type set describes the permitted set of types for a type variable.
|
/// A value type set describes the permitted set of types for a type variable.
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
pub struct ValueTypeSet {
|
pub struct ValueTypeSet {
|
||||||
min_lanes: u8,
|
lanes: BitSet16,
|
||||||
max_lanes: u8,
|
ints: BitSet8,
|
||||||
min_int: u8,
|
floats: BitSet8,
|
||||||
max_int: u8,
|
bools: BitSet8,
|
||||||
min_float: u8,
|
|
||||||
max_float: u8,
|
|
||||||
min_bool: u8,
|
|
||||||
max_bool: u8,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ValueTypeSet {
|
impl ValueTypeSet {
|
||||||
@@ -519,11 +519,11 @@ impl ValueTypeSet {
|
|||||||
fn is_base_type(&self, scalar: Type) -> bool {
|
fn is_base_type(&self, scalar: Type) -> bool {
|
||||||
let l2b = scalar.log2_lane_bits();
|
let l2b = scalar.log2_lane_bits();
|
||||||
if scalar.is_int() {
|
if scalar.is_int() {
|
||||||
self.min_int <= l2b && l2b < self.max_int
|
self.ints.contains(l2b)
|
||||||
} else if scalar.is_float() {
|
} else if scalar.is_float() {
|
||||||
self.min_float <= l2b && l2b < self.max_float
|
self.floats.contains(l2b)
|
||||||
} else if scalar.is_bool() {
|
} else if scalar.is_bool() {
|
||||||
self.min_bool <= l2b && l2b < self.max_bool
|
self.bools.contains(l2b)
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
@@ -532,23 +532,23 @@ impl ValueTypeSet {
|
|||||||
/// Does `typ` belong to this set?
|
/// Does `typ` belong to this set?
|
||||||
pub fn contains(&self, typ: Type) -> bool {
|
pub fn contains(&self, typ: Type) -> bool {
|
||||||
let l2l = typ.log2_lane_count();
|
let l2l = typ.log2_lane_count();
|
||||||
self.min_lanes <= l2l && l2l < self.max_lanes && self.is_base_type(typ.lane_type())
|
self.lanes.contains(l2l) && self.is_base_type(typ.lane_type())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get an example member of this type set.
|
/// Get an example member of this type set.
|
||||||
///
|
///
|
||||||
/// This is used for error messages to avoid suggesting invalid types.
|
/// This is used for error messages to avoid suggesting invalid types.
|
||||||
pub fn example(&self) -> Type {
|
pub fn example(&self) -> Type {
|
||||||
let t = if self.max_int > 5 {
|
let t = if self.ints.max().unwrap_or(0) > 5 {
|
||||||
types::I32
|
types::I32
|
||||||
} else if self.max_float > 5 {
|
} else if self.floats.max().unwrap_or(0) > 5 {
|
||||||
types::F32
|
types::F32
|
||||||
} else if self.max_bool > 5 {
|
} else if self.bools.max().unwrap_or(0) > 5 {
|
||||||
types::B32
|
types::B32
|
||||||
} else {
|
} else {
|
||||||
types::B1
|
types::B1
|
||||||
};
|
};
|
||||||
t.by(1 << self.min_lanes).unwrap()
|
t.by(1 << self.lanes.min().unwrap()).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -709,15 +709,12 @@ mod tests {
|
|||||||
use ir::types::*;
|
use ir::types::*;
|
||||||
|
|
||||||
let vts = ValueTypeSet {
|
let vts = ValueTypeSet {
|
||||||
min_lanes: 0,
|
lanes: BitSet16::from_range(0, 8),
|
||||||
max_lanes: 8,
|
ints: BitSet8::from_range(4, 7),
|
||||||
min_int: 3,
|
floats: BitSet8::from_range(0, 0),
|
||||||
max_int: 7,
|
bools: BitSet8::from_range(3, 7),
|
||||||
min_float: 0,
|
|
||||||
max_float: 0,
|
|
||||||
min_bool: 3,
|
|
||||||
max_bool: 7,
|
|
||||||
};
|
};
|
||||||
|
assert!(!vts.contains(I8));
|
||||||
assert!(vts.contains(I32));
|
assert!(vts.contains(I32));
|
||||||
assert!(vts.contains(I64));
|
assert!(vts.contains(I64));
|
||||||
assert!(vts.contains(I32X4));
|
assert!(vts.contains(I32X4));
|
||||||
@@ -728,38 +725,26 @@ mod tests {
|
|||||||
assert_eq!(vts.example().to_string(), "i32");
|
assert_eq!(vts.example().to_string(), "i32");
|
||||||
|
|
||||||
let vts = ValueTypeSet {
|
let vts = ValueTypeSet {
|
||||||
min_lanes: 0,
|
lanes: BitSet16::from_range(0, 8),
|
||||||
max_lanes: 8,
|
ints: BitSet8::from_range(0, 0),
|
||||||
min_int: 0,
|
floats: BitSet8::from_range(5, 7),
|
||||||
max_int: 0,
|
bools: BitSet8::from_range(3, 7),
|
||||||
min_float: 5,
|
|
||||||
max_float: 7,
|
|
||||||
min_bool: 3,
|
|
||||||
max_bool: 7,
|
|
||||||
};
|
};
|
||||||
assert_eq!(vts.example().to_string(), "f32");
|
assert_eq!(vts.example().to_string(), "f32");
|
||||||
|
|
||||||
let vts = ValueTypeSet {
|
let vts = ValueTypeSet {
|
||||||
min_lanes: 1,
|
lanes: BitSet16::from_range(1, 8),
|
||||||
max_lanes: 8,
|
ints: BitSet8::from_range(0, 0),
|
||||||
min_int: 0,
|
floats: BitSet8::from_range(5, 7),
|
||||||
max_int: 0,
|
bools: BitSet8::from_range(3, 7),
|
||||||
min_float: 5,
|
|
||||||
max_float: 7,
|
|
||||||
min_bool: 3,
|
|
||||||
max_bool: 7,
|
|
||||||
};
|
};
|
||||||
assert_eq!(vts.example().to_string(), "f32x2");
|
assert_eq!(vts.example().to_string(), "f32x2");
|
||||||
|
|
||||||
let vts = ValueTypeSet {
|
let vts = ValueTypeSet {
|
||||||
min_lanes: 2,
|
lanes: BitSet16::from_range(2, 8),
|
||||||
max_lanes: 8,
|
ints: BitSet8::from_range(0, 0),
|
||||||
min_int: 0,
|
floats: BitSet8::from_range(0, 0),
|
||||||
max_int: 0,
|
bools: BitSet8::from_range(3, 7),
|
||||||
min_float: 0,
|
|
||||||
max_float: 0,
|
|
||||||
min_bool: 3,
|
|
||||||
max_bool: 7,
|
|
||||||
};
|
};
|
||||||
assert!(!vts.contains(B32X2));
|
assert!(!vts.contains(B32X2));
|
||||||
assert!(vts.contains(B32X4));
|
assert!(vts.contains(B32X4));
|
||||||
@@ -767,14 +752,10 @@ mod tests {
|
|||||||
|
|
||||||
let vts = ValueTypeSet {
|
let vts = ValueTypeSet {
|
||||||
// TypeSet(lanes=(1, 256), ints=(8, 64))
|
// TypeSet(lanes=(1, 256), ints=(8, 64))
|
||||||
min_lanes: 0,
|
lanes: BitSet16::from_range(0, 9),
|
||||||
max_lanes: 9,
|
ints: BitSet8::from_range(3, 7),
|
||||||
min_int: 3,
|
floats: BitSet8::from_range(0, 0),
|
||||||
max_int: 7,
|
bools: BitSet8::from_range(0, 0),
|
||||||
min_float: 0,
|
|
||||||
max_float: 0,
|
|
||||||
min_bool: 0,
|
|
||||||
max_bool: 0,
|
|
||||||
};
|
};
|
||||||
assert!(vts.contains(I32));
|
assert!(vts.contains(I32));
|
||||||
assert!(vts.contains(I32X4));
|
assert!(vts.contains(I32X4));
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ pub mod dbg;
|
|||||||
pub mod entity_ref;
|
pub mod entity_ref;
|
||||||
|
|
||||||
pub mod binemit;
|
pub mod binemit;
|
||||||
|
pub mod bitset;
|
||||||
pub mod dominator_tree;
|
pub mod dominator_tree;
|
||||||
pub mod entity_list;
|
pub mod entity_list;
|
||||||
pub mod entity_map;
|
pub mod entity_map;
|
||||||
|
|||||||
Reference in New Issue
Block a user