* Reduce code duplication in TypeConstraint subclasses; Add ConstrainWiderOrEqual to ti and to ireduce,{s,u}extend and f{promote,demote}; Fix bug in emitting constraint edges in TypeEnv.dot(); Modify runtime constraint checks to reject match when they encounter overflow
* Rename Constrain types to something shorter; Move lane_bits/lane_counts in subclasses of ValueType; Add wider_or_eq function in rust and python;
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
de5501bc47
commit
a9147ebd30
@@ -12,6 +12,7 @@ from base.types import i8, f32, f64, b1
|
|||||||
from base.immediates import imm64, uimm8, ieee32, ieee64, offset32, uoffset32
|
from base.immediates import imm64, uimm8, ieee32, ieee64, offset32, uoffset32
|
||||||
from base.immediates import intcc, floatcc, memflags, regunit
|
from base.immediates import intcc, floatcc, memflags, regunit
|
||||||
from base import entities
|
from base import entities
|
||||||
|
from cdsl.ti import WiderOrEq
|
||||||
import base.formats # noqa
|
import base.formats # noqa
|
||||||
|
|
||||||
GROUP = InstructionGroup("base", "Shared base instruction set")
|
GROUP = InstructionGroup("base", "Shared base instruction set")
|
||||||
@@ -1405,7 +1406,7 @@ ireduce = Instruction(
|
|||||||
and each lane must not have more bits that the input lanes. If the
|
and each lane must not have more bits that the input lanes. If the
|
||||||
input and output types are the same, this is a no-op.
|
input and output types are the same, this is a no-op.
|
||||||
""",
|
""",
|
||||||
ins=x, outs=a)
|
ins=x, outs=a, constraints=WiderOrEq(Int, IntTo))
|
||||||
|
|
||||||
|
|
||||||
IntTo = TypeVar(
|
IntTo = TypeVar(
|
||||||
@@ -1427,7 +1428,7 @@ uextend = Instruction(
|
|||||||
and each lane must not have fewer bits that the input lanes. If the
|
and each lane must not have fewer bits that the input lanes. If the
|
||||||
input and output types are the same, this is a no-op.
|
input and output types are the same, this is a no-op.
|
||||||
""",
|
""",
|
||||||
ins=x, outs=a)
|
ins=x, outs=a, constraints=WiderOrEq(IntTo, Int))
|
||||||
|
|
||||||
sextend = Instruction(
|
sextend = Instruction(
|
||||||
'sextend', r"""
|
'sextend', r"""
|
||||||
@@ -1441,7 +1442,7 @@ sextend = Instruction(
|
|||||||
and each lane must not have fewer bits that the input lanes. If the
|
and each lane must not have fewer bits that the input lanes. If the
|
||||||
input and output types are the same, this is a no-op.
|
input and output types are the same, this is a no-op.
|
||||||
""",
|
""",
|
||||||
ins=x, outs=a)
|
ins=x, outs=a, constraints=WiderOrEq(IntTo, Int))
|
||||||
|
|
||||||
FloatTo = TypeVar(
|
FloatTo = TypeVar(
|
||||||
'FloatTo', 'A scalar or vector floating point number',
|
'FloatTo', 'A scalar or vector floating point number',
|
||||||
@@ -1457,14 +1458,14 @@ fpromote = Instruction(
|
|||||||
Each lane in `x` is converted to the destination floating point format.
|
Each lane in `x` is converted to the destination floating point format.
|
||||||
This is an exact operation.
|
This is an exact operation.
|
||||||
|
|
||||||
Since Cretonne currently only supports two floating point formats, this
|
Cretonne currently only supports two floating point formats
|
||||||
instruction always converts :type:`f32` to :type:`f64`. This may change
|
- :type:`f32` and :type:`f64`. This may change in the future.
|
||||||
in the future.
|
|
||||||
|
|
||||||
The result type must have the same number of vector lanes as the input,
|
The result type must have the same number of vector lanes as the input,
|
||||||
and the result lanes must be larger than the input lanes.
|
and the result lanes must not have fewer bits than the input lanes. If
|
||||||
|
the input and output types are the same, this is a no-op.
|
||||||
""",
|
""",
|
||||||
ins=x, outs=a)
|
ins=x, outs=a, constraints=WiderOrEq(FloatTo, Float))
|
||||||
|
|
||||||
fdemote = Instruction(
|
fdemote = Instruction(
|
||||||
'fdemote', r"""
|
'fdemote', r"""
|
||||||
@@ -1473,14 +1474,14 @@ fdemote = Instruction(
|
|||||||
Each lane in `x` is converted to the destination floating point format
|
Each lane in `x` is converted to the destination floating point format
|
||||||
by rounding to nearest, ties to even.
|
by rounding to nearest, ties to even.
|
||||||
|
|
||||||
Since Cretonne currently only supports two floating point formats, this
|
Cretonne currently only supports two floating point formats
|
||||||
instruction always converts :type:`f64` to :type:`f32`. This may change
|
- :type:`f32` and :type:`f64`. This may change in the future.
|
||||||
in the future.
|
|
||||||
|
|
||||||
The result type must have the same number of vector lanes as the input,
|
The result type must have the same number of vector lanes as the input,
|
||||||
and the result lanes must be smaller than the input lanes.
|
and the result lanes must not have more bits than the input lanes. If
|
||||||
|
the input and output types are the same, this is a no-op.
|
||||||
""",
|
""",
|
||||||
ins=x, outs=a)
|
ins=x, outs=a, constraints=WiderOrEq(Float, FloatTo))
|
||||||
|
|
||||||
x = Operand('x', Float)
|
x = Operand('x', Float)
|
||||||
a = Operand('a', IntTo)
|
a = Operand('a', IntTo)
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ try:
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .ast import Expr, Apply # noqa
|
from .ast import Expr, Apply # noqa
|
||||||
from .typevar import TypeVar # noqa
|
from .typevar import TypeVar # noqa
|
||||||
|
from .ti import TypeConstraint # noqa
|
||||||
# List of operands for ins/outs:
|
# List of operands for ins/outs:
|
||||||
OpList = Union[Sequence[Operand], Operand]
|
OpList = Union[Sequence[Operand], Operand]
|
||||||
|
ConstrList = Union[Sequence[TypeConstraint], TypeConstraint]
|
||||||
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
|
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
@@ -80,6 +82,7 @@ class Instruction(object):
|
|||||||
operands and other operand kinds.
|
operands and other operand kinds.
|
||||||
:param outs: Tuple of output operands. The output operands must be SSA
|
:param outs: Tuple of output operands. The output operands must be SSA
|
||||||
values or `variable_args`.
|
values or `variable_args`.
|
||||||
|
:param constraints: Tuple of instruction-specific TypeConstraints.
|
||||||
:param is_terminator: This is a terminator instruction.
|
:param is_terminator: This is a terminator instruction.
|
||||||
:param is_branch: This is a branch instruction.
|
:param is_branch: This is a branch instruction.
|
||||||
:param is_call: This is a call instruction.
|
:param is_call: This is a call instruction.
|
||||||
@@ -102,13 +105,14 @@ class Instruction(object):
|
|||||||
'can_trap': 'Can this instruction cause a trap?',
|
'can_trap': 'Can this instruction cause a trap?',
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, name, doc, ins=(), outs=(), **kwargs):
|
def __init__(self, name, doc, ins=(), outs=(), constraints=(), **kwargs):
|
||||||
# type: (str, str, OpList, OpList, **Any) -> None # noqa
|
# type: (str, str, OpList, OpList, ConstrList, **Any) -> None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.camel_name = camel_case(name)
|
self.camel_name = camel_case(name)
|
||||||
self.__doc__ = doc
|
self.__doc__ = doc
|
||||||
self.ins = self._to_operand_tuple(ins)
|
self.ins = self._to_operand_tuple(ins)
|
||||||
self.outs = self._to_operand_tuple(outs)
|
self.outs = self._to_operand_tuple(outs)
|
||||||
|
self.constraints = self._to_constraint_tuple(constraints)
|
||||||
self.format = InstructionFormat.lookup(self.ins, self.outs)
|
self.format = InstructionFormat.lookup(self.ins, self.outs)
|
||||||
|
|
||||||
# Opcode number, assigned by gen_instr.py.
|
# Opcode number, assigned by gen_instr.py.
|
||||||
@@ -268,6 +272,23 @@ class Instruction(object):
|
|||||||
assert isinstance(op, Operand)
|
assert isinstance(op, Operand)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_constraint_tuple(x):
|
||||||
|
# type: (ConstrList) -> Tuple[TypeConstraint, ...]
|
||||||
|
"""
|
||||||
|
Allow a single TypeConstraint instance instead of the awkward singleton
|
||||||
|
tuple syntax.
|
||||||
|
"""
|
||||||
|
# import placed here to avoid circular dependency
|
||||||
|
from .ti import TypeConstraint # noqa
|
||||||
|
if isinstance(x, TypeConstraint):
|
||||||
|
x = (x,)
|
||||||
|
else:
|
||||||
|
x = tuple(x)
|
||||||
|
for op in x:
|
||||||
|
assert isinstance(op, TypeConstraint)
|
||||||
|
return x
|
||||||
|
|
||||||
def bind(self, *args):
|
def bind(self, *args):
|
||||||
# type: (*ValueType) -> BoundInstruction
|
# type: (*ValueType) -> BoundInstruction
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\
|
from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\
|
||||||
b1, icmp, iadd_cout, iadd_cin, uextend, ireduce
|
b1, icmp, iadd_cout, iadd_cin, uextend, sextend, ireduce, fpromote, \
|
||||||
|
fdemote
|
||||||
from base.legalize import narrow, expand
|
from base.legalize import narrow, expand
|
||||||
from base.immediates import intcc
|
from base.immediates import intcc
|
||||||
from base.types import i32, i8
|
from base.types import i32, i8
|
||||||
from .typevar import TypeVar
|
from .typevar import TypeVar
|
||||||
from .ast import Var, Def
|
from .ast import Var, Def
|
||||||
from .xform import Rtl, XForm
|
from .xform import Rtl, XForm
|
||||||
from .ti import ti_rtl, subst, TypeEnv, get_type_env, ConstrainTVsEqual
|
from .ti import ti_rtl, subst, TypeEnv, get_type_env, TypesEqual, WiderOrEq
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
@@ -52,9 +53,10 @@ def agree(me, other):
|
|||||||
|
|
||||||
# Translate our constraints using m, and sort
|
# Translate our constraints using m, and sort
|
||||||
me_equiv_constr = sorted([constr.translate(m)
|
me_equiv_constr = sorted([constr.translate(m)
|
||||||
for constr in me.constraints])
|
for constr in me.constraints], key=repr)
|
||||||
# Sort other's constraints
|
# Sort other's constraints
|
||||||
other_equiv_constr = sorted(other.constraints)
|
other_equiv_constr = sorted([constr.translate(other)
|
||||||
|
for constr in other.constraints], key=repr)
|
||||||
return me_equiv_constr == other_equiv_constr
|
return me_equiv_constr == other_equiv_constr
|
||||||
|
|
||||||
|
|
||||||
@@ -78,7 +80,7 @@ def check_typing(got_or_err, expected, symtab=None):
|
|||||||
tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()}
|
tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()}
|
||||||
# Rewrite the TVs in the input constraints to their XForm internal
|
# Rewrite the TVs in the input constraints to their XForm internal
|
||||||
# versions
|
# versions
|
||||||
c = [(subst(a, subst_m), subst(b, subst_m)) for (a, b) in c]
|
c = [constr.translate(subst_m) for constr in c]
|
||||||
else:
|
else:
|
||||||
# If no symtab, just convert m from Var->TypeVar map to a
|
# If no symtab, just convert m from Var->TypeVar map to a
|
||||||
# TypeVar->TypeVar map
|
# TypeVar->TypeVar map
|
||||||
@@ -209,7 +211,7 @@ class TestRTL(TypeCheckingBaseTest):
|
|||||||
self.v3: txn,
|
self.v3: txn,
|
||||||
self.v4: txn,
|
self.v4: txn,
|
||||||
self.v5: txn,
|
self.v5: txn,
|
||||||
}, [ConstrainTVsEqual(ixn.as_bool(), txn.as_bool())]))
|
}, [TypesEqual(ixn.as_bool(), txn.as_bool())]))
|
||||||
|
|
||||||
def test_vselect_vsplits(self):
|
def test_vselect_vsplits(self):
|
||||||
# type: () -> None
|
# type: () -> None
|
||||||
@@ -319,6 +321,90 @@ class TestRTL(TypeCheckingBaseTest):
|
|||||||
"Error: empty type created when unifying " +
|
"Error: empty type created when unifying " +
|
||||||
"`typeof_v4` and `typeof_v5`")
|
"`typeof_v4` and `typeof_v5`")
|
||||||
|
|
||||||
|
def test_extend_reduce(self):
|
||||||
|
# type: () -> None
|
||||||
|
r = Rtl(
|
||||||
|
self.v1 << uextend(self.v0),
|
||||||
|
self.v2 << ireduce(self.v1),
|
||||||
|
self.v3 << sextend(self.v2),
|
||||||
|
)
|
||||||
|
ti = TypeEnv()
|
||||||
|
typing = ti_rtl(r, ti)
|
||||||
|
typing = typing.extract()
|
||||||
|
|
||||||
|
itype0 = TypeVar("t", "", ints=True, simd=(1, 256))
|
||||||
|
itype1 = TypeVar("t1", "", ints=True, simd=(1, 256))
|
||||||
|
itype2 = TypeVar("t2", "", ints=True, simd=(1, 256))
|
||||||
|
itype3 = TypeVar("t3", "", ints=True, simd=(1, 256))
|
||||||
|
|
||||||
|
check_typing(typing, ({
|
||||||
|
self.v0: itype0,
|
||||||
|
self.v1: itype1,
|
||||||
|
self.v2: itype2,
|
||||||
|
self.v3: itype3,
|
||||||
|
}, [WiderOrEq(itype1, itype0),
|
||||||
|
WiderOrEq(itype1, itype2),
|
||||||
|
WiderOrEq(itype3, itype2)]))
|
||||||
|
|
||||||
|
def test_extend_reduce_enumeration(self):
|
||||||
|
# type: () -> None
|
||||||
|
for op in (uextend, sextend, ireduce):
|
||||||
|
r = Rtl(
|
||||||
|
self.v1 << op(self.v0),
|
||||||
|
)
|
||||||
|
ti = TypeEnv()
|
||||||
|
typing = ti_rtl(r, ti).extract()
|
||||||
|
|
||||||
|
# The number of possible typings is 9 * (3+ 2*2 + 3) = 90
|
||||||
|
l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
|
||||||
|
assert (len(l) == len(set(l)) and len(l) == 90)
|
||||||
|
for (tv0, tv1) in l:
|
||||||
|
typ0, typ1 = (tv0.singleton_type(), tv1.singleton_type())
|
||||||
|
if (op == ireduce):
|
||||||
|
assert typ0.wider_or_equal(typ1)
|
||||||
|
else:
|
||||||
|
assert typ1.wider_or_equal(typ0)
|
||||||
|
|
||||||
|
def test_fpromote_fdemote(self):
|
||||||
|
# type: () -> None
|
||||||
|
r = Rtl(
|
||||||
|
self.v1 << fpromote(self.v0),
|
||||||
|
self.v2 << fdemote(self.v1),
|
||||||
|
)
|
||||||
|
ti = TypeEnv()
|
||||||
|
typing = ti_rtl(r, ti)
|
||||||
|
typing = typing.extract()
|
||||||
|
|
||||||
|
ftype0 = TypeVar("t", "", floats=True, simd=(1, 256))
|
||||||
|
ftype1 = TypeVar("t1", "", floats=True, simd=(1, 256))
|
||||||
|
ftype2 = TypeVar("t2", "", floats=True, simd=(1, 256))
|
||||||
|
|
||||||
|
check_typing(typing, ({
|
||||||
|
self.v0: ftype0,
|
||||||
|
self.v1: ftype1,
|
||||||
|
self.v2: ftype2,
|
||||||
|
}, [WiderOrEq(ftype1, ftype0),
|
||||||
|
WiderOrEq(ftype1, ftype2)]))
|
||||||
|
|
||||||
|
def test_fpromote_fdemote_enumeration(self):
|
||||||
|
# type: () -> None
|
||||||
|
for op in (fpromote, fdemote):
|
||||||
|
r = Rtl(
|
||||||
|
self.v1 << op(self.v0),
|
||||||
|
)
|
||||||
|
ti = TypeEnv()
|
||||||
|
typing = ti_rtl(r, ti).extract()
|
||||||
|
|
||||||
|
# The number of possible typings is 9*(2 + 1) = 27
|
||||||
|
l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
|
||||||
|
assert (len(l) == len(set(l)) and len(l) == 27)
|
||||||
|
for (tv0, tv1) in l:
|
||||||
|
(typ0, typ1) = (tv0.singleton_type(), tv1.singleton_type())
|
||||||
|
if (op == fdemote):
|
||||||
|
assert typ0.wider_or_equal(typ1)
|
||||||
|
else:
|
||||||
|
assert typ1.wider_or_equal(typ0)
|
||||||
|
|
||||||
|
|
||||||
class TestXForm(TypeCheckingBaseTest):
|
class TestXForm(TypeCheckingBaseTest):
|
||||||
def test_iadd_cout(self):
|
def test_iadd_cout(self):
|
||||||
@@ -453,7 +539,7 @@ class TestXForm(TypeCheckingBaseTest):
|
|||||||
self.v3: i32t,
|
self.v3: i32t,
|
||||||
self.v4: i32t,
|
self.v4: i32t,
|
||||||
self.v5: i32t,
|
self.v5: i32t,
|
||||||
}, []), x.symtab)
|
}, [WiderOrEq(i32t, itype)]), x.symtab)
|
||||||
|
|
||||||
def test_bound_inst_inference1(self):
|
def test_bound_inst_inference1(self):
|
||||||
# Second example taken from issue #26
|
# Second example taken from issue #26
|
||||||
@@ -477,7 +563,7 @@ class TestXForm(TypeCheckingBaseTest):
|
|||||||
self.v3: i32t,
|
self.v3: i32t,
|
||||||
self.v4: i32t,
|
self.v4: i32t,
|
||||||
self.v5: i32t,
|
self.v5: i32t,
|
||||||
}, []), x.symtab)
|
}, [WiderOrEq(i32t, itype)]), x.symtab)
|
||||||
|
|
||||||
def test_fully_bound_inst_inference(self):
|
def test_fully_bound_inst_inference(self):
|
||||||
# Second example taken from issue #26 with complete bounds
|
# Second example taken from issue #26 with complete bounds
|
||||||
@@ -494,6 +580,7 @@ class TestXForm(TypeCheckingBaseTest):
|
|||||||
i8t = TypeVar.singleton(i8)
|
i8t = TypeVar.singleton(i8)
|
||||||
i32t = TypeVar.singleton(i32)
|
i32t = TypeVar.singleton(i32)
|
||||||
|
|
||||||
|
# Note no constraints here since they are all trivial
|
||||||
check_typing(x.ti, ({
|
check_typing(x.ti, ({
|
||||||
self.v0: i8t,
|
self.v0: i8t,
|
||||||
self.v1: i8t,
|
self.v1: i8t,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from itertools import product
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
||||||
from typing import Iterable, List # noqa
|
from typing import Iterable, List, Any # noqa
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from .xform import Rtl, XForm # noqa
|
from .xform import Rtl, XForm # noqa
|
||||||
from .ast import Expr # noqa
|
from .ast import Expr # noqa
|
||||||
@@ -25,9 +25,72 @@ class TypeConstraint(object):
|
|||||||
"""
|
"""
|
||||||
Base class for all runtime-emittable type constraints.
|
Base class for all runtime-emittable type constraints.
|
||||||
"""
|
"""
|
||||||
|
def translate(self, m):
|
||||||
|
# type: (Union[TypeEnv, TypeMap]) -> TypeConstraint
|
||||||
|
"""
|
||||||
|
Translate any TypeVars in the constraint according to the map or
|
||||||
|
TypeEnv m
|
||||||
|
"""
|
||||||
|
def translate_one(a):
|
||||||
|
# type: (Any) -> Any
|
||||||
|
if (isinstance(a, TypeVar)):
|
||||||
|
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
|
||||||
|
return a
|
||||||
|
|
||||||
|
res = None # type: TypeConstraint
|
||||||
|
res = self.__class__(*tuple(map(translate_one, self._args())))
|
||||||
|
return res
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
# type: (object) -> bool
|
||||||
|
if (not isinstance(other, self.__class__)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
assert isinstance(other, TypeConstraint) # help MyPy figure out other
|
||||||
|
return self._args() == other._args()
|
||||||
|
|
||||||
|
def is_concrete(self):
|
||||||
|
# type: () -> bool
|
||||||
|
"""
|
||||||
|
Return true iff all typevars in the constraint are singletons.
|
||||||
|
"""
|
||||||
|
tvs = filter(lambda x: isinstance(x, TypeVar), self._args())
|
||||||
|
return [] == list(filter(lambda x: x.singleton_type() is None, tvs))
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
# type: () -> int
|
||||||
|
return hash(self._args())
|
||||||
|
|
||||||
|
def _args(self):
|
||||||
|
# type: () -> Tuple[Any,...]
|
||||||
|
"""
|
||||||
|
Return a tuple with the exact arguments passed to __init__ to create
|
||||||
|
this object.
|
||||||
|
"""
|
||||||
|
assert False, "Abstract"
|
||||||
|
|
||||||
|
def is_trivial(self):
|
||||||
|
# type: () -> bool
|
||||||
|
"""
|
||||||
|
Return true if this constrain is statically decidable.
|
||||||
|
"""
|
||||||
|
assert False, "Abstract"
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
# type: () -> bool
|
||||||
|
"""
|
||||||
|
Evaluate this constraint. Should only be called when the constraint has
|
||||||
|
been translated to concrete types.
|
||||||
|
"""
|
||||||
|
assert False, "Abstract"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
# type: () -> str
|
||||||
|
return (self.__class__.__name__ + '(' +
|
||||||
|
', '.join(map(str, self._args())) + ')')
|
||||||
|
|
||||||
|
|
||||||
class ConstrainTVsEqual(TypeConstraint):
|
class TypesEqual(TypeConstraint):
|
||||||
"""
|
"""
|
||||||
Constraint specifying that two derived type vars must have the same runtime
|
Constraint specifying that two derived type vars must have the same runtime
|
||||||
type.
|
type.
|
||||||
@@ -37,48 +100,24 @@ class ConstrainTVsEqual(TypeConstraint):
|
|||||||
assert tv1.is_derived and tv2.is_derived
|
assert tv1.is_derived and tv2.is_derived
|
||||||
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
|
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
|
||||||
|
|
||||||
|
def _args(self):
|
||||||
|
# type: () -> Tuple[Any,...]
|
||||||
|
""" See TypeConstraint._args() """
|
||||||
|
return (self.tv1, self.tv2)
|
||||||
|
|
||||||
def is_trivial(self):
|
def is_trivial(self):
|
||||||
# type: () -> bool
|
# type: () -> bool
|
||||||
"""
|
""" See TypeConstraint.is_trivial() """
|
||||||
Return true if this constrain is statically decidable.
|
return self.tv1 == self.tv2 or self.is_concrete()
|
||||||
"""
|
|
||||||
return self.tv1 == self.tv2 or \
|
|
||||||
(self.tv1.singleton_type() is not None and
|
|
||||||
self.tv2.singleton_type() is not None)
|
|
||||||
|
|
||||||
def translate(self, m):
|
|
||||||
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual
|
|
||||||
"""
|
|
||||||
Translate any TypeVars in the constraint according to the map m
|
|
||||||
"""
|
|
||||||
if isinstance(m, TypeEnv):
|
|
||||||
return ConstrainTVsEqual(m[self.tv1], m[self.tv2])
|
|
||||||
else:
|
|
||||||
return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m))
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
# type: (object) -> bool
|
|
||||||
if (not isinstance(other, ConstrainTVsEqual)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return (self.tv1, self.tv2) == (other.tv1, other.tv2)
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
# type: () -> int
|
|
||||||
return hash((self.tv1, self.tv2))
|
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
# type: () -> bool
|
# type: () -> bool
|
||||||
"""
|
""" See TypeConstraint.eval() """
|
||||||
Evaluate this constraint. Should only be called when the constraint has
|
assert self.is_concrete()
|
||||||
been translated to concrete types.
|
|
||||||
"""
|
|
||||||
assert self.tv1.singleton_type() is not None and \
|
|
||||||
self.tv2.singleton_type() is not None
|
|
||||||
return self.tv1.singleton_type() == self.tv2.singleton_type()
|
return self.tv1.singleton_type() == self.tv2.singleton_type()
|
||||||
|
|
||||||
|
|
||||||
class ConstrainTVInTypeset(TypeConstraint):
|
class InTypeset(TypeConstraint):
|
||||||
"""
|
"""
|
||||||
Constraint specifying that a type var must belong to some typeset.
|
Constraint specifying that a type var must belong to some typeset.
|
||||||
"""
|
"""
|
||||||
@@ -88,11 +127,14 @@ class ConstrainTVInTypeset(TypeConstraint):
|
|||||||
self.tv = tv
|
self.tv = tv
|
||||||
self.ts = ts
|
self.ts = ts
|
||||||
|
|
||||||
|
def _args(self):
|
||||||
|
# type: () -> Tuple[Any,...]
|
||||||
|
""" See TypeConstraint._args() """
|
||||||
|
return (self.tv, self.ts)
|
||||||
|
|
||||||
def is_trivial(self):
|
def is_trivial(self):
|
||||||
# type: () -> bool
|
# type: () -> bool
|
||||||
"""
|
""" See TypeConstraint.is_trivial() """
|
||||||
Return true if this constrain is statically decidable.
|
|
||||||
"""
|
|
||||||
tv_ts = self.tv.get_typeset().copy()
|
tv_ts = self.tv.get_typeset().copy()
|
||||||
|
|
||||||
# Trivially True
|
# Trivially True
|
||||||
@@ -104,39 +146,78 @@ class ConstrainTVInTypeset(TypeConstraint):
|
|||||||
if (tv_ts.size() == 0):
|
if (tv_ts.size() == 0):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return self.is_concrete()
|
||||||
|
|
||||||
def translate(self, m):
|
|
||||||
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset
|
|
||||||
"""
|
|
||||||
Translate any TypeVars in the constraint according to the map m
|
|
||||||
"""
|
|
||||||
if isinstance(m, TypeEnv):
|
|
||||||
return ConstrainTVInTypeset(m[self.tv], self.ts)
|
|
||||||
else:
|
|
||||||
return ConstrainTVInTypeset(subst(self.tv, m), self.ts)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
# type: (object) -> bool
|
|
||||||
if (not isinstance(other, ConstrainTVInTypeset)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return (self.tv, self.ts) == (other.tv, other.ts)
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
# type: () -> int
|
|
||||||
return hash((self.tv, self.ts))
|
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
# type: () -> bool
|
# type: () -> bool
|
||||||
"""
|
""" See TypeConstraint.eval() """
|
||||||
Evaluate this constraint. Should only be called when the constraint has
|
assert self.is_concrete()
|
||||||
been translated to concrete types.
|
|
||||||
"""
|
|
||||||
assert self.tv.singleton_type() is not None
|
|
||||||
return self.tv.get_typeset().issubset(self.ts)
|
return self.tv.get_typeset().issubset(self.ts)
|
||||||
|
|
||||||
|
|
||||||
|
class WiderOrEq(TypeConstraint):
|
||||||
|
"""
|
||||||
|
Constraint specifying that a type var tv1 must be wider than or equal to
|
||||||
|
type var tv2 at runtime. This requires that:
|
||||||
|
1) They have the same number of lanes
|
||||||
|
2) In a lane tv1 has at least as many bits as tv2.
|
||||||
|
"""
|
||||||
|
def __init__(self, tv1, tv2):
|
||||||
|
# type: (TypeVar, TypeVar) -> None
|
||||||
|
self.tv1 = tv1
|
||||||
|
self.tv2 = tv2
|
||||||
|
|
||||||
|
def _args(self):
|
||||||
|
# type: () -> Tuple[Any,...]
|
||||||
|
""" See TypeConstraint._args() """
|
||||||
|
return (self.tv1, self.tv2)
|
||||||
|
|
||||||
|
def is_trivial(self):
|
||||||
|
# type: () -> bool
|
||||||
|
""" See TypeConstraint.is_trivial() """
|
||||||
|
# Trivially true
|
||||||
|
if (self.tv1 == self.tv2):
|
||||||
|
return True
|
||||||
|
|
||||||
|
ts1 = self.tv1.get_typeset()
|
||||||
|
ts2 = self.tv2.get_typeset()
|
||||||
|
|
||||||
|
def set_wider_or_equal(s1, s2):
|
||||||
|
# type: (Set[int], Set[int]) -> bool
|
||||||
|
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
|
||||||
|
|
||||||
|
# Trivially True
|
||||||
|
if set_wider_or_equal(ts1.ints, ts2.ints) and\
|
||||||
|
set_wider_or_equal(ts1.floats, ts2.floats) and\
|
||||||
|
set_wider_or_equal(ts1.bools, ts2.bools):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def set_narrower(s1, s2):
|
||||||
|
# type: (Set[int], Set[int]) -> bool
|
||||||
|
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
|
||||||
|
|
||||||
|
# Trivially False
|
||||||
|
if set_narrower(ts1.ints, ts2.ints) and\
|
||||||
|
set_narrower(ts1.floats, ts2.floats) and\
|
||||||
|
set_narrower(ts1.bools, ts2.bools):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Trivially False
|
||||||
|
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return self.is_concrete()
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
# type: () -> bool
|
||||||
|
""" See TypeConstraint.eval() """
|
||||||
|
assert self.is_concrete()
|
||||||
|
typ1 = self.tv1.singleton_type()
|
||||||
|
typ2 = self.tv2.singleton_type()
|
||||||
|
|
||||||
|
return typ1.wider_or_equal(typ2)
|
||||||
|
|
||||||
|
|
||||||
class TypeEnv(object):
|
class TypeEnv(object):
|
||||||
"""
|
"""
|
||||||
Class encapsulating the neccessary book keeping for type inference.
|
Class encapsulating the neccessary book keeping for type inference.
|
||||||
@@ -204,12 +285,11 @@ class TypeEnv(object):
|
|||||||
|
|
||||||
self.type_map[tv1] = tv2
|
self.type_map[tv1] = tv2
|
||||||
|
|
||||||
def add_constraint(self, tv1, tv2):
|
def add_constraint(self, constr):
|
||||||
# type: (TypeVar, TypeVar) -> None
|
# type: (TypeConstraint) -> None
|
||||||
"""
|
"""
|
||||||
Add a new equivalence constraint between tv1 and tv2
|
Add a new constraint
|
||||||
"""
|
"""
|
||||||
constr = ConstrainTVsEqual(tv1, tv2)
|
|
||||||
if (constr not in self.constraints):
|
if (constr not in self.constraints):
|
||||||
self.constraints.append(constr)
|
self.constraints.append(constr)
|
||||||
|
|
||||||
@@ -261,6 +341,7 @@ class TypeEnv(object):
|
|||||||
Get the free typevars in the current type env.
|
Get the free typevars in the current type env.
|
||||||
"""
|
"""
|
||||||
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
|
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
|
||||||
|
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
|
||||||
# Filter out None here due to singleton type vars
|
# Filter out None here due to singleton type vars
|
||||||
return sorted(filter(lambda x: x is not None, tvs),
|
return sorted(filter(lambda x: x is not None, tvs),
|
||||||
key=lambda x: x.name)
|
key=lambda x: x.name)
|
||||||
@@ -326,17 +407,18 @@ class TypeEnv(object):
|
|||||||
|
|
||||||
new_constraints = [] # type: List[TypeConstraint]
|
new_constraints = [] # type: List[TypeConstraint]
|
||||||
for constr in self.constraints:
|
for constr in self.constraints:
|
||||||
# Currently typeinference only generates ConstrainTVsEqual
|
|
||||||
# constraints
|
|
||||||
assert isinstance(constr, ConstrainTVsEqual)
|
|
||||||
constr = constr.translate(self)
|
constr = constr.translate(self)
|
||||||
|
|
||||||
if constr.is_trivial() or constr in new_constraints:
|
if constr.is_trivial() or constr in new_constraints:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Sanity: translated constraints should refer to only real vars
|
# Sanity: translated constraints should refer to only real vars
|
||||||
assert constr.tv1.free_typevar() in vars_tvs and\
|
for arg in constr._args():
|
||||||
constr.tv2.free_typevar() in vars_tvs
|
if (not isinstance(arg, TypeVar)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
arg_free_tv = arg.free_typevar()
|
||||||
|
assert arg_free_tv is None or arg_free_tv in vars_tvs
|
||||||
|
|
||||||
new_constraints.append(constr)
|
new_constraints.append(constr)
|
||||||
|
|
||||||
@@ -372,9 +454,6 @@ class TypeEnv(object):
|
|||||||
# Check if constraints are satisfied for this typing
|
# Check if constraints are satisfied for this typing
|
||||||
failed = None
|
failed = None
|
||||||
for constr in self.constraints:
|
for constr in self.constraints:
|
||||||
# Currently typeinference only generates ConstrainTVsEqual
|
|
||||||
# constraints
|
|
||||||
assert isinstance(constr, ConstrainTVsEqual)
|
|
||||||
concrete_constr = constr.translate(m)
|
concrete_constr = constr.translate(m)
|
||||||
if not concrete_constr.eval():
|
if not concrete_constr.eval():
|
||||||
failed = concrete_constr
|
failed = concrete_constr
|
||||||
@@ -401,22 +480,27 @@ class TypeEnv(object):
|
|||||||
# Add all registered TVs (as some of them may be singleton nodes not
|
# Add all registered TVs (as some of them may be singleton nodes not
|
||||||
# appearing in the graph
|
# appearing in the graph
|
||||||
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
|
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
|
||||||
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, Optional[str]]]
|
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
|
||||||
|
|
||||||
for (k, v) in self.type_map.items():
|
for (k, v) in self.type_map.items():
|
||||||
# Add all intermediate TVs appearing in edges
|
# Add all intermediate TVs appearing in edges
|
||||||
nodes.add(k)
|
nodes.add(k)
|
||||||
nodes.add(v)
|
nodes.add(v)
|
||||||
edges.add((k, v, "dotted", None))
|
edges.add((k, v, "dotted", "forward", None))
|
||||||
while (v.is_derived):
|
while (v.is_derived):
|
||||||
nodes.add(v.base)
|
nodes.add(v.base)
|
||||||
edges.add((v, v.base, "solid", v.derived_func))
|
edges.add((v, v.base, "solid", "forward", v.derived_func))
|
||||||
v = v.base
|
v = v.base
|
||||||
|
|
||||||
for constr in self.constraints:
|
for constr in self.constraints:
|
||||||
assert isinstance(constr, ConstrainTVsEqual)
|
if isinstance(constr, TypesEqual):
|
||||||
assert constr.tv1 in nodes and constr.tv2 in nodes
|
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||||
edges.add((constr.tv1, constr.tv2, "dashed", None))
|
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
|
||||||
|
elif isinstance(constr, WiderOrEq):
|
||||||
|
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||||
|
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
|
||||||
|
else:
|
||||||
|
assert False, "Can't display constraint {}".format(constr)
|
||||||
|
|
||||||
root_nodes = set([x for x in nodes
|
root_nodes = set([x for x in nodes
|
||||||
if x not in self.type_map and not x.is_derived])
|
if x not in self.type_map and not x.is_derived])
|
||||||
@@ -428,17 +512,12 @@ class TypeEnv(object):
|
|||||||
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
|
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
|
||||||
r += ";\n"
|
r += ";\n"
|
||||||
|
|
||||||
for (n1, n2, style, elabel) in edges:
|
for (n1, n2, style, direction, elabel) in edges:
|
||||||
e = label(n1)
|
e = label(n1) + "->" + label(n2)
|
||||||
if style == "dashed":
|
e += "[style={},dir={}".format(style, direction)
|
||||||
e += '--'
|
|
||||||
else:
|
|
||||||
e += '->'
|
|
||||||
e += label(n2)
|
|
||||||
e += "[style={}".format(style)
|
|
||||||
|
|
||||||
if elabel is not None:
|
if elabel is not None:
|
||||||
e += ",label={}".format(elabel)
|
e += ",label=\"{}\"".format(elabel)
|
||||||
e += "];\n"
|
e += "];\n"
|
||||||
|
|
||||||
r += e
|
r += e
|
||||||
@@ -589,7 +668,7 @@ def unify(tv1, tv2, typ):
|
|||||||
inv_f = TypeVar.inverse_func(tv1.derived_func)
|
inv_f = TypeVar.inverse_func(tv1.derived_func)
|
||||||
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
|
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
|
||||||
|
|
||||||
typ.add_constraint(tv1, tv2)
|
typ.add_constraint(TypesEqual(tv1, tv2))
|
||||||
return typ
|
return typ
|
||||||
|
|
||||||
|
|
||||||
@@ -648,6 +727,10 @@ def ti_def(definition, typ):
|
|||||||
|
|
||||||
typ = get_type_env(typ_or_err)
|
typ = get_type_env(typ_or_err)
|
||||||
|
|
||||||
|
# Add any instruction specific constraints
|
||||||
|
for constr in inst.constraints:
|
||||||
|
typ.add_constraint(constr.translate(m))
|
||||||
|
|
||||||
return typ
|
return typ
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,26 @@ class ValueType(object):
|
|||||||
else:
|
else:
|
||||||
raise AttributeError("No type named '{}'".format(name))
|
raise AttributeError("No type named '{}'".format(name))
|
||||||
|
|
||||||
|
def lane_bits(self):
|
||||||
|
# type: () -> int
|
||||||
|
"""Return the number of bits in a lane."""
|
||||||
|
assert False, "Abstract"
|
||||||
|
|
||||||
|
def lane_count(self):
|
||||||
|
# type: () -> int
|
||||||
|
"""Return the number of lanes."""
|
||||||
|
assert False, "Abstract"
|
||||||
|
|
||||||
|
def wider_or_equal(self, other):
|
||||||
|
# type: (ValueType) -> bool
|
||||||
|
"""
|
||||||
|
Return true iff:
|
||||||
|
1. self and other have equal number of lanes
|
||||||
|
2. each lane in self has at least as many bits as a lane in other
|
||||||
|
"""
|
||||||
|
return (self.lane_count() == other.lane_count() and
|
||||||
|
self.lane_bits() >= other.lane_bits())
|
||||||
|
|
||||||
|
|
||||||
class ScalarType(ValueType):
|
class ScalarType(ValueType):
|
||||||
"""
|
"""
|
||||||
@@ -85,6 +105,11 @@ class ScalarType(ValueType):
|
|||||||
self._vectors[lanes] = v
|
self._vectors[lanes] = v
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
def lane_count(self):
|
||||||
|
# type: () -> int
|
||||||
|
"""Return the number of lanes."""
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
class VectorType(ValueType):
|
class VectorType(ValueType):
|
||||||
"""
|
"""
|
||||||
@@ -112,6 +137,16 @@ class VectorType(ValueType):
|
|||||||
return ('VectorType(base={}, lanes={})'
|
return ('VectorType(base={}, lanes={})'
|
||||||
.format(self.base.name, self.lanes))
|
.format(self.base.name, self.lanes))
|
||||||
|
|
||||||
|
def lane_count(self):
|
||||||
|
# type: () -> int
|
||||||
|
"""Return the number of lanes."""
|
||||||
|
return self.lanes
|
||||||
|
|
||||||
|
def lane_bits(self):
|
||||||
|
# type: () -> int
|
||||||
|
"""Return the number of bits in a lane."""
|
||||||
|
return self.base.lane_bits()
|
||||||
|
|
||||||
|
|
||||||
class IntType(ScalarType):
|
class IntType(ScalarType):
|
||||||
"""A concrete scalar integer type."""
|
"""A concrete scalar integer type."""
|
||||||
@@ -138,6 +173,11 @@ class IntType(ScalarType):
|
|||||||
else:
|
else:
|
||||||
return typ
|
return typ
|
||||||
|
|
||||||
|
def lane_bits(self):
|
||||||
|
# type: () -> int
|
||||||
|
"""Return the number of bits in a lane."""
|
||||||
|
return self.bits
|
||||||
|
|
||||||
|
|
||||||
class FloatType(ScalarType):
|
class FloatType(ScalarType):
|
||||||
"""A concrete scalar floating point type."""
|
"""A concrete scalar floating point type."""
|
||||||
@@ -164,6 +204,11 @@ class FloatType(ScalarType):
|
|||||||
else:
|
else:
|
||||||
return typ
|
return typ
|
||||||
|
|
||||||
|
def lane_bits(self):
|
||||||
|
# type: () -> int
|
||||||
|
"""Return the number of bits in a lane."""
|
||||||
|
return self.bits
|
||||||
|
|
||||||
|
|
||||||
class BoolType(ScalarType):
|
class BoolType(ScalarType):
|
||||||
"""A concrete scalar boolean type."""
|
"""A concrete scalar boolean type."""
|
||||||
@@ -189,3 +234,8 @@ class BoolType(ScalarType):
|
|||||||
return cast(BoolType, typ)
|
return cast(BoolType, typ)
|
||||||
else:
|
else:
|
||||||
return typ
|
return typ
|
||||||
|
|
||||||
|
def lane_bits(self):
|
||||||
|
# type: () -> int
|
||||||
|
"""Return the number of bits in a lane."""
|
||||||
|
return self.bits
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ from __future__ import absolute_import
|
|||||||
from srcgen import Formatter
|
from srcgen import Formatter
|
||||||
from base import legalize, instructions
|
from base import legalize, instructions
|
||||||
from cdsl.ast import Var
|
from cdsl.ast import Var
|
||||||
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, ConstrainTVsEqual,\
|
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, TypesEqual,\
|
||||||
ConstrainTVInTypeset
|
InTypeset, WiderOrEq
|
||||||
from unique_table import UniqueTable
|
from unique_table import UniqueTable
|
||||||
from gen_instr import gen_typesets_table
|
from gen_instr import gen_typesets_table
|
||||||
from cdsl.typevar import TypeVar
|
from cdsl.typevar import TypeVar
|
||||||
@@ -66,7 +66,7 @@ def get_runtime_typechecks(xform):
|
|||||||
|
|
||||||
assert xform_ts.issubset(src_ts)
|
assert xform_ts.issubset(src_ts)
|
||||||
if src_ts != xform_ts:
|
if src_ts != xform_ts:
|
||||||
check_l.append(ConstrainTVInTypeset(xform.ti[v], xform_ts))
|
check_l.append(InTypeset(xform.ti[v], xform_ts))
|
||||||
|
|
||||||
# 2,3) Add any constraints that appear in xform.ti
|
# 2,3) Add any constraints that appear in xform.ti
|
||||||
check_l.extend(xform.ti.constraints)
|
check_l.extend(xform.ti.constraints)
|
||||||
@@ -81,6 +81,14 @@ def emit_runtime_typecheck(check, fmt, type_sets):
|
|||||||
"""
|
"""
|
||||||
def build_derived_expr(tv):
|
def build_derived_expr(tv):
|
||||||
# type: (TypeVar) -> str
|
# type: (TypeVar) -> str
|
||||||
|
"""
|
||||||
|
Build an expression of type Option<Type> corresponding to a concrete
|
||||||
|
type transformed by the sequence of derivation functions in tv.
|
||||||
|
|
||||||
|
We are using Option<Type>, as some constraints may cause an
|
||||||
|
over/underflow on patterns that do not match them. We want to capture
|
||||||
|
this without panicking at runtime.
|
||||||
|
"""
|
||||||
if not tv.is_derived:
|
if not tv.is_derived:
|
||||||
assert tv.name.startswith('typeof_')
|
assert tv.name.startswith('typeof_')
|
||||||
return "Some({})".format(tv.name)
|
return "Some({})".format(tv.name)
|
||||||
@@ -102,7 +110,8 @@ def emit_runtime_typecheck(check, fmt, type_sets):
|
|||||||
else:
|
else:
|
||||||
assert False, "Unknown derived function {}".format(tv.derived_func)
|
assert False, "Unknown derived function {}".format(tv.derived_func)
|
||||||
|
|
||||||
if (isinstance(check, ConstrainTVInTypeset)):
|
if (isinstance(check, InTypeset)):
|
||||||
|
assert not check.tv.is_derived
|
||||||
tv = check.tv.name
|
tv = check.tv.name
|
||||||
if check.ts not in type_sets.index:
|
if check.ts not in type_sets.index:
|
||||||
type_sets.add(check.ts)
|
type_sets.add(check.ts)
|
||||||
@@ -112,11 +121,28 @@ def emit_runtime_typecheck(check, fmt, type_sets):
|
|||||||
with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv),
|
with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv),
|
||||||
'};'):
|
'};'):
|
||||||
fmt.line('return false;')
|
fmt.line('return false;')
|
||||||
elif (isinstance(check, ConstrainTVsEqual)):
|
elif (isinstance(check, TypesEqual)):
|
||||||
tv1 = build_derived_expr(check.tv1)
|
with fmt.indented('{', '};'):
|
||||||
tv2 = build_derived_expr(check.tv2)
|
fmt.line('let a = {};'.format(build_derived_expr(check.tv1)))
|
||||||
with fmt.indented('if {} != {} {{'.format(tv1, tv2), '};'):
|
fmt.line('let b = {};'.format(build_derived_expr(check.tv2)))
|
||||||
fmt.line('return false;')
|
|
||||||
|
fmt.comment('On overflow constraint doesn\'t appply')
|
||||||
|
with fmt.indented('if a.is_none() || b.is_none() {', '};'):
|
||||||
|
fmt.line('return false;')
|
||||||
|
|
||||||
|
with fmt.indented('if a != b {', '};'):
|
||||||
|
fmt.line('return false;')
|
||||||
|
elif (isinstance(check, WiderOrEq)):
|
||||||
|
with fmt.indented('{', '};'):
|
||||||
|
fmt.line('let a = {};'.format(build_derived_expr(check.tv1)))
|
||||||
|
fmt.line('let b = {};'.format(build_derived_expr(check.tv2)))
|
||||||
|
|
||||||
|
fmt.comment('On overflow constraint doesn\'t appply')
|
||||||
|
with fmt.indented('if a.is_none() || b.is_none() {', '};'):
|
||||||
|
fmt.line('return false;')
|
||||||
|
|
||||||
|
with fmt.indented('if !a.wider_or_equal(b) {', '};'):
|
||||||
|
fmt.line('return false;')
|
||||||
else:
|
else:
|
||||||
assert False, "Unknown check {}".format(check)
|
assert False, "Unknown check {}".format(check)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from unittest import TestCase
|
|||||||
from srcgen import Formatter
|
from srcgen import Formatter
|
||||||
from gen_legalizer import get_runtime_typechecks, emit_runtime_typecheck
|
from gen_legalizer import get_runtime_typechecks, emit_runtime_typecheck
|
||||||
from base.instructions import vselect, vsplit, isplit, iconcat, vconcat, \
|
from base.instructions import vselect, vsplit, isplit, iconcat, vconcat, \
|
||||||
iconst, b1, icmp, copy # noqa
|
iconst, b1, icmp, copy, sextend, uextend, ireduce, fdemote, fpromote # noqa
|
||||||
from base.legalize import narrow, expand # noqa
|
from base.legalize import narrow, expand # noqa
|
||||||
from base.immediates import intcc # noqa
|
from base.immediates import intcc # noqa
|
||||||
from cdsl.typevar import TypeVar, TypeSet
|
from cdsl.typevar import TypeVar, TypeSet
|
||||||
@@ -57,9 +57,32 @@ def equiv_check(tv1, tv2):
|
|||||||
# type: (TypeVar, TypeVar) -> CheckProducer
|
# type: (TypeVar, TypeVar) -> CheckProducer
|
||||||
return lambda typesets: format_check(
|
return lambda typesets: format_check(
|
||||||
typesets,
|
typesets,
|
||||||
'if Some({}).map(|t: Type| -> t.as_bool()) != ' +
|
'{{\n' +
|
||||||
'Some({}).map(|t: Type| -> t.as_bool()) ' +
|
' let a = {};\n' +
|
||||||
'{{\n return false;\n}};\n', tv1, tv2)
|
' let b = {};\n' +
|
||||||
|
' if a.is_none() || b.is_none() {{\n' +
|
||||||
|
' return false;\n' +
|
||||||
|
' }};\n' +
|
||||||
|
' if a != b {{\n' +
|
||||||
|
' return false;\n' +
|
||||||
|
' }};\n' +
|
||||||
|
'}};\n', tv1, tv2)
|
||||||
|
|
||||||
|
|
||||||
|
def wider_check(tv1, tv2):
|
||||||
|
# type: (TypeVar, TypeVar) -> CheckProducer
|
||||||
|
return lambda typesets: format_check(
|
||||||
|
typesets,
|
||||||
|
'{{\n' +
|
||||||
|
' let a = {};\n' +
|
||||||
|
' let b = {};\n' +
|
||||||
|
' if a.is_none() || b.is_none() {{\n' +
|
||||||
|
' return false;\n' +
|
||||||
|
' }};\n' +
|
||||||
|
' if !a.wider_or_equal(b) {{\n' +
|
||||||
|
' return false;\n' +
|
||||||
|
' }};\n' +
|
||||||
|
'}};\n', tv1, tv2)
|
||||||
|
|
||||||
|
|
||||||
def sequence(*args):
|
def sequence(*args):
|
||||||
@@ -138,8 +161,49 @@ class TestRuntimeChecks(TestCase):
|
|||||||
self.v5 << vselect(self.v1, self.v3, self.v4),
|
self.v5 << vselect(self.v1, self.v3, self.v4),
|
||||||
)
|
)
|
||||||
x = XForm(r, r)
|
x = XForm(r, r)
|
||||||
|
tv2_exp = 'Some({}).map(|t: Type| -> t.as_bool())'\
|
||||||
|
.format(self.v2.get_typevar().name)
|
||||||
|
tv3_exp = 'Some({}).map(|t: Type| -> t.as_bool())'\
|
||||||
|
.format(self.v3.get_typevar().name)
|
||||||
|
|
||||||
self.check_yo_check(
|
self.check_yo_check(
|
||||||
x, sequence(typeset_check(self.v3, ts),
|
x, sequence(typeset_check(self.v3, ts),
|
||||||
equiv_check(self.v2.get_typevar(),
|
equiv_check(tv2_exp, tv3_exp)))
|
||||||
self.v3.get_typevar())))
|
|
||||||
|
def test_reduce_extend(self):
|
||||||
|
# type: () -> None
|
||||||
|
r = Rtl(
|
||||||
|
self.v1 << uextend(self.v0),
|
||||||
|
self.v2 << ireduce(self.v1),
|
||||||
|
self.v3 << sextend(self.v2),
|
||||||
|
)
|
||||||
|
x = XForm(r, r)
|
||||||
|
|
||||||
|
tv0_exp = 'Some({})'.format(self.v0.get_typevar().name)
|
||||||
|
tv1_exp = 'Some({})'.format(self.v1.get_typevar().name)
|
||||||
|
tv2_exp = 'Some({})'.format(self.v2.get_typevar().name)
|
||||||
|
tv3_exp = 'Some({})'.format(self.v3.get_typevar().name)
|
||||||
|
|
||||||
|
self.check_yo_check(
|
||||||
|
x, sequence(wider_check(tv1_exp, tv0_exp),
|
||||||
|
wider_check(tv1_exp, tv2_exp),
|
||||||
|
wider_check(tv3_exp, tv2_exp)))
|
||||||
|
|
||||||
|
def test_demote_promote(self):
|
||||||
|
# type: () -> None
|
||||||
|
r = Rtl(
|
||||||
|
self.v1 << fpromote(self.v0),
|
||||||
|
self.v2 << fdemote(self.v1),
|
||||||
|
self.v3 << fpromote(self.v2),
|
||||||
|
)
|
||||||
|
x = XForm(r, r)
|
||||||
|
|
||||||
|
tv0_exp = 'Some({})'.format(self.v0.get_typevar().name)
|
||||||
|
tv1_exp = 'Some({})'.format(self.v1.get_typevar().name)
|
||||||
|
tv2_exp = 'Some({})'.format(self.v2.get_typevar().name)
|
||||||
|
tv3_exp = 'Some({})'.format(self.v3.get_typevar().name)
|
||||||
|
|
||||||
|
self.check_yo_check(
|
||||||
|
x, sequence(wider_check(tv1_exp, tv0_exp),
|
||||||
|
wider_check(tv1_exp, tv2_exp),
|
||||||
|
wider_check(tv3_exp, tv2_exp)))
|
||||||
|
|||||||
@@ -236,6 +236,13 @@ impl Type {
|
|||||||
pub fn index(self) -> usize {
|
pub fn index(self) -> usize {
|
||||||
self.0 as usize
|
self.0 as usize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// True iff:
|
||||||
|
/// 1) self.lane_count() == other.lane_count() and
|
||||||
|
/// 2) self.lane_bits() >= other.lane_bits()
|
||||||
|
pub fn wider_or_equal(self, other: Type) -> bool {
|
||||||
|
self.lane_count() == other.lane_count() && self.lane_bits() >= other.lane_bits()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for Type {
|
impl Display for Type {
|
||||||
|
|||||||
Reference in New Issue
Block a user