From 4ebc0e85873bebf5762496fb436be391b04a39da Mon Sep 17 00:00:00 2001 From: d1m0 Date: Thu, 22 Jun 2017 16:47:14 -0700 Subject: [PATCH] Convert interval sets inside TypeSet/ValueTypeSet in general sets (#102) * Convert TypeSet fields to sets; Add BitSet type to rust; Encode ValueTypeSets using BitSet; (still need mypy cleanup) * nits * cleanup nits * forgot mypy type annotations * rustfmt fixes * Round 1 comments: filer b2, b4; doc comments in python; move bitset in its own toplevel module; Use Into * fixes * Revert comment to appease rustfmt --- lib/cretonne/meta/cdsl/test_typevar.py | 22 ++-- lib/cretonne/meta/cdsl/typevar.py | 160 +++++++++++++++---------- lib/cretonne/src/bitset.rs | 147 +++++++++++++++++++++++ lib/cretonne/src/ir/instructions.rs | 93 ++++++-------- lib/cretonne/src/lib.rs | 1 + 5 files changed, 292 insertions(+), 131 deletions(-) create mode 100644 lib/cretonne/src/bitset.rs diff --git a/lib/cretonne/meta/cdsl/test_typevar.py b/lib/cretonne/meta/cdsl/test_typevar.py index 29db26f583..97793f71a5 100644 --- a/lib/cretonne/meta/cdsl/test_typevar.py +++ b/lib/cretonne/meta/cdsl/test_typevar.py @@ -40,7 +40,7 @@ class TestTypeSet(TestCase): a = TypeSet(lanes=True, ints=True, floats=True) s = set() s.add(a) - a.max_int = 32 + a.ints.remove(64) # Can't rehash after modification. with self.assertRaises(AssertionError): a in s @@ -71,14 +71,18 @@ class TestTypeVar(TestCase): def test_singleton(self): x = TypeVar.singleton(i32) self.assertEqual(str(x), '`i32`') - self.assertEqual(x.type_set.min_int, 32) - self.assertEqual(x.type_set.max_int, 32) - self.assertEqual(x.type_set.min_lanes, 1) - self.assertEqual(x.type_set.max_lanes, 1) + self.assertEqual(min(x.type_set.ints), 32) + self.assertEqual(max(x.type_set.ints), 32) + self.assertEqual(min(x.type_set.lanes), 1) + self.assertEqual(max(x.type_set.lanes), 1) + self.assertEqual(len(x.type_set.floats), 0) + self.assertEqual(len(x.type_set.bools), 0) x = TypeVar.singleton(i32.by(4)) self.assertEqual(str(x), '`i32x4`') - self.assertEqual(x.type_set.min_int, 32) - self.assertEqual(x.type_set.max_int, 32) - self.assertEqual(x.type_set.min_lanes, 4) - self.assertEqual(x.type_set.max_lanes, 4) + self.assertEqual(min(x.type_set.ints), 32) + self.assertEqual(max(x.type_set.ints), 32) + self.assertEqual(min(x.type_set.lanes), 4) + self.assertEqual(max(x.type_set.lanes), 4) + self.assertEqual(len(x.type_set.floats), 0) + self.assertEqual(len(x.type_set.bools), 0) diff --git a/lib/cretonne/meta/cdsl/typevar.py b/lib/cretonne/meta/cdsl/typevar.py index 1dc8630f5f..119c6bdf01 100644 --- a/lib/cretonne/meta/cdsl/typevar.py +++ b/lib/cretonne/meta/cdsl/typevar.py @@ -9,7 +9,7 @@ import math from . import types, is_power_of_two try: - from typing import Tuple, Union, TYPE_CHECKING # noqa + from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa if TYPE_CHECKING: from srcgen import Formatter # noqa Interval = Tuple[int, int] @@ -46,6 +46,32 @@ def intersect(a, b): return (None, None) +def is_empty(intv): + # type: (Interval) -> bool + return intv is None or intv is False or intv == (None, None) + + +def encode_bitset(vals, size): + # type: (Iterable[int], int) -> int + """ + Encode a set of values (each between 0 and size) as a bitset of width size. + """ + res = 0 + assert is_power_of_two(size) and size <= 64 + for v in vals: + assert 0 <= v and v < size + res |= 1 << v + return res + + +def pp_set(s): + # type: (Iterable[Any]) -> str + """ + Return a consistent string representation of a set (ordering is fixed) + """ + return '{' + ', '.join([repr(x) for x in sorted(s)]) + '}' + + def decode_interval(intv, full_range, default=None): # type: (BoolInterval, Interval, int) -> Interval """ @@ -74,6 +100,18 @@ def decode_interval(intv, full_range, default=None): return (default, default) +def interval_to_set(intv): + # type: (Interval) -> Set + if is_empty(intv): + return set() + + (lo, hi) = intv + assert is_power_of_two(lo) + assert is_power_of_two(hi) + assert lo <= hi + return set([2**i for i in range(int_log2(lo), int_log2(hi)+1)]) + + class TypeSet(object): """ A set of types. @@ -95,22 +133,22 @@ class TypeSet(object): A typeset representing scalar integer types `i8` through `i32`: >>> TypeSet(ints=(8, 32)) - TypeSet(lanes=(1, 1), ints=(8, 32)) + TypeSet(lanes={1}, ints={8, 16, 32}) Passing `True` instead of a range selects all available scalar types: >>> TypeSet(ints=True) - TypeSet(lanes=(1, 1), ints=(8, 64)) + TypeSet(lanes={1}, ints={8, 16, 32, 64}) >>> TypeSet(floats=True) - TypeSet(lanes=(1, 1), floats=(32, 64)) + TypeSet(lanes={1}, floats={32, 64}) >>> TypeSet(bools=True) - TypeSet(lanes=(1, 1), bools=(1, 64)) + TypeSet(lanes={1}, bools={1, 8, 16, 32, 64}) Similarly, passing `True` for the lanes selects all possible scalar and vector types: >>> TypeSet(lanes=True, ints=True) - TypeSet(lanes=(1, 256), ints=(8, 64)) + TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256}, ints={8, 16, 32, 64}) :param lanes: `(min, max)` inclusive range of permitted vector lane counts. :param ints: `(min, max)` inclusive range of permitted scalar integer @@ -123,19 +161,19 @@ class TypeSet(object): def __init__(self, lanes=None, ints=None, floats=None, bools=None): # type: (BoolInterval, BoolInterval, BoolInterval, BoolInterval) -> None # noqa - self.min_lanes, self.max_lanes = decode_interval( - lanes, (1, MAX_LANES), 1) - self.min_int, self.max_int = decode_interval(ints, (8, MAX_BITS)) - self.min_float, self.max_float = decode_interval(floats, (32, 64)) - self.min_bool, self.max_bool = decode_interval(bools, (1, MAX_BITS)) + self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1)) + self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS))) + self.floats = interval_to_set(decode_interval(floats, (32, 64))) + self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS))) + self.bools = set(filter(lambda x: x == 1 or x >= 8, self.bools)) def typeset_key(self): - # type: () -> Tuple[int, int, int, int, int, int, int, int] + # type: () -> Tuple[Tuple, Tuple, Tuple, Tuple] """Key tuple used for hashing and equality.""" - return (self.min_lanes, self.max_lanes, - self.min_int, self.max_int, - self.min_float, self.max_float, - self.min_bool, self.max_bool) + return (tuple(sorted(list(self.lanes))), + tuple(sorted(list(self.ints))), + tuple(sorted(list(self.floats))), + tuple(sorted(list(self.bools)))) def __hash__(self): # type: () -> int @@ -153,31 +191,29 @@ class TypeSet(object): def __repr__(self): # type: () -> str - s = 'TypeSet(lanes=({}, {})'.format(self.min_lanes, self.max_lanes) - if self.min_int is not None: - s += ', ints=({}, {})'.format(self.min_int, self.max_int) - if self.min_float is not None: - s += ', floats=({}, {})'.format(self.min_float, self.max_float) - if self.min_bool is not None: - s += ', bools=({}, {})'.format(self.min_bool, self.max_bool) + s = 'TypeSet(lanes={}'.format(pp_set(self.lanes)) + if len(self.ints) > 0: + s += ', ints={}'.format(pp_set(self.ints)) + if len(self.floats) > 0: + s += ', floats={}'.format(pp_set(self.floats)) + if len(self.bools) > 0: + s += ', bools={}'.format(pp_set(self.bools)) return s + ')' def emit_fields(self, fmt): # type: (Formatter) -> None """Emit field initializers for this typeset.""" fmt.comment(repr(self)) - fields = ('lanes', 'int', 'float', 'bool') - for field in fields: - min_val = getattr(self, 'min_' + field) - max_val = getattr(self, 'max_' + field) - if min_val is None: - fmt.line('min_{}: 0,'.format(field)) - fmt.line('max_{}: 0,'.format(field)) - else: - fmt.line('min_{}: {},'.format( - field, int_log2(min_val))) - fmt.line('max_{}: {},'.format( - field, int_log2(max_val) + 1)) + + fields = (('lanes', 16), + ('ints', 8), + ('floats', 8), + ('bools', 8)) + + for (field, bits) in fields: + vals = [int_log2(x) for x in getattr(self, field)] + fmt.line('{}: BitSet::({}),' + .format(field, bits, encode_bitset(vals, bits))) def __iand__(self, other): # type: (TypeSet) -> TypeSet @@ -186,32 +222,22 @@ class TypeSet(object): >>> a = TypeSet(lanes=True, ints=(16, 32)) >>> a - TypeSet(lanes=(1, 256), ints=(16, 32)) + TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256}, ints={16, 32}) >>> b = TypeSet(lanes=(4, 16), ints=True) >>> a &= b >>> a - TypeSet(lanes=(4, 16), ints=(16, 32)) + TypeSet(lanes={4, 8, 16}, ints={16, 32}) >>> a = TypeSet(lanes=True, bools=(1, 8)) >>> b = TypeSet(lanes=True, bools=(16, 32)) >>> a &= b >>> a - TypeSet(lanes=(1, 256)) + TypeSet(lanes={1, 2, 4, 8, 16, 32, 64, 128, 256}) """ - self.min_lanes = max(self.min_lanes, other.min_lanes) - self.max_lanes = min(self.max_lanes, other.max_lanes) - - self.min_int, self.max_int = intersect( - (self.min_int, self.max_int), - (other.min_int, other.max_int)) - - self.min_float, self.max_float = intersect( - (self.min_float, self.max_float), - (other.min_float, other.max_float)) - - self.min_bool, self.max_bool = intersect( - (self.min_bool, self.max_bool), - (other.min_bool, other.max_bool)) + self.lanes.intersection_update(other.lanes) + self.ints.intersection_update(other.ints) + self.floats.intersection_update(other.floats) + self.bools.intersection_update(other.bools) return self @@ -382,12 +408,12 @@ class TypeVar(object): """ if not self.is_derived: ts = self.type_set - if ts.min_int: - assert ts.min_int > 8, "Can't halve all integer types" - if ts.min_float: - assert ts.min_float > 32, "Can't halve all float types" - if ts.min_bool: - assert ts.min_bool > 8, "Can't halve all boolean types" + if len(ts.ints) > 0: + assert min(ts.ints) > 8, "Can't halve all integer types" + if len(ts.floats) > 0: + assert min(ts.floats) > 32, "Can't halve all float types" + if len(ts.bools) > 0: + assert min(ts.bools) > 8, "Can't halve all boolean types" return TypeVar.derived(self, self.HALFWIDTH) @@ -399,12 +425,14 @@ class TypeVar(object): """ if not self.is_derived: ts = self.type_set - if ts.max_int: - assert ts.max_int < MAX_BITS, "Can't double all integer types." - if ts.max_float: - assert ts.max_float < MAX_BITS, "Can't double all float types." - if ts.max_bool: - assert ts.max_bool < MAX_BITS, "Can't double all bool types." + if len(ts.ints) > 0: + assert max(ts.ints) < MAX_BITS,\ + "Can't double all integer types." + if len(ts.floats) > 0: + assert max(ts.floats) < MAX_BITS,\ + "Can't double all float types." + if len(ts.bools) > 0: + assert max(ts.bools) < MAX_BITS, "Can't double all bool types." return TypeVar.derived(self, self.DOUBLEWIDTH) @@ -416,7 +444,7 @@ class TypeVar(object): """ if not self.is_derived: ts = self.type_set - assert ts.min_lanes > 1, "Can't halve a scalar type" + assert min(ts.lanes) > 1, "Can't halve a scalar type" return TypeVar.derived(self, self.HALFVECTOR) @@ -428,7 +456,7 @@ class TypeVar(object): """ if not self.is_derived: ts = self.type_set - assert ts.max_lanes < 256, "Can't double 256 lanes." + assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes." return TypeVar.derived(self, self.DOUBLEVECTOR) diff --git a/lib/cretonne/src/bitset.rs b/lib/cretonne/src/bitset.rs new file mode 100644 index 0000000000..cfca8371fb --- /dev/null +++ b/lib/cretonne/src/bitset.rs @@ -0,0 +1,147 @@ +//! Small Bitset +//! +//! This module defines a struct BitSet encapsulating a bitset built over the type T. +//! T is intended to be a primitive unsigned type. Currently it can be any type between u8 and u32 +//! +//! If you would like to add support for larger bitsets in the future, you need to change the trait +//! bound Into and the u32 in the implementation of max_bits() +use std::mem::size_of; +use std::ops::{Shl, BitOr, Sub, Add}; +use std::convert::{Into, From}; + +/// A small bitset built on a single primitive integer type +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct BitSet(pub T); + +impl BitSet + where T: Into + From + BitOr + Shl + Sub + + Add + PartialEq + Copy +{ +/// Maximum number of bits supported by this BitSet instance + pub fn bits() -> usize { + size_of::() * 8 + } + +/// Maximum number of bits supported by any bitset instance atm. + pub fn max_bits() -> usize { + size_of::() * 8 + } + +/// Check if this BitSet contains the number num + pub fn contains(&self, num: u8) -> bool { + assert!((num as usize) < Self::bits()); + assert!((num as usize) < Self::max_bits()); + return self.0.into() & (1 << num) != 0; + } + +/// Return the smallest number contained in the bitset or None if empty + pub fn min(&self) -> Option { + if self.0.into() == 0 { + None + } else { + Some(self.0.into().trailing_zeros() as u8) + } + } + +/// Return the largest number contained in the bitset or None if empty + pub fn max(&self) -> Option { + if self.0.into() == 0 { + None + } else { + let leading_zeroes = self.0.into().leading_zeros() as usize; + Some((Self::max_bits() - leading_zeroes - 1) as u8) + } + } + +/// Construct a BitSet with the half-open range [lo,hi) filled in + pub fn from_range(lo: u8, hi: u8) -> BitSet { + assert!(lo <= hi); + assert!((hi as usize) <= Self::bits()); + let one : T = T::from(1); +// I can't just do (one << hi) - one here as the shift may overflow + let hi_rng = if hi >= 1 { + (one << (hi-1)) + ((one << (hi-1)) - one) + } else { + T::from(0) + }; + + let lo_rng = (one << lo) - one; + + BitSet(hi_rng - lo_rng) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn contains() { + let s = BitSet::(255); + for i in 0..7 { + assert!(s.contains(i)); + } + + let s1 = BitSet::(0); + for i in 0..7 { + assert!(!s1.contains(i)); + } + + let s2 = BitSet::(127); + for i in 0..6 { + assert!(s2.contains(i)); + } + assert!(!s2.contains(7)); + + let s3 = BitSet::(2 | 4 | 64); + assert!(!s3.contains(0) && !s3.contains(3) && !s3.contains(4) && !s3.contains(5) && + !s3.contains(7)); + assert!(s3.contains(1) && s3.contains(2) && s3.contains(6)); + + let s4 = BitSet::(4 | 8 | 256 | 1024); + assert!(!s4.contains(0) && !s4.contains(1) && !s4.contains(4) && !s4.contains(5) && + !s4.contains(6) && !s4.contains(7) && + !s4.contains(9) && !s4.contains(11)); + assert!(s4.contains(2) && s4.contains(3) && s4.contains(8) && s4.contains(10)); + } + + #[test] + fn minmax() { + let s = BitSet::(255); + assert_eq!(s.min(), Some(0)); + assert_eq!(s.max(), Some(7)); + assert!(s.min() == Some(0) && s.max() == Some(7)); + let s1 = BitSet::(0); + assert!(s1.min() == None && s1.max() == None); + let s2 = BitSet::(127); + assert!(s2.min() == Some(0) && s2.max() == Some(6)); + let s3 = BitSet::(2 | 4 | 64); + assert!(s3.min() == Some(1) && s3.max() == Some(6)); + let s4 = BitSet::(4 | 8 | 256 | 1024); + assert!(s4.min() == Some(2) && s4.max() == Some(10)); + } + + #[test] + fn from_range() { + let s = BitSet::::from_range(5, 5); + assert!(s.0 == 0); + + let s = BitSet::::from_range(0, 8); + assert!(s.0 == 255); + + let s = BitSet::::from_range(0, 8); + assert!(s.0 == 255u16); + + let s = BitSet::::from_range(0, 16); + assert!(s.0 == 65535u16); + + let s = BitSet::::from_range(5, 6); + assert!(s.0 == 32u8); + + let s = BitSet::::from_range(3, 7); + assert!(s.0 == 8 | 16 | 32 | 64); + + let s = BitSet::::from_range(5, 11); + assert!(s.0 == 32 | 64 | 128 | 256 | 512 | 1024); + } +} diff --git a/lib/cretonne/src/ir/instructions.rs b/lib/cretonne/src/ir/instructions.rs index ad4dd6fa21..35a200de08 100644 --- a/lib/cretonne/src/ir/instructions.rs +++ b/lib/cretonne/src/ir/instructions.rs @@ -17,6 +17,7 @@ use ir::types; use isa::RegUnit; use entity_list; +use bitset::BitSet; use ref_slice::{ref_slice, ref_slice_mut}; /// Some instructions use an external list of argument values because there is not enough space in @@ -499,17 +500,16 @@ impl OpcodeConstraints { } } +type BitSet8 = BitSet; +type BitSet16 = BitSet; + /// A value type set describes the permitted set of types for a type variable. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct ValueTypeSet { - min_lanes: u8, - max_lanes: u8, - min_int: u8, - max_int: u8, - min_float: u8, - max_float: u8, - min_bool: u8, - max_bool: u8, + lanes: BitSet16, + ints: BitSet8, + floats: BitSet8, + bools: BitSet8, } impl ValueTypeSet { @@ -519,11 +519,11 @@ impl ValueTypeSet { fn is_base_type(&self, scalar: Type) -> bool { let l2b = scalar.log2_lane_bits(); if scalar.is_int() { - self.min_int <= l2b && l2b < self.max_int + self.ints.contains(l2b) } else if scalar.is_float() { - self.min_float <= l2b && l2b < self.max_float + self.floats.contains(l2b) } else if scalar.is_bool() { - self.min_bool <= l2b && l2b < self.max_bool + self.bools.contains(l2b) } else { false } @@ -532,23 +532,23 @@ impl ValueTypeSet { /// Does `typ` belong to this set? pub fn contains(&self, typ: Type) -> bool { let l2l = typ.log2_lane_count(); - self.min_lanes <= l2l && l2l < self.max_lanes && self.is_base_type(typ.lane_type()) + self.lanes.contains(l2l) && self.is_base_type(typ.lane_type()) } /// Get an example member of this type set. /// /// This is used for error messages to avoid suggesting invalid types. pub fn example(&self) -> Type { - let t = if self.max_int > 5 { + let t = if self.ints.max().unwrap_or(0) > 5 { types::I32 - } else if self.max_float > 5 { + } else if self.floats.max().unwrap_or(0) > 5 { types::F32 - } else if self.max_bool > 5 { + } else if self.bools.max().unwrap_or(0) > 5 { types::B32 } else { types::B1 }; - t.by(1 << self.min_lanes).unwrap() + t.by(1 << self.lanes.min().unwrap()).unwrap() } } @@ -709,15 +709,12 @@ mod tests { use ir::types::*; let vts = ValueTypeSet { - min_lanes: 0, - max_lanes: 8, - min_int: 3, - max_int: 7, - min_float: 0, - max_float: 0, - min_bool: 3, - max_bool: 7, + lanes: BitSet16::from_range(0, 8), + ints: BitSet8::from_range(4, 7), + floats: BitSet8::from_range(0, 0), + bools: BitSet8::from_range(3, 7), }; + assert!(!vts.contains(I8)); assert!(vts.contains(I32)); assert!(vts.contains(I64)); assert!(vts.contains(I32X4)); @@ -728,38 +725,26 @@ mod tests { assert_eq!(vts.example().to_string(), "i32"); let vts = ValueTypeSet { - min_lanes: 0, - max_lanes: 8, - min_int: 0, - max_int: 0, - min_float: 5, - max_float: 7, - min_bool: 3, - max_bool: 7, + lanes: BitSet16::from_range(0, 8), + ints: BitSet8::from_range(0, 0), + floats: BitSet8::from_range(5, 7), + bools: BitSet8::from_range(3, 7), }; assert_eq!(vts.example().to_string(), "f32"); let vts = ValueTypeSet { - min_lanes: 1, - max_lanes: 8, - min_int: 0, - max_int: 0, - min_float: 5, - max_float: 7, - min_bool: 3, - max_bool: 7, + lanes: BitSet16::from_range(1, 8), + ints: BitSet8::from_range(0, 0), + floats: BitSet8::from_range(5, 7), + bools: BitSet8::from_range(3, 7), }; assert_eq!(vts.example().to_string(), "f32x2"); let vts = ValueTypeSet { - min_lanes: 2, - max_lanes: 8, - min_int: 0, - max_int: 0, - min_float: 0, - max_float: 0, - min_bool: 3, - max_bool: 7, + lanes: BitSet16::from_range(2, 8), + ints: BitSet8::from_range(0, 0), + floats: BitSet8::from_range(0, 0), + bools: BitSet8::from_range(3, 7), }; assert!(!vts.contains(B32X2)); assert!(vts.contains(B32X4)); @@ -767,14 +752,10 @@ mod tests { let vts = ValueTypeSet { // TypeSet(lanes=(1, 256), ints=(8, 64)) - min_lanes: 0, - max_lanes: 9, - min_int: 3, - max_int: 7, - min_float: 0, - max_float: 0, - min_bool: 0, - max_bool: 0, + lanes: BitSet16::from_range(0, 9), + ints: BitSet8::from_range(3, 7), + floats: BitSet8::from_range(0, 0), + bools: BitSet8::from_range(0, 0), }; assert!(vts.contains(I32)); assert!(vts.contains(I32X4)); diff --git a/lib/cretonne/src/lib.rs b/lib/cretonne/src/lib.rs index 09bdfc9302..08c1c4354b 100644 --- a/lib/cretonne/src/lib.rs +++ b/lib/cretonne/src/lib.rs @@ -16,6 +16,7 @@ pub mod dbg; pub mod entity_ref; pub mod binemit; +pub mod bitset; pub mod dominator_tree; pub mod entity_list; pub mod entity_map;