Add the BVType; Add suport for bitvectors in TypeVar and TypeSet.

This commit is contained in:
Dimo
2017-07-20 17:20:23 -07:00
committed by Jakob Stoklund Olesen
parent 605886a277
commit a12fa86e60
3 changed files with 136 additions and 12 deletions

View File

@@ -239,7 +239,7 @@ class Var(Expr):
'typeof_{}'.format(self), 'typeof_{}'.format(self),
'Type of the pattern variable `{}`'.format(self), 'Type of the pattern variable `{}`'.format(self),
ints=True, floats=True, bools=True, ints=True, floats=True, bools=True,
scalars=True, simd=True) scalars=True, simd=True, bitvecs=True)
self.original_typevar = tv self.original_typevar = tv
self.typevar = tv self.typevar = tv
return self.typevar return self.typevar

View File

@@ -84,7 +84,10 @@ class ScalarType(ValueType):
self._vectors = dict() # type: Dict[int, VectorType] self._vectors = dict() # type: Dict[int, VectorType]
# Assign numbers starting from 1. (0 is VOID). # Assign numbers starting from 1. (0 is VOID).
ValueType.all_scalars.append(self) ValueType.all_scalars.append(self)
self.number = len(ValueType.all_scalars) # Numbers are only valid for Cretone types that get emitted to Rust.
# This excludes BVTypes
self.number = len([x for x in ValueType.all_scalars
if not isinstance(x, BVType)])
assert self.number < 16, 'Too many scalar types' assert self.number < 16, 'Too many scalar types'
def __repr__(self): def __repr__(self):
@@ -239,3 +242,34 @@ class BoolType(ScalarType):
# type: () -> int # type: () -> int
"""Return the number of bits in a lane.""" """Return the number of bits in a lane."""
return self.bits return self.bits
class BVType(ScalarType):
"""A flat bitvector type. Used for semantics description only."""
def __init__(self, bits):
# type: (int) -> None
assert bits > 0, 'Must have positive number of bits'
super(BVType, self).__init__(
name='bv{:d}'.format(bits),
membytes=bits // 8,
doc="A bitvector type with {} bits.".format(bits))
self.bits = bits
def __repr__(self):
# type: () -> str
return 'BVType(bits={})'.format(self.bits)
@staticmethod
def with_bits(bits):
# type: (int) -> BVType
typ = ValueType.by_name('bv{:d}'.format(bits))
if TYPE_CHECKING:
return cast(BVType, typ)
else:
return typ
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.bits

View File

@@ -22,6 +22,7 @@ except ImportError:
MAX_LANES = 256 MAX_LANES = 256
MAX_BITS = 64 MAX_BITS = 64
MAX_BITVEC = MAX_BITS * MAX_LANES
def int_log2(x): def int_log2(x):
@@ -169,15 +170,20 @@ class TypeSet(object):
point widths. point widths.
:param bools: `(min, max)` inclusive range of permitted scalar boolean :param bools: `(min, max)` inclusive range of permitted scalar boolean
widths. widths.
:param bitvecs : `(min, max)` inclusive range of permitted bitvector
widths.
""" """
def __init__(self, lanes=None, ints=None, floats=None, bools=None): def __init__(self, lanes=None, ints=None, floats=None, bools=None,
# type: (BoolInterval, BoolInterval, BoolInterval, BoolInterval) -> None # noqa bitvecs=None):
# type: (BoolInterval, BoolInterval, BoolInterval, BoolInterval, BoolInterval) -> None # noqa
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1)) self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS))) self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS)))
self.floats = interval_to_set(decode_interval(floats, (32, 64))) self.floats = interval_to_set(decode_interval(floats, (32, 64)))
self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS))) self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS)))
self.bools = set(filter(legal_bool, self.bools)) self.bools = set(filter(legal_bool, self.bools))
self.bitvecs = interval_to_set(decode_interval(bitvecs,
(1, MAX_BITVEC)))
def copy(self): def copy(self):
# type: (TypeSet) -> TypeSet # type: (TypeSet) -> TypeSet
@@ -188,12 +194,13 @@ class TypeSet(object):
return deepcopy(self) return deepcopy(self)
def typeset_key(self): def typeset_key(self):
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple] # type: () -> Tuple[Tuple, Tuple, Tuple, Tuple, Tuple]
"""Key tuple used for hashing and equality.""" """Key tuple used for hashing and equality."""
return (tuple(sorted(list(self.lanes))), return (tuple(sorted(list(self.lanes))),
tuple(sorted(list(self.ints))), tuple(sorted(list(self.ints))),
tuple(sorted(list(self.floats))), tuple(sorted(list(self.floats))),
tuple(sorted(list(self.bools)))) tuple(sorted(list(self.bools))),
tuple(sorted(list(self.bitvecs))))
def __hash__(self): def __hash__(self):
# type: () -> int # type: () -> int
@@ -222,11 +229,14 @@ class TypeSet(object):
s += ', floats={}'.format(pp_set(self.floats)) s += ', floats={}'.format(pp_set(self.floats))
if len(self.bools) > 0: if len(self.bools) > 0:
s += ', bools={}'.format(pp_set(self.bools)) s += ', bools={}'.format(pp_set(self.bools))
if len(self.bitvecs) > 0:
s += ', bitvecs={}'.format(pp_set(self.bitvecs))
return s + ')' return s + ')'
def emit_fields(self, fmt): def emit_fields(self, fmt):
# type: (Formatter) -> None # type: (Formatter) -> None
"""Emit field initializers for this typeset.""" """Emit field initializers for this typeset."""
assert len(self.bitvecs) == 0, "Bitvector types are not emitable."
fmt.comment(repr(self)) fmt.comment(repr(self))
fields = (('lanes', 16), fields = (('lanes', 16),
@@ -262,6 +272,7 @@ class TypeSet(object):
self.ints.intersection_update(other.ints) self.ints.intersection_update(other.ints)
self.floats.intersection_update(other.floats) self.floats.intersection_update(other.floats)
self.bools.intersection_update(other.bools) self.bools.intersection_update(other.bools)
self.bitvecs.intersection_update(other.bitvecs)
return self return self
@@ -273,7 +284,8 @@ class TypeSet(object):
return self.lanes.issubset(other.lanes) and \ return self.lanes.issubset(other.lanes) and \
self.ints.issubset(other.ints) and \ self.ints.issubset(other.ints) and \
self.floats.issubset(other.floats) and \ self.floats.issubset(other.floats) and \
self.bools.issubset(other.bools) self.bools.issubset(other.bools) and \
self.bitvecs.issubset(other.bitvecs)
def lane_of(self): def lane_of(self):
# type: () -> TypeSet # type: () -> TypeSet
@@ -282,6 +294,7 @@ class TypeSet(object):
""" """
new = self.copy() new = self.copy()
new.lanes = set([1]) new.lanes = set([1])
new.bitvecs = set()
return new return new
def as_bool(self): def as_bool(self):
@@ -292,6 +305,7 @@ class TypeSet(object):
new = self.copy() new = self.copy()
new.ints = set() new.ints = set()
new.floats = set() new.floats = set()
new.bitvecs = set()
if len(self.lanes.difference(set([1]))) > 0: if len(self.lanes.difference(set([1]))) > 0:
new.bools = self.ints.union(self.floats).union(self.bools) new.bools = self.ints.union(self.floats).union(self.bools)
@@ -309,6 +323,7 @@ class TypeSet(object):
new.ints = set([x//2 for x in self.ints if x > 8]) new.ints = set([x//2 for x in self.ints if x > 8])
new.floats = set([x//2 for x in self.floats if x > 32]) new.floats = set([x//2 for x in self.floats if x > 32])
new.bools = set([x//2 for x in self.bools if x > 8]) new.bools = set([x//2 for x in self.bools if x > 8])
new.bitvecs = set([x//2 for x in self.bitvecs if x > 1])
return new return new
@@ -322,6 +337,7 @@ class TypeSet(object):
new.floats = set([x*2 for x in self.floats if x < MAX_BITS]) new.floats = set([x*2 for x in self.floats if x < MAX_BITS])
new.bools = set(filter(legal_bool, new.bools = set(filter(legal_bool,
set([x*2 for x in self.bools if x < MAX_BITS]))) set([x*2 for x in self.bools if x < MAX_BITS])))
new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC])
return new return new
@@ -331,6 +347,7 @@ class TypeSet(object):
Return a TypeSet describing the image of self across halfvector Return a TypeSet describing the image of self across halfvector
""" """
new = self.copy() new = self.copy()
new.bitvecs = set()
new.lanes = set([x//2 for x in self.lanes if x > 1]) new.lanes = set([x//2 for x in self.lanes if x > 1])
return new return new
@@ -341,10 +358,29 @@ class TypeSet(object):
Return a TypeSet describing the image of self across doublevector Return a TypeSet describing the image of self across doublevector
""" """
new = self.copy() new = self.copy()
new.bitvecs = set()
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES]) new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
return new return new
def to_bitvec(self):
# type: () -> TypeSet
"""
Return a TypeSet describing the image of self across to_bitvec
"""
assert len(self.bitvecs) == 0
all_scalars = self.ints.union(self.floats.union(self.bools))
new = self.copy()
new.lanes = set([1])
new.ints = set()
new.bools = set()
new.floats = set()
new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars
for nlanes in self.lanes])
return new
def image(self, func): def image(self, func):
# type: (str) -> TypeSet # type: (str) -> TypeSet
""" """
@@ -362,6 +398,8 @@ class TypeSet(object):
return self.half_vector() return self.half_vector()
elif (func == TypeVar.DOUBLEVECTOR): elif (func == TypeVar.DOUBLEVECTOR):
return self.double_vector() return self.double_vector()
elif (func == TypeVar.TOBITVEC):
return self.to_bitvec()
else: else:
assert False, "Unknown derived function: " + func assert False, "Unknown derived function: " + func
@@ -376,10 +414,12 @@ class TypeSet(object):
if (func == TypeVar.LANEOF): if (func == TypeVar.LANEOF):
new = self.copy() new = self.copy()
new.bitvecs = set()
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)]) new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
return new return new
elif (func == TypeVar.ASBOOL): elif (func == TypeVar.ASBOOL):
new = self.copy() new = self.copy()
new.bitvecs = set()
if 1 not in self.bools: if 1 not in self.bools:
new.ints = self.bools.difference(set([1])) new.ints = self.bools.difference(set([1]))
@@ -400,6 +440,39 @@ class TypeSet(object):
return self.double_vector() return self.double_vector()
elif (func == TypeVar.DOUBLEVECTOR): elif (func == TypeVar.DOUBLEVECTOR):
return self.half_vector() return self.half_vector()
elif (func == TypeVar.TOBITVEC):
new = TypeSet()
# Start with all possible lanes/ints/floats/bools
lanes = interval_to_set(decode_interval(True, (1, MAX_LANES), 1))
ints = interval_to_set(decode_interval(True, (8, MAX_BITS)))
floats = interval_to_set(decode_interval(True, (32, 64)))
bools = interval_to_set(decode_interval(True, (1, MAX_BITS)))
# See which combinations have a size that appears in self.bitvecs
has_t = set() # type: Set[Tuple[str, int, int]]
for l in lanes:
for i in ints:
if i * l in self.bitvecs:
has_t.add(('i', i, l))
for i in bools:
if i * l in self.bitvecs:
has_t.add(('b', i, l))
for i in floats:
if i * l in self.bitvecs:
has_t.add(('f', i, l))
for (t, width, lane) in has_t:
new.lanes.add(lane)
if (t == 'i'):
new.ints.add(width)
elif (t == 'b'):
new.bools.add(width)
else:
assert t == 'f'
new.floats.add(width)
return new
else: else:
assert False, "Unknown derived function: " + func assert False, "Unknown derived function: " + func
@@ -409,7 +482,7 @@ class TypeSet(object):
Return the number of concrete types represented by this typeset Return the number of concrete types represented by this typeset
""" """
return len(self.lanes) * (len(self.ints) + len(self.floats) + return len(self.lanes) * (len(self.ints) + len(self.floats) +
len(self.bools)) len(self.bools) + len(self.bitvecs))
def concrete_types(self): def concrete_types(self):
# type: () -> Iterable[types.ValueType] # type: () -> Iterable[types.ValueType]
@@ -427,6 +500,8 @@ class TypeSet(object):
yield by(types.FloatType.with_bits(bits), nlanes) yield by(types.FloatType.with_bits(bits), nlanes)
for bits in self.bools: for bits in self.bools:
yield by(types.BoolType.with_bits(bits), nlanes) yield by(types.BoolType.with_bits(bits), nlanes)
for bits in self.bitvecs:
yield by(types.BVType.with_bits(bits), nlanes)
def get_singleton(self): def get_singleton(self):
# type: () -> types.ValueType # type: () -> types.ValueType
@@ -458,14 +533,15 @@ class TypeVar(object):
:param scalars: Allow type variable to assume scalar types. :param scalars: Allow type variable to assume scalar types.
:param simd: Allow type variable to assume vector types, or `(min, max)` :param simd: Allow type variable to assume vector types, or `(min, max)`
lane count range. lane count range.
:param bitvecs: Allow all BitVec base types, or `(min, max)` bit-range.
""" """
def __init__( def __init__(
self, name, doc, self, name, doc,
ints=False, floats=False, bools=False, ints=False, floats=False, bools=False,
scalars=True, simd=False, scalars=True, simd=False, bitvecs=False,
base=None, derived_func=None): base=None, derived_func=None):
# type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, TypeVar, str) -> None # noqa # type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, BoolInterval, TypeVar, str) -> None # noqa
self.name = name self.name = name
self.__doc__ = doc self.__doc__ = doc
self.is_derived = isinstance(base, TypeVar) self.is_derived = isinstance(base, TypeVar)
@@ -482,7 +558,8 @@ class TypeVar(object):
lanes=lanes, lanes=lanes,
ints=ints, ints=ints,
floats=floats, floats=floats,
bools=bools) bools=bools,
bitvecs=bitvecs)
@staticmethod @staticmethod
def singleton(typ): def singleton(typ):
@@ -498,6 +575,7 @@ class TypeVar(object):
ints = None ints = None
floats = None floats = None
bools = None bools = None
bitvecs = None
if isinstance(scalar, types.IntType): if isinstance(scalar, types.IntType):
ints = (scalar.bits, scalar.bits) ints = (scalar.bits, scalar.bits)
@@ -505,10 +583,13 @@ class TypeVar(object):
floats = (scalar.bits, scalar.bits) floats = (scalar.bits, scalar.bits)
elif isinstance(scalar, types.BoolType): elif isinstance(scalar, types.BoolType):
bools = (scalar.bits, scalar.bits) bools = (scalar.bits, scalar.bits)
elif isinstance(scalar, types.BVType):
bitvecs = (scalar.bits, scalar.bits)
tv = TypeVar( tv = TypeVar(
typ.name, typ.__doc__, typ.name, typ.__doc__,
ints, floats, bools, simd=lanes) ints=ints, floats=floats, bools=bools,
bitvecs=bitvecs, simd=lanes)
return tv return tv
def __str__(self): def __str__(self):
@@ -558,6 +639,7 @@ class TypeVar(object):
DOUBLEWIDTH = 'double_width' DOUBLEWIDTH = 'double_width'
HALFVECTOR = 'half_vector' HALFVECTOR = 'half_vector'
DOUBLEVECTOR = 'double_vector' DOUBLEVECTOR = 'double_vector'
TOBITVEC = 'to_bitvec'
@staticmethod @staticmethod
def is_bijection(func): def is_bijection(func):
@@ -668,6 +750,14 @@ class TypeVar(object):
""" """
return TypeVar.derived(self, self.DOUBLEVECTOR) return TypeVar.derived(self, self.DOUBLEVECTOR)
def to_bitvec(self):
# type: () -> TypeVar
"""
Return a derived type variable that represent a flat bitvector with
the same size as self
"""
return TypeVar.derived(self, self.TOBITVEC)
def singleton_type(self): def singleton_type(self):
# type: () -> ValueType # type: () -> ValueType
""" """