* Reduce code duplication in TypeConstraint subclasses; Add ConstrainWiderOrEqual to ti and to ireduce,{s,u}extend and f{promote,demote}; Fix bug in emitting constraint edges in TypeEnv.dot(); Modify runtime constraint checks to reject match when they encounter overflow
* Rename Constrain types to something shorter; Move lane_bits/lane_counts in subclasses of ValueType; Add wider_or_eq function in rust and python;
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
de5501bc47
commit
a9147ebd30
@@ -12,6 +12,7 @@ from base.types import i8, f32, f64, b1
|
||||
from base.immediates import imm64, uimm8, ieee32, ieee64, offset32, uoffset32
|
||||
from base.immediates import intcc, floatcc, memflags, regunit
|
||||
from base import entities
|
||||
from cdsl.ti import WiderOrEq
|
||||
import base.formats # noqa
|
||||
|
||||
GROUP = InstructionGroup("base", "Shared base instruction set")
|
||||
@@ -1405,7 +1406,7 @@ ireduce = Instruction(
|
||||
and each lane must not have more bits that the input lanes. If the
|
||||
input and output types are the same, this is a no-op.
|
||||
""",
|
||||
ins=x, outs=a)
|
||||
ins=x, outs=a, constraints=WiderOrEq(Int, IntTo))
|
||||
|
||||
|
||||
IntTo = TypeVar(
|
||||
@@ -1427,7 +1428,7 @@ uextend = Instruction(
|
||||
and each lane must not have fewer bits that the input lanes. If the
|
||||
input and output types are the same, this is a no-op.
|
||||
""",
|
||||
ins=x, outs=a)
|
||||
ins=x, outs=a, constraints=WiderOrEq(IntTo, Int))
|
||||
|
||||
sextend = Instruction(
|
||||
'sextend', r"""
|
||||
@@ -1441,7 +1442,7 @@ sextend = Instruction(
|
||||
and each lane must not have fewer bits that the input lanes. If the
|
||||
input and output types are the same, this is a no-op.
|
||||
""",
|
||||
ins=x, outs=a)
|
||||
ins=x, outs=a, constraints=WiderOrEq(IntTo, Int))
|
||||
|
||||
FloatTo = TypeVar(
|
||||
'FloatTo', 'A scalar or vector floating point number',
|
||||
@@ -1457,14 +1458,14 @@ fpromote = Instruction(
|
||||
Each lane in `x` is converted to the destination floating point format.
|
||||
This is an exact operation.
|
||||
|
||||
Since Cretonne currently only supports two floating point formats, this
|
||||
instruction always converts :type:`f32` to :type:`f64`. This may change
|
||||
in the future.
|
||||
Cretonne currently only supports two floating point formats
|
||||
- :type:`f32` and :type:`f64`. This may change in the future.
|
||||
|
||||
The result type must have the same number of vector lanes as the input,
|
||||
and the result lanes must be larger than the input lanes.
|
||||
and the result lanes must not have fewer bits than the input lanes. If
|
||||
the input and output types are the same, this is a no-op.
|
||||
""",
|
||||
ins=x, outs=a)
|
||||
ins=x, outs=a, constraints=WiderOrEq(FloatTo, Float))
|
||||
|
||||
fdemote = Instruction(
|
||||
'fdemote', r"""
|
||||
@@ -1473,14 +1474,14 @@ fdemote = Instruction(
|
||||
Each lane in `x` is converted to the destination floating point format
|
||||
by rounding to nearest, ties to even.
|
||||
|
||||
Since Cretonne currently only supports two floating point formats, this
|
||||
instruction always converts :type:`f64` to :type:`f32`. This may change
|
||||
in the future.
|
||||
Cretonne currently only supports two floating point formats
|
||||
- :type:`f32` and :type:`f64`. This may change in the future.
|
||||
|
||||
The result type must have the same number of vector lanes as the input,
|
||||
and the result lanes must be smaller than the input lanes.
|
||||
and the result lanes must not have more bits than the input lanes. If
|
||||
the input and output types are the same, this is a no-op.
|
||||
""",
|
||||
ins=x, outs=a)
|
||||
ins=x, outs=a, constraints=WiderOrEq(Float, FloatTo))
|
||||
|
||||
x = Operand('x', Float)
|
||||
a = Operand('a', IntTo)
|
||||
|
||||
@@ -10,8 +10,10 @@ try:
|
||||
if TYPE_CHECKING:
|
||||
from .ast import Expr, Apply # noqa
|
||||
from .typevar import TypeVar # noqa
|
||||
from .ti import TypeConstraint # noqa
|
||||
# List of operands for ins/outs:
|
||||
OpList = Union[Sequence[Operand], Operand]
|
||||
ConstrList = Union[Sequence[TypeConstraint], TypeConstraint]
|
||||
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -80,6 +82,7 @@ class Instruction(object):
|
||||
operands and other operand kinds.
|
||||
:param outs: Tuple of output operands. The output operands must be SSA
|
||||
values or `variable_args`.
|
||||
:param constraints: Tuple of instruction-specific TypeConstraints.
|
||||
:param is_terminator: This is a terminator instruction.
|
||||
:param is_branch: This is a branch instruction.
|
||||
:param is_call: This is a call instruction.
|
||||
@@ -102,13 +105,14 @@ class Instruction(object):
|
||||
'can_trap': 'Can this instruction cause a trap?',
|
||||
}
|
||||
|
||||
def __init__(self, name, doc, ins=(), outs=(), **kwargs):
|
||||
# type: (str, str, OpList, OpList, **Any) -> None # noqa
|
||||
def __init__(self, name, doc, ins=(), outs=(), constraints=(), **kwargs):
|
||||
# type: (str, str, OpList, OpList, ConstrList, **Any) -> None
|
||||
self.name = name
|
||||
self.camel_name = camel_case(name)
|
||||
self.__doc__ = doc
|
||||
self.ins = self._to_operand_tuple(ins)
|
||||
self.outs = self._to_operand_tuple(outs)
|
||||
self.constraints = self._to_constraint_tuple(constraints)
|
||||
self.format = InstructionFormat.lookup(self.ins, self.outs)
|
||||
|
||||
# Opcode number, assigned by gen_instr.py.
|
||||
@@ -268,6 +272,23 @@ class Instruction(object):
|
||||
assert isinstance(op, Operand)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _to_constraint_tuple(x):
|
||||
# type: (ConstrList) -> Tuple[TypeConstraint, ...]
|
||||
"""
|
||||
Allow a single TypeConstraint instance instead of the awkward singleton
|
||||
tuple syntax.
|
||||
"""
|
||||
# import placed here to avoid circular dependency
|
||||
from .ti import TypeConstraint # noqa
|
||||
if isinstance(x, TypeConstraint):
|
||||
x = (x,)
|
||||
else:
|
||||
x = tuple(x)
|
||||
for op in x:
|
||||
assert isinstance(op, TypeConstraint)
|
||||
return x
|
||||
|
||||
def bind(self, *args):
|
||||
# type: (*ValueType) -> BoundInstruction
|
||||
"""
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import absolute_import
|
||||
from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\
|
||||
b1, icmp, iadd_cout, iadd_cin, uextend, ireduce
|
||||
b1, icmp, iadd_cout, iadd_cin, uextend, sextend, ireduce, fpromote, \
|
||||
fdemote
|
||||
from base.legalize import narrow, expand
|
||||
from base.immediates import intcc
|
||||
from base.types import i32, i8
|
||||
from .typevar import TypeVar
|
||||
from .ast import Var, Def
|
||||
from .xform import Rtl, XForm
|
||||
from .ti import ti_rtl, subst, TypeEnv, get_type_env, ConstrainTVsEqual
|
||||
from .ti import ti_rtl, subst, TypeEnv, get_type_env, TypesEqual, WiderOrEq
|
||||
from unittest import TestCase
|
||||
from functools import reduce
|
||||
|
||||
@@ -52,9 +53,10 @@ def agree(me, other):
|
||||
|
||||
# Translate our constraints using m, and sort
|
||||
me_equiv_constr = sorted([constr.translate(m)
|
||||
for constr in me.constraints])
|
||||
for constr in me.constraints], key=repr)
|
||||
# Sort other's constraints
|
||||
other_equiv_constr = sorted(other.constraints)
|
||||
other_equiv_constr = sorted([constr.translate(other)
|
||||
for constr in other.constraints], key=repr)
|
||||
return me_equiv_constr == other_equiv_constr
|
||||
|
||||
|
||||
@@ -78,7 +80,7 @@ def check_typing(got_or_err, expected, symtab=None):
|
||||
tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()}
|
||||
# Rewrite the TVs in the input constraints to their XForm internal
|
||||
# versions
|
||||
c = [(subst(a, subst_m), subst(b, subst_m)) for (a, b) in c]
|
||||
c = [constr.translate(subst_m) for constr in c]
|
||||
else:
|
||||
# If no symtab, just convert m from Var->TypeVar map to a
|
||||
# TypeVar->TypeVar map
|
||||
@@ -209,7 +211,7 @@ class TestRTL(TypeCheckingBaseTest):
|
||||
self.v3: txn,
|
||||
self.v4: txn,
|
||||
self.v5: txn,
|
||||
}, [ConstrainTVsEqual(ixn.as_bool(), txn.as_bool())]))
|
||||
}, [TypesEqual(ixn.as_bool(), txn.as_bool())]))
|
||||
|
||||
def test_vselect_vsplits(self):
|
||||
# type: () -> None
|
||||
@@ -319,6 +321,90 @@ class TestRTL(TypeCheckingBaseTest):
|
||||
"Error: empty type created when unifying " +
|
||||
"`typeof_v4` and `typeof_v5`")
|
||||
|
||||
def test_extend_reduce(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v1 << uextend(self.v0),
|
||||
self.v2 << ireduce(self.v1),
|
||||
self.v3 << sextend(self.v2),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
typing = typing.extract()
|
||||
|
||||
itype0 = TypeVar("t", "", ints=True, simd=(1, 256))
|
||||
itype1 = TypeVar("t1", "", ints=True, simd=(1, 256))
|
||||
itype2 = TypeVar("t2", "", ints=True, simd=(1, 256))
|
||||
itype3 = TypeVar("t3", "", ints=True, simd=(1, 256))
|
||||
|
||||
check_typing(typing, ({
|
||||
self.v0: itype0,
|
||||
self.v1: itype1,
|
||||
self.v2: itype2,
|
||||
self.v3: itype3,
|
||||
}, [WiderOrEq(itype1, itype0),
|
||||
WiderOrEq(itype1, itype2),
|
||||
WiderOrEq(itype3, itype2)]))
|
||||
|
||||
def test_extend_reduce_enumeration(self):
|
||||
# type: () -> None
|
||||
for op in (uextend, sextend, ireduce):
|
||||
r = Rtl(
|
||||
self.v1 << op(self.v0),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti).extract()
|
||||
|
||||
# The number of possible typings is 9 * (3+ 2*2 + 3) = 90
|
||||
l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
|
||||
assert (len(l) == len(set(l)) and len(l) == 90)
|
||||
for (tv0, tv1) in l:
|
||||
typ0, typ1 = (tv0.singleton_type(), tv1.singleton_type())
|
||||
if (op == ireduce):
|
||||
assert typ0.wider_or_equal(typ1)
|
||||
else:
|
||||
assert typ1.wider_or_equal(typ0)
|
||||
|
||||
def test_fpromote_fdemote(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v1 << fpromote(self.v0),
|
||||
self.v2 << fdemote(self.v1),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
typing = typing.extract()
|
||||
|
||||
ftype0 = TypeVar("t", "", floats=True, simd=(1, 256))
|
||||
ftype1 = TypeVar("t1", "", floats=True, simd=(1, 256))
|
||||
ftype2 = TypeVar("t2", "", floats=True, simd=(1, 256))
|
||||
|
||||
check_typing(typing, ({
|
||||
self.v0: ftype0,
|
||||
self.v1: ftype1,
|
||||
self.v2: ftype2,
|
||||
}, [WiderOrEq(ftype1, ftype0),
|
||||
WiderOrEq(ftype1, ftype2)]))
|
||||
|
||||
def test_fpromote_fdemote_enumeration(self):
|
||||
# type: () -> None
|
||||
for op in (fpromote, fdemote):
|
||||
r = Rtl(
|
||||
self.v1 << op(self.v0),
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti).extract()
|
||||
|
||||
# The number of possible typings is 9*(2 + 1) = 27
|
||||
l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
|
||||
assert (len(l) == len(set(l)) and len(l) == 27)
|
||||
for (tv0, tv1) in l:
|
||||
(typ0, typ1) = (tv0.singleton_type(), tv1.singleton_type())
|
||||
if (op == fdemote):
|
||||
assert typ0.wider_or_equal(typ1)
|
||||
else:
|
||||
assert typ1.wider_or_equal(typ0)
|
||||
|
||||
|
||||
class TestXForm(TypeCheckingBaseTest):
|
||||
def test_iadd_cout(self):
|
||||
@@ -453,7 +539,7 @@ class TestXForm(TypeCheckingBaseTest):
|
||||
self.v3: i32t,
|
||||
self.v4: i32t,
|
||||
self.v5: i32t,
|
||||
}, []), x.symtab)
|
||||
}, [WiderOrEq(i32t, itype)]), x.symtab)
|
||||
|
||||
def test_bound_inst_inference1(self):
|
||||
# Second example taken from issue #26
|
||||
@@ -477,7 +563,7 @@ class TestXForm(TypeCheckingBaseTest):
|
||||
self.v3: i32t,
|
||||
self.v4: i32t,
|
||||
self.v5: i32t,
|
||||
}, []), x.symtab)
|
||||
}, [WiderOrEq(i32t, itype)]), x.symtab)
|
||||
|
||||
def test_fully_bound_inst_inference(self):
|
||||
# Second example taken from issue #26 with complete bounds
|
||||
@@ -494,6 +580,7 @@ class TestXForm(TypeCheckingBaseTest):
|
||||
i8t = TypeVar.singleton(i8)
|
||||
i32t = TypeVar.singleton(i32)
|
||||
|
||||
# Note no constraints here since they are all trivial
|
||||
check_typing(x.ti, ({
|
||||
self.v0: i8t,
|
||||
self.v1: i8t,
|
||||
|
||||
@@ -8,7 +8,7 @@ from itertools import product
|
||||
|
||||
try:
|
||||
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
||||
from typing import Iterable, List # noqa
|
||||
from typing import Iterable, List, Any # noqa
|
||||
from typing import cast
|
||||
from .xform import Rtl, XForm # noqa
|
||||
from .ast import Expr # noqa
|
||||
@@ -25,9 +25,72 @@ class TypeConstraint(object):
|
||||
"""
|
||||
Base class for all runtime-emittable type constraints.
|
||||
"""
|
||||
def translate(self, m):
|
||||
# type: (Union[TypeEnv, TypeMap]) -> TypeConstraint
|
||||
"""
|
||||
Translate any TypeVars in the constraint according to the map or
|
||||
TypeEnv m
|
||||
"""
|
||||
def translate_one(a):
|
||||
# type: (Any) -> Any
|
||||
if (isinstance(a, TypeVar)):
|
||||
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
|
||||
return a
|
||||
|
||||
res = None # type: TypeConstraint
|
||||
res = self.__class__(*tuple(map(translate_one, self._args())))
|
||||
return res
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if (not isinstance(other, self.__class__)):
|
||||
return False
|
||||
|
||||
assert isinstance(other, TypeConstraint) # help MyPy figure out other
|
||||
return self._args() == other._args()
|
||||
|
||||
def is_concrete(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true iff all typevars in the constraint are singletons.
|
||||
"""
|
||||
tvs = filter(lambda x: isinstance(x, TypeVar), self._args())
|
||||
return [] == list(filter(lambda x: x.singleton_type() is None, tvs))
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
return hash(self._args())
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
"""
|
||||
Return a tuple with the exact arguments passed to __init__ to create
|
||||
this object.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true if this constrain is statically decidable.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Evaluate this constraint. Should only be called when the constraint has
|
||||
been translated to concrete types.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return (self.__class__.__name__ + '(' +
|
||||
', '.join(map(str, self._args())) + ')')
|
||||
|
||||
|
||||
class ConstrainTVsEqual(TypeConstraint):
|
||||
class TypesEqual(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that two derived type vars must have the same runtime
|
||||
type.
|
||||
@@ -37,48 +100,24 @@ class ConstrainTVsEqual(TypeConstraint):
|
||||
assert tv1.is_derived and tv2.is_derived
|
||||
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv1, self.tv2)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true if this constrain is statically decidable.
|
||||
"""
|
||||
return self.tv1 == self.tv2 or \
|
||||
(self.tv1.singleton_type() is not None and
|
||||
self.tv2.singleton_type() is not None)
|
||||
|
||||
def translate(self, m):
|
||||
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual
|
||||
"""
|
||||
Translate any TypeVars in the constraint according to the map m
|
||||
"""
|
||||
if isinstance(m, TypeEnv):
|
||||
return ConstrainTVsEqual(m[self.tv1], m[self.tv2])
|
||||
else:
|
||||
return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m))
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if (not isinstance(other, ConstrainTVsEqual)):
|
||||
return False
|
||||
|
||||
return (self.tv1, self.tv2) == (other.tv1, other.tv2)
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
return hash((self.tv1, self.tv2))
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
return self.tv1 == self.tv2 or self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Evaluate this constraint. Should only be called when the constraint has
|
||||
been translated to concrete types.
|
||||
"""
|
||||
assert self.tv1.singleton_type() is not None and \
|
||||
self.tv2.singleton_type() is not None
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
return self.tv1.singleton_type() == self.tv2.singleton_type()
|
||||
|
||||
|
||||
class ConstrainTVInTypeset(TypeConstraint):
|
||||
class InTypeset(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that a type var must belong to some typeset.
|
||||
"""
|
||||
@@ -88,11 +127,14 @@ class ConstrainTVInTypeset(TypeConstraint):
|
||||
self.tv = tv
|
||||
self.ts = ts
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv, self.ts)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true if this constrain is statically decidable.
|
||||
"""
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
tv_ts = self.tv.get_typeset().copy()
|
||||
|
||||
# Trivially True
|
||||
@@ -104,39 +146,78 @@ class ConstrainTVInTypeset(TypeConstraint):
|
||||
if (tv_ts.size() == 0):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def translate(self, m):
|
||||
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset
|
||||
"""
|
||||
Translate any TypeVars in the constraint according to the map m
|
||||
"""
|
||||
if isinstance(m, TypeEnv):
|
||||
return ConstrainTVInTypeset(m[self.tv], self.ts)
|
||||
else:
|
||||
return ConstrainTVInTypeset(subst(self.tv, m), self.ts)
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if (not isinstance(other, ConstrainTVInTypeset)):
|
||||
return False
|
||||
|
||||
return (self.tv, self.ts) == (other.tv, other.ts)
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
return hash((self.tv, self.ts))
|
||||
return self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Evaluate this constraint. Should only be called when the constraint has
|
||||
been translated to concrete types.
|
||||
"""
|
||||
assert self.tv.singleton_type() is not None
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
return self.tv.get_typeset().issubset(self.ts)
|
||||
|
||||
|
||||
class WiderOrEq(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that a type var tv1 must be wider than or equal to
|
||||
type var tv2 at runtime. This requires that:
|
||||
1) They have the same number of lanes
|
||||
2) In a lane tv1 has at least as many bits as tv2.
|
||||
"""
|
||||
def __init__(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
self.tv1 = tv1
|
||||
self.tv2 = tv2
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv1, self.tv2)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
# Trivially true
|
||||
if (self.tv1 == self.tv2):
|
||||
return True
|
||||
|
||||
ts1 = self.tv1.get_typeset()
|
||||
ts2 = self.tv2.get_typeset()
|
||||
|
||||
def set_wider_or_equal(s1, s2):
|
||||
# type: (Set[int], Set[int]) -> bool
|
||||
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
|
||||
|
||||
# Trivially True
|
||||
if set_wider_or_equal(ts1.ints, ts2.ints) and\
|
||||
set_wider_or_equal(ts1.floats, ts2.floats) and\
|
||||
set_wider_or_equal(ts1.bools, ts2.bools):
|
||||
return True
|
||||
|
||||
def set_narrower(s1, s2):
|
||||
# type: (Set[int], Set[int]) -> bool
|
||||
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
|
||||
|
||||
# Trivially False
|
||||
if set_narrower(ts1.ints, ts2.ints) and\
|
||||
set_narrower(ts1.floats, ts2.floats) and\
|
||||
set_narrower(ts1.bools, ts2.bools):
|
||||
return True
|
||||
|
||||
# Trivially False
|
||||
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
|
||||
return True
|
||||
|
||||
return self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
typ1 = self.tv1.singleton_type()
|
||||
typ2 = self.tv2.singleton_type()
|
||||
|
||||
return typ1.wider_or_equal(typ2)
|
||||
|
||||
|
||||
class TypeEnv(object):
|
||||
"""
|
||||
Class encapsulating the neccessary book keeping for type inference.
|
||||
@@ -204,12 +285,11 @@ class TypeEnv(object):
|
||||
|
||||
self.type_map[tv1] = tv2
|
||||
|
||||
def add_constraint(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
def add_constraint(self, constr):
|
||||
# type: (TypeConstraint) -> None
|
||||
"""
|
||||
Add a new equivalence constraint between tv1 and tv2
|
||||
Add a new constraint
|
||||
"""
|
||||
constr = ConstrainTVsEqual(tv1, tv2)
|
||||
if (constr not in self.constraints):
|
||||
self.constraints.append(constr)
|
||||
|
||||
@@ -261,6 +341,7 @@ class TypeEnv(object):
|
||||
Get the free typevars in the current type env.
|
||||
"""
|
||||
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
|
||||
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
|
||||
# Filter out None here due to singleton type vars
|
||||
return sorted(filter(lambda x: x is not None, tvs),
|
||||
key=lambda x: x.name)
|
||||
@@ -326,17 +407,18 @@ class TypeEnv(object):
|
||||
|
||||
new_constraints = [] # type: List[TypeConstraint]
|
||||
for constr in self.constraints:
|
||||
# Currently typeinference only generates ConstrainTVsEqual
|
||||
# constraints
|
||||
assert isinstance(constr, ConstrainTVsEqual)
|
||||
constr = constr.translate(self)
|
||||
|
||||
if constr.is_trivial() or constr in new_constraints:
|
||||
continue
|
||||
|
||||
# Sanity: translated constraints should refer to only real vars
|
||||
assert constr.tv1.free_typevar() in vars_tvs and\
|
||||
constr.tv2.free_typevar() in vars_tvs
|
||||
for arg in constr._args():
|
||||
if (not isinstance(arg, TypeVar)):
|
||||
continue
|
||||
|
||||
arg_free_tv = arg.free_typevar()
|
||||
assert arg_free_tv is None or arg_free_tv in vars_tvs
|
||||
|
||||
new_constraints.append(constr)
|
||||
|
||||
@@ -372,9 +454,6 @@ class TypeEnv(object):
|
||||
# Check if constraints are satisfied for this typing
|
||||
failed = None
|
||||
for constr in self.constraints:
|
||||
# Currently typeinference only generates ConstrainTVsEqual
|
||||
# constraints
|
||||
assert isinstance(constr, ConstrainTVsEqual)
|
||||
concrete_constr = constr.translate(m)
|
||||
if not concrete_constr.eval():
|
||||
failed = concrete_constr
|
||||
@@ -401,22 +480,27 @@ class TypeEnv(object):
|
||||
# Add all registered TVs (as some of them may be singleton nodes not
|
||||
# appearing in the graph
|
||||
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
|
||||
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, Optional[str]]]
|
||||
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
|
||||
|
||||
for (k, v) in self.type_map.items():
|
||||
# Add all intermediate TVs appearing in edges
|
||||
nodes.add(k)
|
||||
nodes.add(v)
|
||||
edges.add((k, v, "dotted", None))
|
||||
edges.add((k, v, "dotted", "forward", None))
|
||||
while (v.is_derived):
|
||||
nodes.add(v.base)
|
||||
edges.add((v, v.base, "solid", v.derived_func))
|
||||
edges.add((v, v.base, "solid", "forward", v.derived_func))
|
||||
v = v.base
|
||||
|
||||
for constr in self.constraints:
|
||||
assert isinstance(constr, ConstrainTVsEqual)
|
||||
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", None))
|
||||
if isinstance(constr, TypesEqual):
|
||||
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
|
||||
elif isinstance(constr, WiderOrEq):
|
||||
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
|
||||
else:
|
||||
assert False, "Can't display constraint {}".format(constr)
|
||||
|
||||
root_nodes = set([x for x in nodes
|
||||
if x not in self.type_map and not x.is_derived])
|
||||
@@ -428,17 +512,12 @@ class TypeEnv(object):
|
||||
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
|
||||
r += ";\n"
|
||||
|
||||
for (n1, n2, style, elabel) in edges:
|
||||
e = label(n1)
|
||||
if style == "dashed":
|
||||
e += '--'
|
||||
else:
|
||||
e += '->'
|
||||
e += label(n2)
|
||||
e += "[style={}".format(style)
|
||||
for (n1, n2, style, direction, elabel) in edges:
|
||||
e = label(n1) + "->" + label(n2)
|
||||
e += "[style={},dir={}".format(style, direction)
|
||||
|
||||
if elabel is not None:
|
||||
e += ",label={}".format(elabel)
|
||||
e += ",label=\"{}\"".format(elabel)
|
||||
e += "];\n"
|
||||
|
||||
r += e
|
||||
@@ -589,7 +668,7 @@ def unify(tv1, tv2, typ):
|
||||
inv_f = TypeVar.inverse_func(tv1.derived_func)
|
||||
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
|
||||
|
||||
typ.add_constraint(tv1, tv2)
|
||||
typ.add_constraint(TypesEqual(tv1, tv2))
|
||||
return typ
|
||||
|
||||
|
||||
@@ -648,6 +727,10 @@ def ti_def(definition, typ):
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
# Add any instruction specific constraints
|
||||
for constr in inst.constraints:
|
||||
typ.add_constraint(constr.translate(m))
|
||||
|
||||
return typ
|
||||
|
||||
|
||||
|
||||
@@ -49,6 +49,26 @@ class ValueType(object):
|
||||
else:
|
||||
raise AttributeError("No type named '{}'".format(name))
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
assert False, "Abstract"
|
||||
|
||||
def lane_count(self):
|
||||
# type: () -> int
|
||||
"""Return the number of lanes."""
|
||||
assert False, "Abstract"
|
||||
|
||||
def wider_or_equal(self, other):
|
||||
# type: (ValueType) -> bool
|
||||
"""
|
||||
Return true iff:
|
||||
1. self and other have equal number of lanes
|
||||
2. each lane in self has at least as many bits as a lane in other
|
||||
"""
|
||||
return (self.lane_count() == other.lane_count() and
|
||||
self.lane_bits() >= other.lane_bits())
|
||||
|
||||
|
||||
class ScalarType(ValueType):
|
||||
"""
|
||||
@@ -85,6 +105,11 @@ class ScalarType(ValueType):
|
||||
self._vectors[lanes] = v
|
||||
return v
|
||||
|
||||
def lane_count(self):
|
||||
# type: () -> int
|
||||
"""Return the number of lanes."""
|
||||
return 1
|
||||
|
||||
|
||||
class VectorType(ValueType):
|
||||
"""
|
||||
@@ -112,6 +137,16 @@ class VectorType(ValueType):
|
||||
return ('VectorType(base={}, lanes={})'
|
||||
.format(self.base.name, self.lanes))
|
||||
|
||||
def lane_count(self):
|
||||
# type: () -> int
|
||||
"""Return the number of lanes."""
|
||||
return self.lanes
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
return self.base.lane_bits()
|
||||
|
||||
|
||||
class IntType(ScalarType):
|
||||
"""A concrete scalar integer type."""
|
||||
@@ -138,6 +173,11 @@ class IntType(ScalarType):
|
||||
else:
|
||||
return typ
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
return self.bits
|
||||
|
||||
|
||||
class FloatType(ScalarType):
|
||||
"""A concrete scalar floating point type."""
|
||||
@@ -164,6 +204,11 @@ class FloatType(ScalarType):
|
||||
else:
|
||||
return typ
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
return self.bits
|
||||
|
||||
|
||||
class BoolType(ScalarType):
|
||||
"""A concrete scalar boolean type."""
|
||||
@@ -189,3 +234,8 @@ class BoolType(ScalarType):
|
||||
return cast(BoolType, typ)
|
||||
else:
|
||||
return typ
|
||||
|
||||
def lane_bits(self):
|
||||
# type: () -> int
|
||||
"""Return the number of bits in a lane."""
|
||||
return self.bits
|
||||
|
||||
@@ -11,8 +11,8 @@ from __future__ import absolute_import
|
||||
from srcgen import Formatter
|
||||
from base import legalize, instructions
|
||||
from cdsl.ast import Var
|
||||
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, ConstrainTVsEqual,\
|
||||
ConstrainTVInTypeset
|
||||
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, TypesEqual,\
|
||||
InTypeset, WiderOrEq
|
||||
from unique_table import UniqueTable
|
||||
from gen_instr import gen_typesets_table
|
||||
from cdsl.typevar import TypeVar
|
||||
@@ -66,7 +66,7 @@ def get_runtime_typechecks(xform):
|
||||
|
||||
assert xform_ts.issubset(src_ts)
|
||||
if src_ts != xform_ts:
|
||||
check_l.append(ConstrainTVInTypeset(xform.ti[v], xform_ts))
|
||||
check_l.append(InTypeset(xform.ti[v], xform_ts))
|
||||
|
||||
# 2,3) Add any constraints that appear in xform.ti
|
||||
check_l.extend(xform.ti.constraints)
|
||||
@@ -81,6 +81,14 @@ def emit_runtime_typecheck(check, fmt, type_sets):
|
||||
"""
|
||||
def build_derived_expr(tv):
|
||||
# type: (TypeVar) -> str
|
||||
"""
|
||||
Build an expression of type Option<Type> corresponding to a concrete
|
||||
type transformed by the sequence of derivation functions in tv.
|
||||
|
||||
We are using Option<Type>, as some constraints may cause an
|
||||
over/underflow on patterns that do not match them. We want to capture
|
||||
this without panicking at runtime.
|
||||
"""
|
||||
if not tv.is_derived:
|
||||
assert tv.name.startswith('typeof_')
|
||||
return "Some({})".format(tv.name)
|
||||
@@ -102,7 +110,8 @@ def emit_runtime_typecheck(check, fmt, type_sets):
|
||||
else:
|
||||
assert False, "Unknown derived function {}".format(tv.derived_func)
|
||||
|
||||
if (isinstance(check, ConstrainTVInTypeset)):
|
||||
if (isinstance(check, InTypeset)):
|
||||
assert not check.tv.is_derived
|
||||
tv = check.tv.name
|
||||
if check.ts not in type_sets.index:
|
||||
type_sets.add(check.ts)
|
||||
@@ -112,11 +121,28 @@ def emit_runtime_typecheck(check, fmt, type_sets):
|
||||
with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv),
|
||||
'};'):
|
||||
fmt.line('return false;')
|
||||
elif (isinstance(check, ConstrainTVsEqual)):
|
||||
tv1 = build_derived_expr(check.tv1)
|
||||
tv2 = build_derived_expr(check.tv2)
|
||||
with fmt.indented('if {} != {} {{'.format(tv1, tv2), '};'):
|
||||
fmt.line('return false;')
|
||||
elif (isinstance(check, TypesEqual)):
|
||||
with fmt.indented('{', '};'):
|
||||
fmt.line('let a = {};'.format(build_derived_expr(check.tv1)))
|
||||
fmt.line('let b = {};'.format(build_derived_expr(check.tv2)))
|
||||
|
||||
fmt.comment('On overflow constraint doesn\'t appply')
|
||||
with fmt.indented('if a.is_none() || b.is_none() {', '};'):
|
||||
fmt.line('return false;')
|
||||
|
||||
with fmt.indented('if a != b {', '};'):
|
||||
fmt.line('return false;')
|
||||
elif (isinstance(check, WiderOrEq)):
|
||||
with fmt.indented('{', '};'):
|
||||
fmt.line('let a = {};'.format(build_derived_expr(check.tv1)))
|
||||
fmt.line('let b = {};'.format(build_derived_expr(check.tv2)))
|
||||
|
||||
fmt.comment('On overflow constraint doesn\'t appply')
|
||||
with fmt.indented('if a.is_none() || b.is_none() {', '};'):
|
||||
fmt.line('return false;')
|
||||
|
||||
with fmt.indented('if !a.wider_or_equal(b) {', '};'):
|
||||
fmt.line('return false;')
|
||||
else:
|
||||
assert False, "Unknown check {}".format(check)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest import TestCase
|
||||
from srcgen import Formatter
|
||||
from gen_legalizer import get_runtime_typechecks, emit_runtime_typecheck
|
||||
from base.instructions import vselect, vsplit, isplit, iconcat, vconcat, \
|
||||
iconst, b1, icmp, copy # noqa
|
||||
iconst, b1, icmp, copy, sextend, uextend, ireduce, fdemote, fpromote # noqa
|
||||
from base.legalize import narrow, expand # noqa
|
||||
from base.immediates import intcc # noqa
|
||||
from cdsl.typevar import TypeVar, TypeSet
|
||||
@@ -57,9 +57,32 @@ def equiv_check(tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> CheckProducer
|
||||
return lambda typesets: format_check(
|
||||
typesets,
|
||||
'if Some({}).map(|t: Type| -> t.as_bool()) != ' +
|
||||
'Some({}).map(|t: Type| -> t.as_bool()) ' +
|
||||
'{{\n return false;\n}};\n', tv1, tv2)
|
||||
'{{\n' +
|
||||
' let a = {};\n' +
|
||||
' let b = {};\n' +
|
||||
' if a.is_none() || b.is_none() {{\n' +
|
||||
' return false;\n' +
|
||||
' }};\n' +
|
||||
' if a != b {{\n' +
|
||||
' return false;\n' +
|
||||
' }};\n' +
|
||||
'}};\n', tv1, tv2)
|
||||
|
||||
|
||||
def wider_check(tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> CheckProducer
|
||||
return lambda typesets: format_check(
|
||||
typesets,
|
||||
'{{\n' +
|
||||
' let a = {};\n' +
|
||||
' let b = {};\n' +
|
||||
' if a.is_none() || b.is_none() {{\n' +
|
||||
' return false;\n' +
|
||||
' }};\n' +
|
||||
' if !a.wider_or_equal(b) {{\n' +
|
||||
' return false;\n' +
|
||||
' }};\n' +
|
||||
'}};\n', tv1, tv2)
|
||||
|
||||
|
||||
def sequence(*args):
|
||||
@@ -138,8 +161,49 @@ class TestRuntimeChecks(TestCase):
|
||||
self.v5 << vselect(self.v1, self.v3, self.v4),
|
||||
)
|
||||
x = XForm(r, r)
|
||||
tv2_exp = 'Some({}).map(|t: Type| -> t.as_bool())'\
|
||||
.format(self.v2.get_typevar().name)
|
||||
tv3_exp = 'Some({}).map(|t: Type| -> t.as_bool())'\
|
||||
.format(self.v3.get_typevar().name)
|
||||
|
||||
self.check_yo_check(
|
||||
x, sequence(typeset_check(self.v3, ts),
|
||||
equiv_check(self.v2.get_typevar(),
|
||||
self.v3.get_typevar())))
|
||||
equiv_check(tv2_exp, tv3_exp)))
|
||||
|
||||
def test_reduce_extend(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v1 << uextend(self.v0),
|
||||
self.v2 << ireduce(self.v1),
|
||||
self.v3 << sextend(self.v2),
|
||||
)
|
||||
x = XForm(r, r)
|
||||
|
||||
tv0_exp = 'Some({})'.format(self.v0.get_typevar().name)
|
||||
tv1_exp = 'Some({})'.format(self.v1.get_typevar().name)
|
||||
tv2_exp = 'Some({})'.format(self.v2.get_typevar().name)
|
||||
tv3_exp = 'Some({})'.format(self.v3.get_typevar().name)
|
||||
|
||||
self.check_yo_check(
|
||||
x, sequence(wider_check(tv1_exp, tv0_exp),
|
||||
wider_check(tv1_exp, tv2_exp),
|
||||
wider_check(tv3_exp, tv2_exp)))
|
||||
|
||||
def test_demote_promote(self):
|
||||
# type: () -> None
|
||||
r = Rtl(
|
||||
self.v1 << fpromote(self.v0),
|
||||
self.v2 << fdemote(self.v1),
|
||||
self.v3 << fpromote(self.v2),
|
||||
)
|
||||
x = XForm(r, r)
|
||||
|
||||
tv0_exp = 'Some({})'.format(self.v0.get_typevar().name)
|
||||
tv1_exp = 'Some({})'.format(self.v1.get_typevar().name)
|
||||
tv2_exp = 'Some({})'.format(self.v2.get_typevar().name)
|
||||
tv3_exp = 'Some({})'.format(self.v3.get_typevar().name)
|
||||
|
||||
self.check_yo_check(
|
||||
x, sequence(wider_check(tv1_exp, tv0_exp),
|
||||
wider_check(tv1_exp, tv2_exp),
|
||||
wider_check(tv3_exp, tv2_exp)))
|
||||
|
||||
@@ -236,6 +236,13 @@ impl Type {
|
||||
pub fn index(self) -> usize {
|
||||
self.0 as usize
|
||||
}
|
||||
|
||||
/// True iff:
|
||||
/// 1) self.lane_count() == other.lane_count() and
|
||||
/// 2) self.lane_bits() >= other.lane_bits()
|
||||
pub fn wider_or_equal(self, other: Type) -> bool {
|
||||
self.lane_count() == other.lane_count() && self.lane_bits() >= other.lane_bits()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Type {
|
||||
|
||||
Reference in New Issue
Block a user