From 6a9438d274fb01357cfa5d5a041aaf431348569a Mon Sep 17 00:00:00 2001 From: d1m0 Date: Fri, 23 Jun 2017 11:57:24 -0700 Subject: [PATCH] Add image computation of typesets; Remove TypeVar.singleton_type - instead derive singleton type from typeset; (#104) --- lib/cretonne/meta/base/types.py | 9 +- lib/cretonne/meta/cdsl/instructions.py | 2 +- lib/cretonne/meta/cdsl/test_typevar.py | 80 ++++++++- lib/cretonne/meta/cdsl/types.py | 28 +++- lib/cretonne/meta/cdsl/typevar.py | 218 +++++++++++++++++++++---- lib/cretonne/meta/cdsl/xform.py | 2 +- lib/cretonne/meta/gen_instr.py | 4 +- 7 files changed, 294 insertions(+), 49 deletions(-) diff --git a/lib/cretonne/meta/base/types.py b/lib/cretonne/meta/base/types.py index a2eb1054b9..7111626009 100644 --- a/lib/cretonne/meta/base/types.py +++ b/lib/cretonne/meta/base/types.py @@ -2,15 +2,10 @@ The base.types module predefines all the Cretonne scalar types. """ from __future__ import absolute_import -from cdsl.types import ScalarType, IntType, FloatType, BoolType +from cdsl.types import IntType, FloatType, BoolType #: Boolean. -b1 = ScalarType( - 'b1', 0, - """ - A boolean value that is either true or false. - """) - +b1 = BoolType(1) #: 1-bit bool. Type is abstract (can't be stored in mem) b8 = BoolType(8) #: 8-bit bool. b16 = BoolType(16) #: 16-bit bool. b32 = BoolType(32) #: 32-bit bool. diff --git a/lib/cretonne/meta/cdsl/instructions.py b/lib/cretonne/meta/cdsl/instructions.py index d8e9d24e5f..22c989bd65 100644 --- a/lib/cretonne/meta/cdsl/instructions.py +++ b/lib/cretonne/meta/cdsl/instructions.py @@ -186,7 +186,7 @@ class Instruction(object): try: opnum = self.value_opnums[self.format.typevar_operand] tv = self.ins[opnum].typevar - if tv is tv.free_typevar(): + if tv is tv.free_typevar() or tv.singleton_type() is not None: self.other_typevars = self._verify_ctrl_typevar(tv) self.ctrl_typevar = tv self.use_typevar_operand = True diff --git a/lib/cretonne/meta/cdsl/test_typevar.py b/lib/cretonne/meta/cdsl/test_typevar.py index 97793f71a5..5da081eabe 100644 --- a/lib/cretonne/meta/cdsl/test_typevar.py +++ b/lib/cretonne/meta/cdsl/test_typevar.py @@ -3,7 +3,7 @@ from unittest import TestCase from doctest import DocTestSuite from . import typevar from .typevar import TypeSet, TypeVar -from base.types import i32 +from base.types import i32, i16, b1, f64 def load_tests(loader, tests, ignore): @@ -45,6 +45,84 @@ class TestTypeSet(TestCase): with self.assertRaises(AssertionError): a in s + def test_forward_images(self): + a = TypeSet(lanes=(2, 8), ints=(8, 8), floats=(32, 32)) + b = TypeSet(lanes=(1, 8), ints=(8, 8), floats=(32, 32)) + self.assertEqual(a.lane_of(), TypeSet(ints=(8, 8), floats=(32, 32))) + + c = TypeSet(lanes=(2, 8)) + c.bools = set([8, 32]) + + # Test case with disjoint intervals + self.assertEqual(a.as_bool(), c) + + # For as_bool check b1 is present when 1 \in lanes + d = TypeSet(lanes=(1, 8)) + d.bools = set([1, 8, 32]) + self.assertEqual(b.as_bool(), d) + + self.assertEqual(TypeSet(lanes=(1, 32)).half_vector(), + TypeSet(lanes=(1, 16))) + + self.assertEqual(TypeSet(lanes=(1, 32)).double_vector(), + TypeSet(lanes=(2, 64))) + + self.assertEqual(TypeSet(lanes=(128, 256)).double_vector(), + TypeSet(lanes=(256, 256))) + + self.assertEqual(TypeSet(ints=(8, 32)).half_width(), + TypeSet(ints=(8, 16))) + + self.assertEqual(TypeSet(ints=(8, 32)).double_width(), + TypeSet(ints=(16, 64))) + + self.assertEqual(TypeSet(ints=(32, 64)).double_width(), + TypeSet(ints=(64, 64))) + + # Should produce an empty ts + self.assertEqual(TypeSet(floats=(32, 32)).half_width(), + TypeSet()) + + self.assertEqual(TypeSet(floats=(32, 64)).half_width(), + TypeSet(floats=(32, 32))) + + self.assertEqual(TypeSet(floats=(32, 32)).double_width(), + TypeSet(floats=(64, 64))) + + self.assertEqual(TypeSet(floats=(32, 64)).double_width(), + TypeSet(floats=(64, 64))) + + # Bools have trickier behavior around b1 (since b2, b4 don't exist) + self.assertEqual(TypeSet(bools=(1, 8)).half_width(), + TypeSet()) + + t = TypeSet() + t.bools = set([8, 16]) + self.assertEqual(TypeSet(bools=(1, 32)).half_width(), t) + + # double_width() of bools={1, 8, 16} must not include 2 or 8 + t.bools = set([16, 32]) + self.assertEqual(TypeSet(bools=(1, 16)).double_width(), t) + + self.assertEqual(TypeSet(bools=(32, 64)).double_width(), + TypeSet(bools=(64, 64))) + + def test_get_singleton(self): + # Raise error when calling get_singleton() on non-singleton TS + t = TypeSet(lanes=(1, 1), ints=(8, 8), floats=(32, 32)) + with self.assertRaises(AssertionError): + t.get_singleton() + t = TypeSet(lanes=(1, 2), floats=(32, 32)) + + with self.assertRaises(AssertionError): + t.get_singleton() + + self.assertEqual(TypeSet(ints=(16, 16)).get_singleton(), i16) + self.assertEqual(TypeSet(floats=(64, 64)).get_singleton(), f64) + self.assertEqual(TypeSet(bools=(1, 1)).get_singleton(), b1) + self.assertEqual(TypeSet(lanes=(4, 4), ints=(32, 32)).get_singleton(), + i32.by(4)) + class TestTypeVar(TestCase): def test_functions(self): diff --git a/lib/cretonne/meta/cdsl/types.py b/lib/cretonne/meta/cdsl/types.py index e42460af28..6feb2112d8 100644 --- a/lib/cretonne/meta/cdsl/types.py +++ b/lib/cretonne/meta/cdsl/types.py @@ -97,7 +97,7 @@ class VectorType(ValueType): # type: (ScalarType, int) -> None assert isinstance(base, ScalarType), 'SIMD lanes must be scalar types' super(VectorType, self).__init__( - name='{}x{}'.format(base.name, lanes), + name=VectorType.get_name(base, lanes), membytes=lanes*base.membytes, doc=""" A SIMD vector with {} lanes containing a `{}` each. @@ -111,6 +111,11 @@ class VectorType(ValueType): return ('VectorType(base={}, lanes={})' .format(self.base.name, self.lanes)) + @staticmethod + def get_name(base, lanes): + # type: (ValueType, int) -> str + return '{}x{}'.format(base.name, lanes) + class IntType(ScalarType): """A concrete scalar integer type.""" @@ -119,7 +124,7 @@ class IntType(ScalarType): # type: (int) -> None assert bits > 0, 'IntType must have positive number of bits' super(IntType, self).__init__( - name='i{:d}'.format(bits), + name=IntType.get_name(bits), membytes=bits // 8, doc="An integer type with {} bits.".format(bits)) self.bits = bits @@ -128,6 +133,11 @@ class IntType(ScalarType): # type: () -> str return 'IntType(bits={})'.format(self.bits) + @staticmethod + def get_name(bits): + # type: (int) -> str + return 'i{:d}'.format(bits) + class FloatType(ScalarType): """A concrete scalar floating point type.""" @@ -136,7 +146,7 @@ class FloatType(ScalarType): # type: (int, str) -> None assert bits > 0, 'FloatType must have positive number of bits' super(FloatType, self).__init__( - name='f{:d}'.format(bits), + name=FloatType.get_name(bits), membytes=bits // 8, doc=doc) self.bits = bits @@ -145,6 +155,11 @@ class FloatType(ScalarType): # type: () -> str return 'FloatType(bits={})'.format(self.bits) + @staticmethod + def get_name(bits): + # type: (int) -> str + return 'f{:d}'.format(bits) + class BoolType(ScalarType): """A concrete scalar boolean type.""" @@ -153,7 +168,7 @@ class BoolType(ScalarType): # type: (int) -> None assert bits > 0, 'BoolType must have positive number of bits' super(BoolType, self).__init__( - name='b{:d}'.format(bits), + name=BoolType.get_name(bits), membytes=bits // 8, doc="A boolean type with {} bits.".format(bits)) self.bits = bits @@ -161,3 +176,8 @@ class BoolType(ScalarType): def __repr__(self): # type: () -> str return 'BoolType(bits={})'.format(self.bits) + + @staticmethod + def get_name(bits): + # type: (int) -> str + return 'b{:d}'.format(bits) diff --git a/lib/cretonne/meta/cdsl/typevar.py b/lib/cretonne/meta/cdsl/typevar.py index 119c6bdf01..6ed3ffc86c 100644 --- a/lib/cretonne/meta/cdsl/typevar.py +++ b/lib/cretonne/meta/cdsl/typevar.py @@ -7,15 +7,19 @@ polymorphic by using type variables. from __future__ import absolute_import import math from . import types, is_power_of_two +from copy import deepcopy +from .types import ValueType, IntType, FloatType, BoolType try: from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING # noqa + from typing import cast if TYPE_CHECKING: from srcgen import Formatter # noqa Interval = Tuple[int, int] # An Interval where `True` means 'everything' BoolInterval = Union[bool, Interval] except ImportError: + TYPE_CHECKING = False pass MAX_LANES = 256 @@ -112,6 +116,16 @@ def interval_to_set(intv): return set([2**i for i in range(int_log2(lo), int_log2(hi)+1)]) +def legal_bool(bits): + # type: (int) -> bool + """ + True iff bits is a legal bit width for a bool type. + bits == 1 || bits \in { 8, 16, .. MAX_BITS } + """ + return bits == 1 or \ + (bits >= 8 and bits <= MAX_BITS and is_power_of_two(bits)) + + class TypeSet(object): """ A set of types. @@ -165,7 +179,15 @@ class TypeSet(object): self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS))) self.floats = interval_to_set(decode_interval(floats, (32, 64))) self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS))) - self.bools = set(filter(lambda x: x == 1 or x >= 8, self.bools)) + self.bools = set(filter(legal_bool, self.bools)) + + def copy(self): + # type: (TypeSet) -> TypeSet + """ + Return a copy of our self. deepcopy is sufficient and safe here, since + TypeSet contains only sets of numbers. + """ + return deepcopy(self) def typeset_key(self): # type: () -> Tuple[Tuple, Tuple, Tuple, Tuple] @@ -241,6 +263,109 @@ class TypeSet(object): return self + def lane_of(self): + # type: () -> TypeSet + """ + Return a TypeSet describing the image of self across lane_of + """ + new = self.copy() + new.lanes = set([1]) + return new + + def as_bool(self): + # type: () -> TypeSet + """ + Return a TypeSet describing the image of self across as_bool + """ + new = self.copy() + new.ints = set() + new.floats = set() + new.bools = self.ints.union(self.floats).union(self.bools) + + if 1 in self.lanes: + new.bools.add(1) + return new + + def half_width(self): + # type: () -> TypeSet + """ + Return a TypeSet describing the image of self across halfwidth + """ + new = self.copy() + new.ints = set([x/2 for x in self.ints if x > 8]) + new.floats = set([x/2 for x in self.floats if x > 32]) + new.bools = set([x/2 for x in self.bools if x > 8]) + + return new + + def double_width(self): + # type: () -> TypeSet + """ + Return a TypeSet describing the image of self across doublewidth + """ + new = self.copy() + new.ints = set([x*2 for x in self.ints if x < MAX_BITS]) + new.floats = set([x*2 for x in self.floats if x < MAX_BITS]) + new.bools = set(filter(legal_bool, + set([x*2 for x in self.bools if x < MAX_BITS]))) + + return new + + def half_vector(self): + # type: () -> TypeSet + """ + Return a TypeSet describing the image of self across halfvector + """ + new = self.copy() + new.lanes = set([x/2 for x in self.lanes if x > 1]) + + return new + + def double_vector(self): + # type: () -> TypeSet + """ + Return a TypeSet describing the image of self across doublevector + """ + new = self.copy() + new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES]) + + return new + + def size(self): + # type: () -> int + """ + Return the number of concrete types represented by this typeset + """ + return len(self.lanes) * (len(self.ints) + len(self.floats) + + len(self.bools)) + + def get_singleton(self): + # type: () -> types.ValueType + """ + Return the singleton type represented by self. Can only call on + typesets containing 1 type. + """ + assert self.size() == 1 + if len(self.ints) > 0: + bits = tuple(self.ints)[0] + scalar_type = ValueType.by_name(IntType.get_name(bits)) + elif len(self.floats) > 0: + bits = tuple(self.floats)[0] + scalar_type = ValueType.by_name(FloatType.get_name(bits)) + else: + bits = tuple(self.bools)[0] + scalar_type = ValueType.by_name(BoolType.get_name(bits)) + + nlanes = tuple(self.lanes)[0] + + if nlanes == 1: + return scalar_type + else: + if TYPE_CHECKING: + return cast(types.ScalarType, scalar_type).by(nlanes) + else: + return scalar_type.by(nlanes) + class TypeVar(object): """ @@ -271,7 +396,6 @@ class TypeVar(object): # type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, TypeVar, str) -> None # noqa self.name = name self.__doc__ = doc - self.singleton_type = None # type: types.ValueType self.is_derived = isinstance(base, TypeVar) if base: assert self.is_derived @@ -313,7 +437,6 @@ class TypeVar(object): tv = TypeVar( typ.name, typ.__doc__, ints, floats, bools, simd=lanes) - tv.singleton_type = typ return tv def __str__(self): @@ -406,14 +529,13 @@ class TypeVar(object): Return a derived type variable that has the same number of vector lanes as this one, but the lanes are half the width. """ - if not self.is_derived: - ts = self.type_set - if len(ts.ints) > 0: - assert min(ts.ints) > 8, "Can't halve all integer types" - if len(ts.floats) > 0: - assert min(ts.floats) > 32, "Can't halve all float types" - if len(ts.bools) > 0: - assert min(ts.bools) > 8, "Can't halve all boolean types" + ts = self.get_typeset() + if len(ts.ints) > 0: + assert min(ts.ints) > 8, "Can't halve all integer types" + if len(ts.floats) > 0: + assert min(ts.floats) > 32, "Can't halve all float types" + if len(ts.bools) > 0: + assert min(ts.bools) > 8, "Can't halve all boolean types" return TypeVar.derived(self, self.HALFWIDTH) @@ -423,16 +545,13 @@ class TypeVar(object): Return a derived type variable that has the same number of vector lanes as this one, but the lanes are double the width. """ - if not self.is_derived: - ts = self.type_set - if len(ts.ints) > 0: - assert max(ts.ints) < MAX_BITS,\ - "Can't double all integer types." - if len(ts.floats) > 0: - assert max(ts.floats) < MAX_BITS,\ - "Can't double all float types." - if len(ts.bools) > 0: - assert max(ts.bools) < MAX_BITS, "Can't double all bool types." + ts = self.get_typeset() + if len(ts.ints) > 0: + assert max(ts.ints) < MAX_BITS, "Can't double all integer types." + if len(ts.floats) > 0: + assert max(ts.floats) < MAX_BITS, "Can't double all float types." + if len(ts.bools) > 0: + assert max(ts.bools) < MAX_BITS, "Can't double all bool types." return TypeVar.derived(self, self.DOUBLEWIDTH) @@ -442,9 +561,8 @@ class TypeVar(object): Return a derived type variable that has half the number of vector lanes as this one, with the same lane type. """ - if not self.is_derived: - ts = self.type_set - assert min(ts.lanes) > 1, "Can't halve a scalar type" + ts = self.get_typeset() + assert min(ts.lanes) > 1, "Can't halve a scalar type" return TypeVar.derived(self, self.HALFVECTOR) @@ -454,12 +572,23 @@ class TypeVar(object): Return a derived type variable that has twice the number of vector lanes as this one, with the same lane type. """ - if not self.is_derived: - ts = self.type_set - assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes." + ts = self.get_typeset() + assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes." return TypeVar.derived(self, self.DOUBLEVECTOR) + def singleton_type(self): + # type: () -> ValueType + """ + If the associated typeset has a single type return it. Otherwise return + None + """ + ts = self.get_typeset() + if ts.size() != 1: + return None + + return ts.get_singleton() + def free_typevar(self): # type: () -> TypeVar """ @@ -467,7 +596,7 @@ class TypeVar(object): """ if self.is_derived: return self.base - elif self.singleton_type: + elif self.singleton_type() is not None: # A singleton type variable is not a proper free variable. return None else: @@ -481,8 +610,8 @@ class TypeVar(object): if self.is_derived: return '{}.{}()'.format( self.base.rust_expr(), self.derived_func) - elif self.singleton_type: - return self.singleton_type.rust_name() + elif self.singleton_type(): + return self.singleton_type().rust_name() else: return self.name @@ -501,9 +630,6 @@ class TypeVar(object): if not a.is_derived and not b.is_derived: a.type_set &= b.type_set - # TODO: What if a.type_set becomes empty? - if not a.singleton_type: - a.singleton_type = b.singleton_type return # TODO: Implement constraints for derived type variables. @@ -514,3 +640,29 @@ class TypeVar(object): # # For the fully general case, we would need to compute an image typeset # for `b` and propagate a `a.derived_func` pre-image to `a.base`. + + def get_typeset(self): + # type: () -> TypeSet + """ + Returns the typeset for this TV. If the TV is derived, computes it + recursively from the derived function and the base's typeset. + """ + if not self.is_derived: + return self.type_set + else: + if (self.derived_func == TypeVar.SAMEAS): + return self.base.get_typeset() + elif (self.derived_func == TypeVar.LANEOF): + return self.base.get_typeset().lane_of() + elif (self.derived_func == TypeVar.ASBOOL): + return self.base.get_typeset().as_bool() + elif (self.derived_func == TypeVar.HALFWIDTH): + return self.base.get_typeset().half_width() + elif (self.derived_func == TypeVar.DOUBLEWIDTH): + return self.base.get_typeset().double_width() + elif (self.derived_func == TypeVar.HALFVECTOR): + return self.base.get_typeset().half_vector() + elif (self.derived_func == TypeVar.DOUBLEVECTOR): + return self.base.get_typeset().double_vector() + else: + assert False, "Unknown derived function: " + self.derived_func diff --git a/lib/cretonne/meta/cdsl/xform.py b/lib/cretonne/meta/cdsl/xform.py index 9d284bc1e2..9ce93c9ed9 100644 --- a/lib/cretonne/meta/cdsl/xform.py +++ b/lib/cretonne/meta/cdsl/xform.py @@ -254,7 +254,7 @@ class XForm(object): # Some variables have a fixed type which appears as a type variable # with a singleton_type field set. That's allowed for temps too. for v in fvars: - if v.is_temp() and not v.typevar.singleton_type: + if v.is_temp() and not v.typevar.singleton_type(): raise AssertionError( "Cannot determine type of temp '{}' in xform:\n{}" .format(v, self)) diff --git a/lib/cretonne/meta/gen_instr.py b/lib/cretonne/meta/gen_instr.py index 548cd45a29..2d67311ae2 100644 --- a/lib/cretonne/meta/gen_instr.py +++ b/lib/cretonne/meta/gen_instr.py @@ -321,8 +321,8 @@ def get_constraint(op, ctrl_typevar, type_sets): tv = op.typevar # A concrete value type. - if tv.singleton_type: - return 'Concrete({})'.format(tv.singleton_type.rust_name()) + if tv.singleton_type(): + return 'Concrete({})'.format(tv.singleton_type().rust_name()) if tv.free_typevar() is not ctrl_typevar: assert not tv.is_derived