diff --git a/lib/cretonne/meta/cdsl/instructions.py b/lib/cretonne/meta/cdsl/instructions.py index 8132f144f6..ee6b9c4a59 100644 --- a/lib/cretonne/meta/cdsl/instructions.py +++ b/lib/cretonne/meta/cdsl/instructions.py @@ -130,10 +130,10 @@ class Instruction(object): """ poly_ins = [ i for i in self.format.value_operands - if self.ins[i].typ.free_typevar()] + if self.ins[i].typevar.free_typevar()] poly_outs = [ i for i, o in enumerate(self.outs) - if o.typ.free_typevar()] + if o.is_value() and o.typevar.free_typevar()] self.is_polymorphic = len(poly_ins) > 0 or len(poly_outs) > 0 if not self.is_polymorphic: return @@ -143,7 +143,7 @@ class Instruction(object): typevar_error = None if self.format.typevar_operand is not None: try: - tv = self.ins[self.format.typevar_operand].typ + tv = self.ins[self.format.typevar_operand].typevar if tv is tv.free_typevar(): self.other_typevars = self._verify_ctrl_typevar(tv) self.ctrl_typevar = tv @@ -160,7 +160,7 @@ class Instruction(object): else: raise RuntimeError( "typevar_operand must be a free type variable") - tv = self.outs[0].typ + tv = self.outs[0].typevar if tv is not tv.free_typevar(): raise RuntimeError("first result must be a free type variable") self.other_typevars = self._verify_ctrl_typevar(tv) @@ -181,7 +181,7 @@ class Instruction(object): other_tvs = [] # Check value inputs. for opidx in self.format.value_operands: - typ = self.ins[opidx].typ + typ = self.ins[opidx].typevar tv = typ.free_typevar() # Non-polymorphic or derived form ctrl_typevar is OK. if tv is None or tv is ctrl_typevar: @@ -200,7 +200,9 @@ class Instruction(object): # Check outputs. for result in self.outs: - typ = result.typ + if not result.is_value(): + continue + typ = result.typevar tv = typ.free_typevar() # Non-polymorphic or derived from ctrl_typevar is OK. if tv is None or tv is ctrl_typevar: diff --git a/lib/cretonne/meta/cdsl/operands.py b/lib/cretonne/meta/cdsl/operands.py index f379357611..2fcc22b0c0 100644 --- a/lib/cretonne/meta/cdsl/operands.py +++ b/lib/cretonne/meta/cdsl/operands.py @@ -40,10 +40,6 @@ class OperandKind(object): # type: () -> str return 'OperandKind({})'.format(self.name) - def free_typevar(self): - # Return the free typevariable controlling the type of this operand. - return None - #: An SSA value operand. This is a value defined by another instruction. VALUE = OperandKind( 'value', """ @@ -129,11 +125,15 @@ class Operand(object): # type: (str, OperandSpec, str) -> None self.name = name self.__doc__ = doc - self.typ = typ + + # Decode the operand spec and set self.kind. + # Only VALUE operands have a typevar member. if isinstance(typ, ValueType): self.kind = VALUE + self.typevar = TypeVar.singleton(typ) elif isinstance(typ, TypeVar): self.kind = VALUE + self.typevar = typ else: assert isinstance(typ, OperandKind) self.kind = typ @@ -142,8 +142,9 @@ class Operand(object): # type: () -> str if self.__doc__: return self.__doc__ - else: - return self.typ.__doc__ + if self.kind is VALUE: + return self.typevar.__doc__ + return self.kind.__doc__ def __str__(self): # type: () -> str diff --git a/lib/cretonne/meta/cdsl/test_typevar.py b/lib/cretonne/meta/cdsl/test_typevar.py index a841acfdb9..7dae18221a 100644 --- a/lib/cretonne/meta/cdsl/test_typevar.py +++ b/lib/cretonne/meta/cdsl/test_typevar.py @@ -3,6 +3,7 @@ from unittest import TestCase from doctest import DocTestSuite from . import typevar from .typevar import TypeSet, TypeVar +from base.types import i32 def load_tests(loader, tests, ignore): @@ -62,3 +63,18 @@ class TestTypeVar(TestCase): self.assertEqual(str(x3.double_width()), '`DoubleWidth(x3)`') with self.assertRaises(AssertionError): x3.half_width() + + def test_singleton(self): + x = TypeVar.singleton(i32) + self.assertEqual(str(x), '`i32`') + self.assertEqual(x.type_set.min_int, 32) + self.assertEqual(x.type_set.max_int, 32) + self.assertEqual(x.type_set.min_lanes, 1) + self.assertEqual(x.type_set.max_lanes, 1) + + x = TypeVar.singleton(i32.by(4)) + self.assertEqual(str(x), '`i32x4`') + self.assertEqual(x.type_set.min_int, 32) + self.assertEqual(x.type_set.max_int, 32) + self.assertEqual(x.type_set.min_lanes, 4) + self.assertEqual(x.type_set.max_lanes, 4) diff --git a/lib/cretonne/meta/cdsl/types.py b/lib/cretonne/meta/cdsl/types.py index 3d5a5dcd38..62334b7dd7 100644 --- a/lib/cretonne/meta/cdsl/types.py +++ b/lib/cretonne/meta/cdsl/types.py @@ -30,8 +30,9 @@ class ValueType(object): # type: () -> str return self.name - def free_typevar(self): - return None + def rust_name(self): + # type: () -> str + return 'types::' + self.name.upper() @staticmethod def by_name(name): @@ -63,10 +64,6 @@ class ScalarType(ValueType): # type: () -> str return 'ScalarType({})'.format(self.name) - def rust_name(self): - # type: () -> str - return 'types::' + self.name.upper() - def by(self, lanes): # type: (int) -> VectorType """ diff --git a/lib/cretonne/meta/cdsl/typevar.py b/lib/cretonne/meta/cdsl/typevar.py index 2c10338bbf..9803f1afc6 100644 --- a/lib/cretonne/meta/cdsl/typevar.py +++ b/lib/cretonne/meta/cdsl/typevar.py @@ -6,6 +6,7 @@ polymorphic by using type variables. """ from __future__ import absolute_import import math +from . import types try: from typing import Tuple, Union # noqa @@ -242,6 +243,7 @@ class TypeVar(object): # type: (str, str, BoolInterval, BoolInterval, BoolInterval, bool, BoolInterval, TypeVar, str) -> None # noqa self.name = name self.__doc__ = doc + self.singleton_type = None # type: types.ValueType self.is_derived = isinstance(base, TypeVar) if base: assert self.is_derived @@ -258,6 +260,34 @@ class TypeVar(object): floats=floats, bools=bools) + @staticmethod + def singleton(typ): + # type: (types.ValueType) -> TypeVar + """Create a type variable that can only assume a single type.""" + if isinstance(typ, types.VectorType): + scalar = typ.base + lanes = (typ.lanes, typ.lanes) + elif isinstance(typ, types.ScalarType): + scalar = typ + lanes = (1, 1) + + ints = None + floats = None + bools = None + + if isinstance(scalar, types.IntType): + ints = (scalar.bits, scalar.bits) + elif isinstance(scalar, types.FloatType): + floats = (scalar.bits, scalar.bits) + elif isinstance(scalar, types.BoolType): + bools = (scalar.bits, scalar.bits) + + tv = TypeVar( + typ.name, 'typeof({})'.format(typ), + ints, floats, bools, simd=lanes) + tv.singleton_type = typ + return tv + def __str__(self): # type: () -> str return "`{}`".format(self.name) @@ -317,5 +347,8 @@ class TypeVar(object): # type: () -> TypeVar if self.is_derived: return self.base + elif self.singleton_type: + # A singleton type variable is not a proper free variable. + return None else: return self diff --git a/lib/cretonne/meta/gen_instr.py b/lib/cretonne/meta/gen_instr.py index 56846c90ab..4f240ab6b1 100644 --- a/lib/cretonne/meta/gen_instr.py +++ b/lib/cretonne/meta/gen_instr.py @@ -5,10 +5,14 @@ from __future__ import absolute_import import srcgen import constant_hash from unique_table import UniqueTable, UniqueSeqTable +from cdsl.operands import ImmediateKind import cdsl.types -import cdsl.operands from cdsl.formats import InstructionFormat +from cdsl.instructions import Instruction # noqa +from cdsl.operands import Operand # noqa +from cdsl.typevar import TypeVar # noqa + def gen_formats(fmt): # type: (srcgen.Formatter) -> None @@ -302,6 +306,7 @@ def gen_opcodes(groups, fmt): def get_constraint(op, ctrl_typevar, type_sets): + # type: (Operand, TypeVar, UniqueTable) -> str """ Get the value type constraint for an SSA value operand, where `ctrl_typevar` is the controlling type variable. @@ -312,22 +317,22 @@ def get_constraint(op, ctrl_typevar, type_sets): - `Free(idx)` where `idx` is an index into `type_sets`. - `Same`, `Lane`, `AsBool` for controlling typevar-derived constraints. """ - assert op.kind is cdsl.operands.VALUE - t = op.typ + assert op.is_value() + tv = op.typevar # A concrete value type. - if isinstance(t, cdsl.types.ValueType): - return 'Concrete({})'.format(t.rust_name()) + if tv.singleton_type: + return 'Concrete({})'.format(tv.singleton_type.rust_name()) - if t.free_typevar() is not ctrl_typevar: - assert not t.is_derived - return 'Free({})'.format(type_sets.add(t.type_set)) + if tv.free_typevar() is not ctrl_typevar: + assert not tv.is_derived + return 'Free({})'.format(type_sets.add(tv.type_set)) - if t.is_derived: - assert t.base is ctrl_typevar, "Not derived directly from ctrl_typevar" - return t.derived_func + if tv.is_derived: + assert tv.base is ctrl_typevar, "Not derived from ctrl_typevar" + return tv.derived_func - assert t is ctrl_typevar + assert tv is ctrl_typevar return 'Same' @@ -486,6 +491,7 @@ def gen_member_inits(iform, fmt): def gen_inst_builder(inst, fmt): + # type: (Instruction, srcgen.Formatter) -> None """ Emit a method for generating the instruction `inst`. @@ -502,10 +508,10 @@ def gen_inst_builder(inst, fmt): if inst.is_polymorphic and not inst.use_typevar_operand: args.append('{}: Type'.format(inst.ctrl_typevar.name)) - tmpl_types = list() - into_args = list() + tmpl_types = list() # type: List[str] + into_args = list() # type: List[str] for op in inst.ins: - if isinstance(op.kind, cdsl.operands.ImmediateKind): + if isinstance(op.kind, ImmediateKind): t = 'T{}{}'.format(1 + len(tmpl_types), op.kind.name) tmpl_types.append('{}: Into<{}>'.format(t, op.kind.rust_type)) into_args.append(op.name) @@ -553,7 +559,7 @@ def gen_inst_builder(inst, fmt): # The format constructor will resolve the result types from the # type var. args.append('ctrl_typevar') - elif inst.outs[inst.value_results[0]].typ == inst.ctrl_typevar: + elif inst.outs[inst.value_results[0]].typevar == inst.ctrl_typevar: # The format constructor expects a simple result type. # No type transformation needed from the controlling type # variable. @@ -567,13 +573,12 @@ def gen_inst_builder(inst, fmt): else: # This non-polymorphic instruction has a fixed result type. args.append( - 'types::' + - inst.outs[inst.value_results[0]].typ.name.upper()) + inst.outs[inst.value_results[0]] + .typevar.singleton_type.rust_name()) args.extend(op.name for op in inst.ins) - args = ', '.join(args) # Call to the format constructor, - fcall = 'self.{}({})'.format(inst.format.name, args) + fcall = 'self.{}({})'.format(inst.format.name, ', '.join(args)) if len(inst.value_results) == 0: fmt.line(fcall + '.0')