Add the BVType; Add suport for bitvectors in TypeVar and TypeSet.
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
15a7d50765
commit
bd2e9e5d0b
@@ -239,7 +239,7 @@ class Var(Expr):
|
||||
'typeof_{}'.format(self),
|
||||
'Type of the pattern variable `{}`'.format(self),
|
||||
ints=True, floats=True, bools=True,
|
||||
scalars=True, simd=True)
|
||||
scalars=True, simd=True, bitvecs=True)
|
||||
self.original_typevar = tv
|
||||
self.typevar = tv
|
||||
return self.typevar
|
||||
|
||||
@@ -84,7 +84,10 @@ class ScalarType(ValueType):
|
||||
self._vectors = dict() # type: Dict[int, VectorType]
|
||||
# Assign numbers starting from 1. (0 is VOID).
|
||||
ValueType.all_scalars.append(self)
|
||||
self.number = len(ValueType.all_scalars)
|
||||
# Numbers are only valid for Cretone types that get emitted to Rust.
|
||||
# This excludes BVTypes
|
||||
self.number = len([x for x in ValueType.all_scalars
|
||||
if not isinstance(x, BVType)])
|
||||
assert self.number < 16, 'Too many scalar types'
|
||||
|
||||
def __repr__(self):
|
||||
@@ -239,3 +242,34 @@ class BoolType(ScalarType):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
return self.bits
|
||||
|
||||
|
||||
class BVType(ScalarType):
|
||||
"""A flat bitvector type. Used for semantics description only."""
|
||||
|
||||
def __init__(self, bits):
|
||||
# type: (int) -> None
|
||||
assert bits > 0, 'Must have positive number of bits'
|
||||
super(BVType, self).__init__(
|
||||
name='bv{:d}'.format(bits),
|
||||
membytes=bits // 8,
|
||||
doc="A bitvector type with {} bits.".format(bits))
|
||||
self.bits = bits
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return 'BVType(bits={})'.format(self.bits)
|
||||
|
||||
@staticmethod
|
||||
def with_bits(bits):
|
||||
# type: (int) -> BVType
|
||||
typ = ValueType.by_name('bv{:d}'.format(bits))
|
||||
if TYPE_CHECKING:
|
||||
return cast(BVType, typ)
|
||||
else:
|
||||
return typ
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
return self.bits
|
||||
|
||||
@@ -22,6 +22,7 @@ except ImportError:
|
||||
|
||||
MAX_LANES = 256
|
||||
MAX_BITS = 64
|
||||
MAX_BITVEC = MAX_BITS * MAX_LANES
|
||||
|
||||
|
||||
def int_log2(x):
|
||||
@@ -169,15 +170,20 @@ class TypeSet(object):
|
||||
point widths.
|
||||
:param bools: `(min, max)` inclusive range of permitted scalar boolean
|
||||
widths.
|
||||
:param bitvecs : `(min, max)` inclusive range of permitted bitvector
|
||||
widths.
|
||||
"""
|
||||
|
||||
def __init__(self, lanes=None, ints=None, floats=None, bools=None):
|
||||
# type: (BoolInterval, BoolInterval, BoolInterval, BoolInterval) -> None # noqa
|
||||
def __init__(self, lanes=None, ints=None, floats=None, bools=None,
|
||||
bitvecs=None):
|
||||
# type: (BoolInterval, BoolInterval, BoolInterval, BoolInterval, BoolInterval) -> None # noqa
|
||||
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
|
||||
self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS)))
|
||||
self.floats = interval_to_set(decode_interval(floats, (32, 64)))
|
||||
self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS)))
|
||||
self.bools = set(filter(legal_bool, self.bools))
|
||||
self.bitvecs = interval_to_set(decode_interval(bitvecs,
|
||||
(1, MAX_BITVEC)))
|
||||
|
||||
def copy(self):
|
||||
# type: (TypeSet) -> TypeSet
|
||||
@@ -188,12 +194,13 @@ class TypeSet(object):
|
||||
return deepcopy(self)
|
||||
|
||||
def typeset_key(self):
|
||||
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple]
|
||||
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple, Tuple]
|
||||
"""Key tuple used for hashing and equality."""
|
||||
return (tuple(sorted(list(self.lanes))),
|
||||
tuple(sorted(list(self.ints))),
|
||||
tuple(sorted(list(self.floats))),
|
||||
tuple(sorted(list(self.bools))))
|
||||
tuple(sorted(list(self.bools))),
|
||||
tuple(sorted(list(self.bitvecs))))
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
@@ -222,11 +229,14 @@ class TypeSet(object):
|
||||
s += ', floats={}'.format(pp_set(self.floats))
|
||||
if len(self.bools) > 0:
|
||||
s += ', bools={}'.format(pp_set(self.bools))
|
||||
if len(self.bitvecs) > 0:
|
||||
s += ', bitvecs={}'.format(pp_set(self.bitvecs))
|
||||
return s + ')'
|
||||
|
||||
def emit_fields(self, fmt):
|
||||
# type: (Formatter) -> None
|
||||
"""Emit field initializers for this typeset."""
|
||||
assert len(self.bitvecs) == 0, "Bitvector types are not emitable."
|
||||
fmt.comment(repr(self))
|
||||
|
||||
fields = (('lanes', 16),
|
||||
@@ -262,6 +272,7 @@ class TypeSet(object):
|
||||
self.ints.intersection_update(other.ints)
|
||||
self.floats.intersection_update(other.floats)
|
||||
self.bools.intersection_update(other.bools)
|
||||
self.bitvecs.intersection_update(other.bitvecs)
|
||||
|
||||
return self
|
||||
|
||||
@@ -273,7 +284,8 @@ class TypeSet(object):
|
||||
return self.lanes.issubset(other.lanes) and \
|
||||
self.ints.issubset(other.ints) and \
|
||||
self.floats.issubset(other.floats) and \
|
||||
self.bools.issubset(other.bools)
|
||||
self.bools.issubset(other.bools) and \
|
||||
self.bitvecs.issubset(other.bitvecs)
|
||||
|
||||
def lane_of(self):
|
||||
# type: () -> TypeSet
|
||||
@@ -282,6 +294,7 @@ class TypeSet(object):
|
||||
"""
|
||||
new = self.copy()
|
||||
new.lanes = set([1])
|
||||
new.bitvecs = set()
|
||||
return new
|
||||
|
||||
def as_bool(self):
|
||||
@@ -292,6 +305,7 @@ class TypeSet(object):
|
||||
new = self.copy()
|
||||
new.ints = set()
|
||||
new.floats = set()
|
||||
new.bitvecs = set()
|
||||
|
||||
if len(self.lanes.difference(set([1]))) > 0:
|
||||
new.bools = self.ints.union(self.floats).union(self.bools)
|
||||
@@ -309,6 +323,7 @@ class TypeSet(object):
|
||||
new.ints = set([x//2 for x in self.ints if x > 8])
|
||||
new.floats = set([x//2 for x in self.floats if x > 32])
|
||||
new.bools = set([x//2 for x in self.bools if x > 8])
|
||||
new.bitvecs = set([x//2 for x in self.bitvecs if x > 1])
|
||||
|
||||
return new
|
||||
|
||||
@@ -322,6 +337,7 @@ class TypeSet(object):
|
||||
new.floats = set([x*2 for x in self.floats if x < MAX_BITS])
|
||||
new.bools = set(filter(legal_bool,
|
||||
set([x*2 for x in self.bools if x < MAX_BITS])))
|
||||
new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC])
|
||||
|
||||
return new
|
||||
|
||||
@@ -331,6 +347,7 @@ class TypeSet(object):
|
||||
Return a TypeSet describing the image of self across halfvector
|
||||
"""
|
||||
new = self.copy()
|
||||
new.bitvecs = set()
|
||||
new.lanes = set([x//2 for x in self.lanes if x > 1])
|
||||
|
||||
return new
|
||||
@@ -341,10 +358,29 @@ class TypeSet(object):
|
||||
Return a TypeSet describing the image of self across doublevector
|
||||
"""
|
||||
new = self.copy()
|
||||
new.bitvecs = set()
|
||||
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
|
||||
|
||||
return new
|
||||
|
||||
def to_bitvec(self):
|
||||
# type: () -> TypeSet
|
||||
"""
|
||||
Return a TypeSet describing the image of self across to_bitvec
|
||||
"""
|
||||
assert len(self.bitvecs) == 0
|
||||
all_scalars = self.ints.union(self.floats.union(self.bools))
|
||||
|
||||
new = self.copy()
|
||||
new.lanes = set([1])
|
||||
new.ints = set()
|
||||
new.bools = set()
|
||||
new.floats = set()
|
||||
new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars
|
||||
for nlanes in self.lanes])
|
||||
|
||||
return new
|
||||
|
||||
def image(self, func):
|
||||
# type: (str) -> TypeSet
|
||||
"""
|
||||
@@ -362,6 +398,8 @@ class TypeSet(object):
|
||||
return self.half_vector()
|
||||
elif (func == TypeVar.DOUBLEVECTOR):
|
||||
return self.double_vector()
|
||||
elif (func == TypeVar.TOBITVEC):
|
||||
return self.to_bitvec()
|
||||
else:
|
||||
assert False, "Unknown derived function: " + func
|
||||
|
||||
@@ -376,10 +414,12 @@ class TypeSet(object):
|
||||
|
||||
if (func == TypeVar.LANEOF):
|
||||
new = self.copy()
|
||||
new.bitvecs = set()
|
||||
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
|
||||
return new
|
||||
elif (func == TypeVar.ASBOOL):
|
||||
new = self.copy()
|
||||
new.bitvecs = set()
|
||||
|
||||
if 1 not in self.bools:
|
||||
new.ints = self.bools.difference(set([1]))
|
||||
@@ -400,6 +440,39 @@ class TypeSet(object):
|
||||
return self.double_vector()
|
||||
elif (func == TypeVar.DOUBLEVECTOR):
|
||||
return self.half_vector()
|
||||
elif (func == TypeVar.TOBITVEC):
|
||||
new = TypeSet()
|
||||
|
||||
# Start with all possible lanes/ints/floats/bools
|
||||
lanes = interval_to_set(decode_interval(True, (1, MAX_LANES), 1))
|
||||
ints = interval_to_set(decode_interval(True, (8, MAX_BITS)))
|
||||
floats = interval_to_set(decode_interval(True, (32, 64)))
|
||||
bools = interval_to_set(decode_interval(True, (1, MAX_BITS)))
|
||||
|
||||
# See which combinations have a size that appears in self.bitvecs
|
||||
has_t = set() # type: Set[Tuple[str, int, int]]
|
||||
for l in lanes:
|
||||
for i in ints:
|
||||
if i * l in self.bitvecs:
|
||||
has_t.add(('i', i, l))
|
||||
for i in bools:
|
||||
if i * l in self.bitvecs:
|
||||
has_t.add(('b', i, l))
|
||||
for i in floats:
|
||||
if i * l in self.bitvecs:
|
||||
has_t.add(('f', i, l))
|
||||
|
||||
for (t, width, lane) in has_t:
|
||||
new.lanes.add(lane)
|
||||
if (t == 'i'):
|
||||
new.ints.add(width)
|
||||
elif (t == 'b'):
|
||||
new.bools.add(width)
|
||||
else:
|
||||
assert t == 'f'
|
||||
new.floats.add(width)
|
||||
|
||||
return new
|
||||
else:
|
||||
assert False, "Unknown derived function: " + func
|
||||
|
||||
@@ -409,7 +482,7 @@ class TypeSet(object):
|
||||
Return the number of concrete types represented by this typeset
|
||||
"""
|
||||
return len(self.lanes) * (len(self.ints) + len(self.floats) +
|
||||
len(self.bools))
|
||||
len(self.bools) + len(self.bitvecs))
|
||||
|
||||
def concrete_types(self):
|
||||
# type: () -> Iterable[types.ValueType]
|
||||
@@ -427,6 +500,8 @@ class TypeSet(object):
|
||||
yield by(types.FloatType.with_bits(bits), nlanes)
|
||||
for bits in self.bools:
|
||||
yield by(types.BoolType.with_bits(bits), nlanes)
|
||||
for bits in self.bitvecs:
|
||||
yield by(types.BVType.with_bits(bits), nlanes)
|
||||
|
||||
def get_singleton(self):
|
||||
# type: () -> types.ValueType
|
||||
@@ -458,14 +533,15 @@ class TypeVar(object):
|
||||
:param scalars: Allow type variable to assume scalar types.
|
||||
:param simd: Allow type variable to assume vector types, or `(min, max)`
|
||||
lane count range.
|
||||
:param bitvecs: Allow all BitVec base types, or `(min, max)` bit-range.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name, doc,
|
||||
ints=False, floats=False, bools=False,
|
||||
scalars=True, simd=False,
|
||||
scalars=True, simd=False, bitvecs=False,
|
||||
base=None, derived_func=None):
|
||||
# type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, TypeVar, str) -> None # noqa
|
||||
# type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, BoolInterval, TypeVar, str) -> None # noqa
|
||||
self.name = name
|
||||
self.__doc__ = doc
|
||||
self.is_derived = isinstance(base, TypeVar)
|
||||
@@ -482,7 +558,8 @@ class TypeVar(object):
|
||||
lanes=lanes,
|
||||
ints=ints,
|
||||
floats=floats,
|
||||
bools=bools)
|
||||
bools=bools,
|
||||
bitvecs=bitvecs)
|
||||
|
||||
@staticmethod
|
||||
def singleton(typ):
|
||||
@@ -498,6 +575,7 @@ class TypeVar(object):
|
||||
ints = None
|
||||
floats = None
|
||||
bools = None
|
||||
bitvecs = None
|
||||
|
||||
if isinstance(scalar, types.IntType):
|
||||
ints = (scalar.bits, scalar.bits)
|
||||
@@ -505,10 +583,13 @@ class TypeVar(object):
|
||||
floats = (scalar.bits, scalar.bits)
|
||||
elif isinstance(scalar, types.BoolType):
|
||||
bools = (scalar.bits, scalar.bits)
|
||||
elif isinstance(scalar, types.BVType):
|
||||
bitvecs = (scalar.bits, scalar.bits)
|
||||
|
||||
tv = TypeVar(
|
||||
typ.name, typ.__doc__,
|
||||
ints, floats, bools, simd=lanes)
|
||||
ints=ints, floats=floats, bools=bools,
|
||||
bitvecs=bitvecs, simd=lanes)
|
||||
return tv
|
||||
|
||||
def __str__(self):
|
||||
@@ -558,6 +639,7 @@ class TypeVar(object):
|
||||
DOUBLEWIDTH = 'double_width'
|
||||
HALFVECTOR = 'half_vector'
|
||||
DOUBLEVECTOR = 'double_vector'
|
||||
TOBITVEC = 'to_bitvec'
|
||||
|
||||
@staticmethod
|
||||
def is_bijection(func):
|
||||
@@ -668,6 +750,14 @@ class TypeVar(object):
|
||||
"""
|
||||
return TypeVar.derived(self, self.DOUBLEVECTOR)
|
||||
|
||||
def to_bitvec(self):
|
||||
# type: () -> TypeVar
|
||||
"""
|
||||
Return a derived type variable that represent a flat bitvector with
|
||||
the same size as self
|
||||
"""
|
||||
return TypeVar.derived(self, self.TOBITVEC)
|
||||
|
||||
def singleton_type(self):
|
||||
# type: () -> ValueType
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user