Add the BVType; Add suport for bitvectors in TypeVar and TypeSet.
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
605886a277
commit
a12fa86e60
@@ -239,7 +239,7 @@ class Var(Expr):
|
|||||||
'typeof_{}'.format(self),
|
'typeof_{}'.format(self),
|
||||||
'Type of the pattern variable `{}`'.format(self),
|
'Type of the pattern variable `{}`'.format(self),
|
||||||
ints=True, floats=True, bools=True,
|
ints=True, floats=True, bools=True,
|
||||||
scalars=True, simd=True)
|
scalars=True, simd=True, bitvecs=True)
|
||||||
self.original_typevar = tv
|
self.original_typevar = tv
|
||||||
self.typevar = tv
|
self.typevar = tv
|
||||||
return self.typevar
|
return self.typevar
|
||||||
|
|||||||
@@ -84,7 +84,10 @@ class ScalarType(ValueType):
|
|||||||
self._vectors = dict() # type: Dict[int, VectorType]
|
self._vectors = dict() # type: Dict[int, VectorType]
|
||||||
# Assign numbers starting from 1. (0 is VOID).
|
# Assign numbers starting from 1. (0 is VOID).
|
||||||
ValueType.all_scalars.append(self)
|
ValueType.all_scalars.append(self)
|
||||||
self.number = len(ValueType.all_scalars)
|
# Numbers are only valid for Cretone types that get emitted to Rust.
|
||||||
|
# This excludes BVTypes
|
||||||
|
self.number = len([x for x in ValueType.all_scalars
|
||||||
|
if not isinstance(x, BVType)])
|
||||||
assert self.number < 16, 'Too many scalar types'
|
assert self.number < 16, 'Too many scalar types'
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -239,3 +242,34 @@ class BoolType(ScalarType):
|
|||||||
# type: () -> int
|
# type: () -> int
|
||||||
"""Return the number of bits in a lane."""
|
"""Return the number of bits in a lane."""
|
||||||
return self.bits
|
return self.bits
|
||||||
|
|
||||||
|
|
||||||
|
class BVType(ScalarType):
|
||||||
|
"""A flat bitvector type. Used for semantics description only."""
|
||||||
|
|
||||||
|
def __init__(self, bits):
|
||||||
|
# type: (int) -> None
|
||||||
|
assert bits > 0, 'Must have positive number of bits'
|
||||||
|
super(BVType, self).__init__(
|
||||||
|
name='bv{:d}'.format(bits),
|
||||||
|
membytes=bits // 8,
|
||||||
|
doc="A bitvector type with {} bits.".format(bits))
|
||||||
|
self.bits = bits
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
# type: () -> str
|
||||||
|
return 'BVType(bits={})'.format(self.bits)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def with_bits(bits):
|
||||||
|
# type: (int) -> BVType
|
||||||
|
typ = ValueType.by_name('bv{:d}'.format(bits))
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
return cast(BVType, typ)
|
||||||
|
else:
|
||||||
|
return typ
|
||||||
|
|
||||||
|
def lane_bits(self):
|
||||||
|
# type: () -> int
|
||||||
|
"""Return the number of bits in a lane."""
|
||||||
|
return self.bits
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ except ImportError:
|
|||||||
|
|
||||||
MAX_LANES = 256
|
MAX_LANES = 256
|
||||||
MAX_BITS = 64
|
MAX_BITS = 64
|
||||||
|
MAX_BITVEC = MAX_BITS * MAX_LANES
|
||||||
|
|
||||||
|
|
||||||
def int_log2(x):
|
def int_log2(x):
|
||||||
@@ -169,15 +170,20 @@ class TypeSet(object):
|
|||||||
point widths.
|
point widths.
|
||||||
:param bools: `(min, max)` inclusive range of permitted scalar boolean
|
:param bools: `(min, max)` inclusive range of permitted scalar boolean
|
||||||
widths.
|
widths.
|
||||||
|
:param bitvecs : `(min, max)` inclusive range of permitted bitvector
|
||||||
|
widths.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lanes=None, ints=None, floats=None, bools=None):
|
def __init__(self, lanes=None, ints=None, floats=None, bools=None,
|
||||||
# type: (BoolInterval, BoolInterval, BoolInterval, BoolInterval) -> None # noqa
|
bitvecs=None):
|
||||||
|
# type: (BoolInterval, BoolInterval, BoolInterval, BoolInterval, BoolInterval) -> None # noqa
|
||||||
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
|
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
|
||||||
self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS)))
|
self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS)))
|
||||||
self.floats = interval_to_set(decode_interval(floats, (32, 64)))
|
self.floats = interval_to_set(decode_interval(floats, (32, 64)))
|
||||||
self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS)))
|
self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS)))
|
||||||
self.bools = set(filter(legal_bool, self.bools))
|
self.bools = set(filter(legal_bool, self.bools))
|
||||||
|
self.bitvecs = interval_to_set(decode_interval(bitvecs,
|
||||||
|
(1, MAX_BITVEC)))
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
# type: (TypeSet) -> TypeSet
|
# type: (TypeSet) -> TypeSet
|
||||||
@@ -188,12 +194,13 @@ class TypeSet(object):
|
|||||||
return deepcopy(self)
|
return deepcopy(self)
|
||||||
|
|
||||||
def typeset_key(self):
|
def typeset_key(self):
|
||||||
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple]
|
# type: () -> Tuple[Tuple, Tuple, Tuple, Tuple, Tuple]
|
||||||
"""Key tuple used for hashing and equality."""
|
"""Key tuple used for hashing and equality."""
|
||||||
return (tuple(sorted(list(self.lanes))),
|
return (tuple(sorted(list(self.lanes))),
|
||||||
tuple(sorted(list(self.ints))),
|
tuple(sorted(list(self.ints))),
|
||||||
tuple(sorted(list(self.floats))),
|
tuple(sorted(list(self.floats))),
|
||||||
tuple(sorted(list(self.bools))))
|
tuple(sorted(list(self.bools))),
|
||||||
|
tuple(sorted(list(self.bitvecs))))
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
# type: () -> int
|
# type: () -> int
|
||||||
@@ -222,11 +229,14 @@ class TypeSet(object):
|
|||||||
s += ', floats={}'.format(pp_set(self.floats))
|
s += ', floats={}'.format(pp_set(self.floats))
|
||||||
if len(self.bools) > 0:
|
if len(self.bools) > 0:
|
||||||
s += ', bools={}'.format(pp_set(self.bools))
|
s += ', bools={}'.format(pp_set(self.bools))
|
||||||
|
if len(self.bitvecs) > 0:
|
||||||
|
s += ', bitvecs={}'.format(pp_set(self.bitvecs))
|
||||||
return s + ')'
|
return s + ')'
|
||||||
|
|
||||||
def emit_fields(self, fmt):
|
def emit_fields(self, fmt):
|
||||||
# type: (Formatter) -> None
|
# type: (Formatter) -> None
|
||||||
"""Emit field initializers for this typeset."""
|
"""Emit field initializers for this typeset."""
|
||||||
|
assert len(self.bitvecs) == 0, "Bitvector types are not emitable."
|
||||||
fmt.comment(repr(self))
|
fmt.comment(repr(self))
|
||||||
|
|
||||||
fields = (('lanes', 16),
|
fields = (('lanes', 16),
|
||||||
@@ -262,6 +272,7 @@ class TypeSet(object):
|
|||||||
self.ints.intersection_update(other.ints)
|
self.ints.intersection_update(other.ints)
|
||||||
self.floats.intersection_update(other.floats)
|
self.floats.intersection_update(other.floats)
|
||||||
self.bools.intersection_update(other.bools)
|
self.bools.intersection_update(other.bools)
|
||||||
|
self.bitvecs.intersection_update(other.bitvecs)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -273,7 +284,8 @@ class TypeSet(object):
|
|||||||
return self.lanes.issubset(other.lanes) and \
|
return self.lanes.issubset(other.lanes) and \
|
||||||
self.ints.issubset(other.ints) and \
|
self.ints.issubset(other.ints) and \
|
||||||
self.floats.issubset(other.floats) and \
|
self.floats.issubset(other.floats) and \
|
||||||
self.bools.issubset(other.bools)
|
self.bools.issubset(other.bools) and \
|
||||||
|
self.bitvecs.issubset(other.bitvecs)
|
||||||
|
|
||||||
def lane_of(self):
|
def lane_of(self):
|
||||||
# type: () -> TypeSet
|
# type: () -> TypeSet
|
||||||
@@ -282,6 +294,7 @@ class TypeSet(object):
|
|||||||
"""
|
"""
|
||||||
new = self.copy()
|
new = self.copy()
|
||||||
new.lanes = set([1])
|
new.lanes = set([1])
|
||||||
|
new.bitvecs = set()
|
||||||
return new
|
return new
|
||||||
|
|
||||||
def as_bool(self):
|
def as_bool(self):
|
||||||
@@ -292,6 +305,7 @@ class TypeSet(object):
|
|||||||
new = self.copy()
|
new = self.copy()
|
||||||
new.ints = set()
|
new.ints = set()
|
||||||
new.floats = set()
|
new.floats = set()
|
||||||
|
new.bitvecs = set()
|
||||||
|
|
||||||
if len(self.lanes.difference(set([1]))) > 0:
|
if len(self.lanes.difference(set([1]))) > 0:
|
||||||
new.bools = self.ints.union(self.floats).union(self.bools)
|
new.bools = self.ints.union(self.floats).union(self.bools)
|
||||||
@@ -309,6 +323,7 @@ class TypeSet(object):
|
|||||||
new.ints = set([x//2 for x in self.ints if x > 8])
|
new.ints = set([x//2 for x in self.ints if x > 8])
|
||||||
new.floats = set([x//2 for x in self.floats if x > 32])
|
new.floats = set([x//2 for x in self.floats if x > 32])
|
||||||
new.bools = set([x//2 for x in self.bools if x > 8])
|
new.bools = set([x//2 for x in self.bools if x > 8])
|
||||||
|
new.bitvecs = set([x//2 for x in self.bitvecs if x > 1])
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@@ -322,6 +337,7 @@ class TypeSet(object):
|
|||||||
new.floats = set([x*2 for x in self.floats if x < MAX_BITS])
|
new.floats = set([x*2 for x in self.floats if x < MAX_BITS])
|
||||||
new.bools = set(filter(legal_bool,
|
new.bools = set(filter(legal_bool,
|
||||||
set([x*2 for x in self.bools if x < MAX_BITS])))
|
set([x*2 for x in self.bools if x < MAX_BITS])))
|
||||||
|
new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC])
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@@ -331,6 +347,7 @@ class TypeSet(object):
|
|||||||
Return a TypeSet describing the image of self across halfvector
|
Return a TypeSet describing the image of self across halfvector
|
||||||
"""
|
"""
|
||||||
new = self.copy()
|
new = self.copy()
|
||||||
|
new.bitvecs = set()
|
||||||
new.lanes = set([x//2 for x in self.lanes if x > 1])
|
new.lanes = set([x//2 for x in self.lanes if x > 1])
|
||||||
|
|
||||||
return new
|
return new
|
||||||
@@ -341,10 +358,29 @@ class TypeSet(object):
|
|||||||
Return a TypeSet describing the image of self across doublevector
|
Return a TypeSet describing the image of self across doublevector
|
||||||
"""
|
"""
|
||||||
new = self.copy()
|
new = self.copy()
|
||||||
|
new.bitvecs = set()
|
||||||
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
|
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
def to_bitvec(self):
|
||||||
|
# type: () -> TypeSet
|
||||||
|
"""
|
||||||
|
Return a TypeSet describing the image of self across to_bitvec
|
||||||
|
"""
|
||||||
|
assert len(self.bitvecs) == 0
|
||||||
|
all_scalars = self.ints.union(self.floats.union(self.bools))
|
||||||
|
|
||||||
|
new = self.copy()
|
||||||
|
new.lanes = set([1])
|
||||||
|
new.ints = set()
|
||||||
|
new.bools = set()
|
||||||
|
new.floats = set()
|
||||||
|
new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars
|
||||||
|
for nlanes in self.lanes])
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
def image(self, func):
|
def image(self, func):
|
||||||
# type: (str) -> TypeSet
|
# type: (str) -> TypeSet
|
||||||
"""
|
"""
|
||||||
@@ -362,6 +398,8 @@ class TypeSet(object):
|
|||||||
return self.half_vector()
|
return self.half_vector()
|
||||||
elif (func == TypeVar.DOUBLEVECTOR):
|
elif (func == TypeVar.DOUBLEVECTOR):
|
||||||
return self.double_vector()
|
return self.double_vector()
|
||||||
|
elif (func == TypeVar.TOBITVEC):
|
||||||
|
return self.to_bitvec()
|
||||||
else:
|
else:
|
||||||
assert False, "Unknown derived function: " + func
|
assert False, "Unknown derived function: " + func
|
||||||
|
|
||||||
@@ -376,10 +414,12 @@ class TypeSet(object):
|
|||||||
|
|
||||||
if (func == TypeVar.LANEOF):
|
if (func == TypeVar.LANEOF):
|
||||||
new = self.copy()
|
new = self.copy()
|
||||||
|
new.bitvecs = set()
|
||||||
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
|
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
|
||||||
return new
|
return new
|
||||||
elif (func == TypeVar.ASBOOL):
|
elif (func == TypeVar.ASBOOL):
|
||||||
new = self.copy()
|
new = self.copy()
|
||||||
|
new.bitvecs = set()
|
||||||
|
|
||||||
if 1 not in self.bools:
|
if 1 not in self.bools:
|
||||||
new.ints = self.bools.difference(set([1]))
|
new.ints = self.bools.difference(set([1]))
|
||||||
@@ -400,6 +440,39 @@ class TypeSet(object):
|
|||||||
return self.double_vector()
|
return self.double_vector()
|
||||||
elif (func == TypeVar.DOUBLEVECTOR):
|
elif (func == TypeVar.DOUBLEVECTOR):
|
||||||
return self.half_vector()
|
return self.half_vector()
|
||||||
|
elif (func == TypeVar.TOBITVEC):
|
||||||
|
new = TypeSet()
|
||||||
|
|
||||||
|
# Start with all possible lanes/ints/floats/bools
|
||||||
|
lanes = interval_to_set(decode_interval(True, (1, MAX_LANES), 1))
|
||||||
|
ints = interval_to_set(decode_interval(True, (8, MAX_BITS)))
|
||||||
|
floats = interval_to_set(decode_interval(True, (32, 64)))
|
||||||
|
bools = interval_to_set(decode_interval(True, (1, MAX_BITS)))
|
||||||
|
|
||||||
|
# See which combinations have a size that appears in self.bitvecs
|
||||||
|
has_t = set() # type: Set[Tuple[str, int, int]]
|
||||||
|
for l in lanes:
|
||||||
|
for i in ints:
|
||||||
|
if i * l in self.bitvecs:
|
||||||
|
has_t.add(('i', i, l))
|
||||||
|
for i in bools:
|
||||||
|
if i * l in self.bitvecs:
|
||||||
|
has_t.add(('b', i, l))
|
||||||
|
for i in floats:
|
||||||
|
if i * l in self.bitvecs:
|
||||||
|
has_t.add(('f', i, l))
|
||||||
|
|
||||||
|
for (t, width, lane) in has_t:
|
||||||
|
new.lanes.add(lane)
|
||||||
|
if (t == 'i'):
|
||||||
|
new.ints.add(width)
|
||||||
|
elif (t == 'b'):
|
||||||
|
new.bools.add(width)
|
||||||
|
else:
|
||||||
|
assert t == 'f'
|
||||||
|
new.floats.add(width)
|
||||||
|
|
||||||
|
return new
|
||||||
else:
|
else:
|
||||||
assert False, "Unknown derived function: " + func
|
assert False, "Unknown derived function: " + func
|
||||||
|
|
||||||
@@ -409,7 +482,7 @@ class TypeSet(object):
|
|||||||
Return the number of concrete types represented by this typeset
|
Return the number of concrete types represented by this typeset
|
||||||
"""
|
"""
|
||||||
return len(self.lanes) * (len(self.ints) + len(self.floats) +
|
return len(self.lanes) * (len(self.ints) + len(self.floats) +
|
||||||
len(self.bools))
|
len(self.bools) + len(self.bitvecs))
|
||||||
|
|
||||||
def concrete_types(self):
|
def concrete_types(self):
|
||||||
# type: () -> Iterable[types.ValueType]
|
# type: () -> Iterable[types.ValueType]
|
||||||
@@ -427,6 +500,8 @@ class TypeSet(object):
|
|||||||
yield by(types.FloatType.with_bits(bits), nlanes)
|
yield by(types.FloatType.with_bits(bits), nlanes)
|
||||||
for bits in self.bools:
|
for bits in self.bools:
|
||||||
yield by(types.BoolType.with_bits(bits), nlanes)
|
yield by(types.BoolType.with_bits(bits), nlanes)
|
||||||
|
for bits in self.bitvecs:
|
||||||
|
yield by(types.BVType.with_bits(bits), nlanes)
|
||||||
|
|
||||||
def get_singleton(self):
|
def get_singleton(self):
|
||||||
# type: () -> types.ValueType
|
# type: () -> types.ValueType
|
||||||
@@ -458,14 +533,15 @@ class TypeVar(object):
|
|||||||
:param scalars: Allow type variable to assume scalar types.
|
:param scalars: Allow type variable to assume scalar types.
|
||||||
:param simd: Allow type variable to assume vector types, or `(min, max)`
|
:param simd: Allow type variable to assume vector types, or `(min, max)`
|
||||||
lane count range.
|
lane count range.
|
||||||
|
:param bitvecs: Allow all BitVec base types, or `(min, max)` bit-range.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name, doc,
|
self, name, doc,
|
||||||
ints=False, floats=False, bools=False,
|
ints=False, floats=False, bools=False,
|
||||||
scalars=True, simd=False,
|
scalars=True, simd=False, bitvecs=False,
|
||||||
base=None, derived_func=None):
|
base=None, derived_func=None):
|
||||||
# type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, TypeVar, str) -> None # noqa
|
# type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, BoolInterval, TypeVar, str) -> None # noqa
|
||||||
self.name = name
|
self.name = name
|
||||||
self.__doc__ = doc
|
self.__doc__ = doc
|
||||||
self.is_derived = isinstance(base, TypeVar)
|
self.is_derived = isinstance(base, TypeVar)
|
||||||
@@ -482,7 +558,8 @@ class TypeVar(object):
|
|||||||
lanes=lanes,
|
lanes=lanes,
|
||||||
ints=ints,
|
ints=ints,
|
||||||
floats=floats,
|
floats=floats,
|
||||||
bools=bools)
|
bools=bools,
|
||||||
|
bitvecs=bitvecs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def singleton(typ):
|
def singleton(typ):
|
||||||
@@ -498,6 +575,7 @@ class TypeVar(object):
|
|||||||
ints = None
|
ints = None
|
||||||
floats = None
|
floats = None
|
||||||
bools = None
|
bools = None
|
||||||
|
bitvecs = None
|
||||||
|
|
||||||
if isinstance(scalar, types.IntType):
|
if isinstance(scalar, types.IntType):
|
||||||
ints = (scalar.bits, scalar.bits)
|
ints = (scalar.bits, scalar.bits)
|
||||||
@@ -505,10 +583,13 @@ class TypeVar(object):
|
|||||||
floats = (scalar.bits, scalar.bits)
|
floats = (scalar.bits, scalar.bits)
|
||||||
elif isinstance(scalar, types.BoolType):
|
elif isinstance(scalar, types.BoolType):
|
||||||
bools = (scalar.bits, scalar.bits)
|
bools = (scalar.bits, scalar.bits)
|
||||||
|
elif isinstance(scalar, types.BVType):
|
||||||
|
bitvecs = (scalar.bits, scalar.bits)
|
||||||
|
|
||||||
tv = TypeVar(
|
tv = TypeVar(
|
||||||
typ.name, typ.__doc__,
|
typ.name, typ.__doc__,
|
||||||
ints, floats, bools, simd=lanes)
|
ints=ints, floats=floats, bools=bools,
|
||||||
|
bitvecs=bitvecs, simd=lanes)
|
||||||
return tv
|
return tv
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -558,6 +639,7 @@ class TypeVar(object):
|
|||||||
DOUBLEWIDTH = 'double_width'
|
DOUBLEWIDTH = 'double_width'
|
||||||
HALFVECTOR = 'half_vector'
|
HALFVECTOR = 'half_vector'
|
||||||
DOUBLEVECTOR = 'double_vector'
|
DOUBLEVECTOR = 'double_vector'
|
||||||
|
TOBITVEC = 'to_bitvec'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_bijection(func):
|
def is_bijection(func):
|
||||||
@@ -668,6 +750,14 @@ class TypeVar(object):
|
|||||||
"""
|
"""
|
||||||
return TypeVar.derived(self, self.DOUBLEVECTOR)
|
return TypeVar.derived(self, self.DOUBLEVECTOR)
|
||||||
|
|
||||||
|
def to_bitvec(self):
|
||||||
|
# type: () -> TypeVar
|
||||||
|
"""
|
||||||
|
Return a derived type variable that represent a flat bitvector with
|
||||||
|
the same size as self
|
||||||
|
"""
|
||||||
|
return TypeVar.derived(self, self.TOBITVEC)
|
||||||
|
|
||||||
def singleton_type(self):
|
def singleton_type(self):
|
||||||
# type: () -> ValueType
|
# type: () -> ValueType
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user