Add TypeVar.rust_expr().

Generate a Rust expression that computes the value of a derived type
variable.
This commit is contained in:
Jakob Stoklund Olesen
2016-11-10 14:46:42 -08:00
parent bf1568035f
commit 5c9a12f101
4 changed files with 55 additions and 22 deletions

View File

@@ -197,6 +197,16 @@ class Var(Expr):
return False return False
return self.typevar is self.original_typevar return self.typevar is self.original_typevar
def rust_type(self):
# type: () -> str
"""
Get a Rust expression that computes the type of this variable.
It is assumed that local variables exist corresponding to the free type
variables.
"""
return self.typevar.rust_expr()
def constrain_typevar(self, sym_typevar, sym_ctrl, ctrl_var): def constrain_typevar(self, sym_typevar, sym_ctrl, ctrl_var):
# type: (TypeVar, TypeVar, Var) -> None # type: (TypeVar, TypeVar, Var) -> None
""" """

View File

@@ -57,10 +57,14 @@ class TestTypeVar(TestCase):
x2 = TypeVar('x2', 'i16 and up', ints=(16, 64)) x2 = TypeVar('x2', 'i16 and up', ints=(16, 64))
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
x2.double_width() x2.double_width()
self.assertEqual(str(x2.half_width()), '`HalfWidth(x2)`') self.assertEqual(str(x2.half_width()), '`half_width(x2)`')
self.assertEqual(x2.half_width().rust_expr(), 'x2.half_width()')
self.assertEqual(
x2.half_width().double_width().rust_expr(),
'x2.half_width().double_width()')
x3 = TypeVar('x3', 'up to i32', ints=(8, 32)) x3 = TypeVar('x3', 'up to i32', ints=(8, 32))
self.assertEqual(str(x3.double_width()), '`DoubleWidth(x3)`') self.assertEqual(str(x3.double_width()), '`double_width(x3)`')
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
x3.half_width() x3.half_width()

View File

@@ -312,11 +312,14 @@ class TypeVar(object):
return self is other return self is other
# Supported functions for derived type variables. # Supported functions for derived type variables.
SAMEAS = 'SameAs' # The names here must match the method names on `ir::types::Type`.
LANEOF = 'LaneOf' # The camel_case of the names must match `enum OperandConstraint` in
ASBOOL = 'AsBool' # `instructions.rs`.
HALFWIDTH = 'HalfWidth' SAMEAS = 'same_as'
DOUBLEWIDTH = 'DoubleWidth' LANEOF = 'lane_of'
ASBOOL = 'as_bool'
HALFWIDTH = 'half_width'
DOUBLEWIDTH = 'double_width'
@staticmethod @staticmethod
def derived(base, derived_func): def derived(base, derived_func):
@@ -370,6 +373,7 @@ class TypeVar(object):
Return a derived type variable that has the same number of vector lanes Return a derived type variable that has the same number of vector lanes
as this one, but the lanes are half the width. as this one, but the lanes are half the width.
""" """
if not self.is_derived:
ts = self.type_set ts = self.type_set
if ts.min_int: if ts.min_int:
assert ts.min_int > 8, "Can't halve all integer types" assert ts.min_int > 8, "Can't halve all integer types"
@@ -386,13 +390,14 @@ class TypeVar(object):
Return a derived type variable that has the same number of vector lanes Return a derived type variable that has the same number of vector lanes
as this one, but the lanes are double the width. as this one, but the lanes are double the width.
""" """
if not self.is_derived:
ts = self.type_set ts = self.type_set
if ts.max_int: if ts.max_int:
assert ts.max_int < MAX_BITS, "Can't double all integer types." assert ts.max_int < MAX_BITS, "Can't double all integer types."
if ts.max_float: if ts.max_float:
assert ts.max_float < MAX_BITS, "Can't double all float types." assert ts.max_float < MAX_BITS, "Can't double all float types."
if ts.max_bool: if ts.max_bool:
assert ts.max_bool < MAX_BITS, "Can't double all boolean types." assert ts.max_bool < MAX_BITS, "Can't double all bool types."
return TypeVar.derived(self, self.DOUBLEWIDTH) return TypeVar.derived(self, self.DOUBLEWIDTH)
@@ -409,6 +414,19 @@ class TypeVar(object):
else: else:
return self return self
def rust_expr(self):
# type: () -> str
"""
Get a Rust expression that computes the type of this type variable.
"""
if self.is_derived:
return '{}.{}()'.format(
self.base.rust_expr(), self.derived_func)
elif self.singleton_type:
return self.singleton_type.rust_name()
else:
return self.name
def constrain_types(self, other): def constrain_types(self, other):
# type: (TypeVar) -> None # type: (TypeVar) -> None
""" """

View File

@@ -5,6 +5,7 @@ from __future__ import absolute_import
import srcgen import srcgen
import constant_hash import constant_hash
from unique_table import UniqueTable, UniqueSeqTable from unique_table import UniqueTable, UniqueSeqTable
from cdsl import camel_case
from cdsl.operands import ImmediateKind from cdsl.operands import ImmediateKind
import cdsl.types import cdsl.types
from cdsl.formats import InstructionFormat from cdsl.formats import InstructionFormat
@@ -330,7 +331,7 @@ def get_constraint(op, ctrl_typevar, type_sets):
if tv.is_derived: if tv.is_derived:
assert tv.base is ctrl_typevar, "Not derived from ctrl_typevar" assert tv.base is ctrl_typevar, "Not derived from ctrl_typevar"
return tv.derived_func return camel_case(tv.derived_func)
assert tv is ctrl_typevar assert tv is ctrl_typevar
return 'Same' return 'Same'