Add TypeVar.rust_expr().
Generate a Rust expression that computes the value of a derived type variable.
This commit is contained in:
@@ -197,6 +197,16 @@ class Var(Expr):
|
|||||||
return False
|
return False
|
||||||
return self.typevar is self.original_typevar
|
return self.typevar is self.original_typevar
|
||||||
|
|
||||||
|
def rust_type(self):
|
||||||
|
# type: () -> str
|
||||||
|
"""
|
||||||
|
Get a Rust expression that computes the type of this variable.
|
||||||
|
|
||||||
|
It is assumed that local variables exist corresponding to the free type
|
||||||
|
variables.
|
||||||
|
"""
|
||||||
|
return self.typevar.rust_expr()
|
||||||
|
|
||||||
def constrain_typevar(self, sym_typevar, sym_ctrl, ctrl_var):
|
def constrain_typevar(self, sym_typevar, sym_ctrl, ctrl_var):
|
||||||
# type: (TypeVar, TypeVar, Var) -> None
|
# type: (TypeVar, TypeVar, Var) -> None
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -57,10 +57,14 @@ class TestTypeVar(TestCase):
|
|||||||
x2 = TypeVar('x2', 'i16 and up', ints=(16, 64))
|
x2 = TypeVar('x2', 'i16 and up', ints=(16, 64))
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
x2.double_width()
|
x2.double_width()
|
||||||
self.assertEqual(str(x2.half_width()), '`HalfWidth(x2)`')
|
self.assertEqual(str(x2.half_width()), '`half_width(x2)`')
|
||||||
|
self.assertEqual(x2.half_width().rust_expr(), 'x2.half_width()')
|
||||||
|
self.assertEqual(
|
||||||
|
x2.half_width().double_width().rust_expr(),
|
||||||
|
'x2.half_width().double_width()')
|
||||||
|
|
||||||
x3 = TypeVar('x3', 'up to i32', ints=(8, 32))
|
x3 = TypeVar('x3', 'up to i32', ints=(8, 32))
|
||||||
self.assertEqual(str(x3.double_width()), '`DoubleWidth(x3)`')
|
self.assertEqual(str(x3.double_width()), '`double_width(x3)`')
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
x3.half_width()
|
x3.half_width()
|
||||||
|
|
||||||
|
|||||||
@@ -312,11 +312,14 @@ class TypeVar(object):
|
|||||||
return self is other
|
return self is other
|
||||||
|
|
||||||
# Supported functions for derived type variables.
|
# Supported functions for derived type variables.
|
||||||
SAMEAS = 'SameAs'
|
# The names here must match the method names on `ir::types::Type`.
|
||||||
LANEOF = 'LaneOf'
|
# The camel_case of the names must match `enum OperandConstraint` in
|
||||||
ASBOOL = 'AsBool'
|
# `instructions.rs`.
|
||||||
HALFWIDTH = 'HalfWidth'
|
SAMEAS = 'same_as'
|
||||||
DOUBLEWIDTH = 'DoubleWidth'
|
LANEOF = 'lane_of'
|
||||||
|
ASBOOL = 'as_bool'
|
||||||
|
HALFWIDTH = 'half_width'
|
||||||
|
DOUBLEWIDTH = 'double_width'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def derived(base, derived_func):
|
def derived(base, derived_func):
|
||||||
@@ -370,13 +373,14 @@ class TypeVar(object):
|
|||||||
Return a derived type variable that has the same number of vector lanes
|
Return a derived type variable that has the same number of vector lanes
|
||||||
as this one, but the lanes are half the width.
|
as this one, but the lanes are half the width.
|
||||||
"""
|
"""
|
||||||
ts = self.type_set
|
if not self.is_derived:
|
||||||
if ts.min_int:
|
ts = self.type_set
|
||||||
assert ts.min_int > 8, "Can't halve all integer types"
|
if ts.min_int:
|
||||||
if ts.min_float:
|
assert ts.min_int > 8, "Can't halve all integer types"
|
||||||
assert ts.min_float > 32, "Can't halve all float types"
|
if ts.min_float:
|
||||||
if ts.min_bool:
|
assert ts.min_float > 32, "Can't halve all float types"
|
||||||
assert ts.min_bool > 8, "Can't halve all boolean types"
|
if ts.min_bool:
|
||||||
|
assert ts.min_bool > 8, "Can't halve all boolean types"
|
||||||
|
|
||||||
return TypeVar.derived(self, self.HALFWIDTH)
|
return TypeVar.derived(self, self.HALFWIDTH)
|
||||||
|
|
||||||
@@ -386,13 +390,14 @@ class TypeVar(object):
|
|||||||
Return a derived type variable that has the same number of vector lanes
|
Return a derived type variable that has the same number of vector lanes
|
||||||
as this one, but the lanes are double the width.
|
as this one, but the lanes are double the width.
|
||||||
"""
|
"""
|
||||||
ts = self.type_set
|
if not self.is_derived:
|
||||||
if ts.max_int:
|
ts = self.type_set
|
||||||
assert ts.max_int < MAX_BITS, "Can't double all integer types."
|
if ts.max_int:
|
||||||
if ts.max_float:
|
assert ts.max_int < MAX_BITS, "Can't double all integer types."
|
||||||
assert ts.max_float < MAX_BITS, "Can't double all float types."
|
if ts.max_float:
|
||||||
if ts.max_bool:
|
assert ts.max_float < MAX_BITS, "Can't double all float types."
|
||||||
assert ts.max_bool < MAX_BITS, "Can't double all boolean types."
|
if ts.max_bool:
|
||||||
|
assert ts.max_bool < MAX_BITS, "Can't double all bool types."
|
||||||
|
|
||||||
return TypeVar.derived(self, self.DOUBLEWIDTH)
|
return TypeVar.derived(self, self.DOUBLEWIDTH)
|
||||||
|
|
||||||
@@ -409,6 +414,19 @@ class TypeVar(object):
|
|||||||
else:
|
else:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def rust_expr(self):
|
||||||
|
# type: () -> str
|
||||||
|
"""
|
||||||
|
Get a Rust expression that computes the type of this type variable.
|
||||||
|
"""
|
||||||
|
if self.is_derived:
|
||||||
|
return '{}.{}()'.format(
|
||||||
|
self.base.rust_expr(), self.derived_func)
|
||||||
|
elif self.singleton_type:
|
||||||
|
return self.singleton_type.rust_name()
|
||||||
|
else:
|
||||||
|
return self.name
|
||||||
|
|
||||||
def constrain_types(self, other):
|
def constrain_types(self, other):
|
||||||
# type: (TypeVar) -> None
|
# type: (TypeVar) -> None
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from __future__ import absolute_import
|
|||||||
import srcgen
|
import srcgen
|
||||||
import constant_hash
|
import constant_hash
|
||||||
from unique_table import UniqueTable, UniqueSeqTable
|
from unique_table import UniqueTable, UniqueSeqTable
|
||||||
|
from cdsl import camel_case
|
||||||
from cdsl.operands import ImmediateKind
|
from cdsl.operands import ImmediateKind
|
||||||
import cdsl.types
|
import cdsl.types
|
||||||
from cdsl.formats import InstructionFormat
|
from cdsl.formats import InstructionFormat
|
||||||
@@ -330,7 +331,7 @@ def get_constraint(op, ctrl_typevar, type_sets):
|
|||||||
|
|
||||||
if tv.is_derived:
|
if tv.is_derived:
|
||||||
assert tv.base is ctrl_typevar, "Not derived from ctrl_typevar"
|
assert tv.base is ctrl_typevar, "Not derived from ctrl_typevar"
|
||||||
return tv.derived_func
|
return camel_case(tv.derived_func)
|
||||||
|
|
||||||
assert tv is ctrl_typevar
|
assert tv is ctrl_typevar
|
||||||
return 'Same'
|
return 'Same'
|
||||||
|
|||||||
Reference in New Issue
Block a user