Add fix for #114 (#115)

* Reduce code duplication in TypeConstraint subclasses; Add ConstrainWiderOrEqual to ti and to ireduce,{s,u}extend and f{promote,demote}; Fix bug in emitting constraint edges in TypeEnv.dot(); Modify runtime constraint checks to reject match when they encounter overflow

* Rename Constrain types to something shorter; Move lane_bits/lane_counts in subclasses of ValueType; Add wider_or_eq function in rust and python;
This commit is contained in:
d1m0
2017-07-12 08:51:55 -07:00
committed by Jakob Stoklund Olesen
parent de5501bc47
commit a9147ebd30
8 changed files with 471 additions and 132 deletions

View File

@@ -12,6 +12,7 @@ from base.types import i8, f32, f64, b1
from base.immediates import imm64, uimm8, ieee32, ieee64, offset32, uoffset32
from base.immediates import intcc, floatcc, memflags, regunit
from base import entities
from cdsl.ti import WiderOrEq
import base.formats # noqa
GROUP = InstructionGroup("base", "Shared base instruction set")
@@ -1405,7 +1406,7 @@ ireduce = Instruction(
and each lane must not have more bits that the input lanes. If the
input and output types are the same, this is a no-op.
""",
ins=x, outs=a)
ins=x, outs=a, constraints=WiderOrEq(Int, IntTo))
IntTo = TypeVar(
@@ -1427,7 +1428,7 @@ uextend = Instruction(
and each lane must not have fewer bits that the input lanes. If the
input and output types are the same, this is a no-op.
""",
ins=x, outs=a)
ins=x, outs=a, constraints=WiderOrEq(IntTo, Int))
sextend = Instruction(
'sextend', r"""
@@ -1441,7 +1442,7 @@ sextend = Instruction(
and each lane must not have fewer bits that the input lanes. If the
input and output types are the same, this is a no-op.
""",
ins=x, outs=a)
ins=x, outs=a, constraints=WiderOrEq(IntTo, Int))
FloatTo = TypeVar(
'FloatTo', 'A scalar or vector floating point number',
@@ -1457,14 +1458,14 @@ fpromote = Instruction(
Each lane in `x` is converted to the destination floating point format.
This is an exact operation.
Since Cretonne currently only supports two floating point formats, this
instruction always converts :type:`f32` to :type:`f64`. This may change
in the future.
Cretonne currently only supports two floating point formats
- :type:`f32` and :type:`f64`. This may change in the future.
The result type must have the same number of vector lanes as the input,
and the result lanes must be larger than the input lanes.
and the result lanes must not have fewer bits than the input lanes. If
the input and output types are the same, this is a no-op.
""",
ins=x, outs=a)
ins=x, outs=a, constraints=WiderOrEq(FloatTo, Float))
fdemote = Instruction(
'fdemote', r"""
@@ -1473,14 +1474,14 @@ fdemote = Instruction(
Each lane in `x` is converted to the destination floating point format
by rounding to nearest, ties to even.
Since Cretonne currently only supports two floating point formats, this
instruction always converts :type:`f64` to :type:`f32`. This may change
in the future.
Cretonne currently only supports two floating point formats
- :type:`f32` and :type:`f64`. This may change in the future.
The result type must have the same number of vector lanes as the input,
and the result lanes must be smaller than the input lanes.
and the result lanes must not have more bits than the input lanes. If
the input and output types are the same, this is a no-op.
""",
ins=x, outs=a)
ins=x, outs=a, constraints=WiderOrEq(Float, FloatTo))
x = Operand('x', Float)
a = Operand('a', IntTo)

View File

@@ -10,8 +10,10 @@ try:
if TYPE_CHECKING:
from .ast import Expr, Apply # noqa
from .typevar import TypeVar # noqa
from .ti import TypeConstraint # noqa
# List of operands for ins/outs:
OpList = Union[Sequence[Operand], Operand]
ConstrList = Union[Sequence[TypeConstraint], TypeConstraint]
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
except ImportError:
pass
@@ -80,6 +82,7 @@ class Instruction(object):
operands and other operand kinds.
:param outs: Tuple of output operands. The output operands must be SSA
values or `variable_args`.
:param constraints: Tuple of instruction-specific TypeConstraints.
:param is_terminator: This is a terminator instruction.
:param is_branch: This is a branch instruction.
:param is_call: This is a call instruction.
@@ -102,13 +105,14 @@ class Instruction(object):
'can_trap': 'Can this instruction cause a trap?',
}
def __init__(self, name, doc, ins=(), outs=(), **kwargs):
# type: (str, str, OpList, OpList, **Any) -> None # noqa
def __init__(self, name, doc, ins=(), outs=(), constraints=(), **kwargs):
# type: (str, str, OpList, OpList, ConstrList, **Any) -> None
self.name = name
self.camel_name = camel_case(name)
self.__doc__ = doc
self.ins = self._to_operand_tuple(ins)
self.outs = self._to_operand_tuple(outs)
self.constraints = self._to_constraint_tuple(constraints)
self.format = InstructionFormat.lookup(self.ins, self.outs)
# Opcode number, assigned by gen_instr.py.
@@ -268,6 +272,23 @@ class Instruction(object):
assert isinstance(op, Operand)
return x
@staticmethod
def _to_constraint_tuple(x):
# type: (ConstrList) -> Tuple[TypeConstraint, ...]
"""
Allow a single TypeConstraint instance instead of the awkward singleton
tuple syntax.
"""
# import placed here to avoid circular dependency
from .ti import TypeConstraint # noqa
if isinstance(x, TypeConstraint):
x = (x,)
else:
x = tuple(x)
for op in x:
assert isinstance(op, TypeConstraint)
return x
def bind(self, *args):
# type: (*ValueType) -> BoundInstruction
"""

View File

@@ -1,13 +1,14 @@
from __future__ import absolute_import
from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\
b1, icmp, iadd_cout, iadd_cin, uextend, ireduce
b1, icmp, iadd_cout, iadd_cin, uextend, sextend, ireduce, fpromote, \
fdemote
from base.legalize import narrow, expand
from base.immediates import intcc
from base.types import i32, i8
from .typevar import TypeVar
from .ast import Var, Def
from .xform import Rtl, XForm
from .ti import ti_rtl, subst, TypeEnv, get_type_env, ConstrainTVsEqual
from .ti import ti_rtl, subst, TypeEnv, get_type_env, TypesEqual, WiderOrEq
from unittest import TestCase
from functools import reduce
@@ -52,9 +53,10 @@ def agree(me, other):
# Translate our constraints using m, and sort
me_equiv_constr = sorted([constr.translate(m)
for constr in me.constraints])
for constr in me.constraints], key=repr)
# Sort other's constraints
other_equiv_constr = sorted(other.constraints)
other_equiv_constr = sorted([constr.translate(other)
for constr in other.constraints], key=repr)
return me_equiv_constr == other_equiv_constr
@@ -78,7 +80,7 @@ def check_typing(got_or_err, expected, symtab=None):
tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()}
# Rewrite the TVs in the input constraints to their XForm internal
# versions
c = [(subst(a, subst_m), subst(b, subst_m)) for (a, b) in c]
c = [constr.translate(subst_m) for constr in c]
else:
# If no symtab, just convert m from Var->TypeVar map to a
# TypeVar->TypeVar map
@@ -209,7 +211,7 @@ class TestRTL(TypeCheckingBaseTest):
self.v3: txn,
self.v4: txn,
self.v5: txn,
}, [ConstrainTVsEqual(ixn.as_bool(), txn.as_bool())]))
}, [TypesEqual(ixn.as_bool(), txn.as_bool())]))
def test_vselect_vsplits(self):
# type: () -> None
@@ -319,6 +321,90 @@ class TestRTL(TypeCheckingBaseTest):
"Error: empty type created when unifying " +
"`typeof_v4` and `typeof_v5`")
def test_extend_reduce(self):
# type: () -> None
r = Rtl(
self.v1 << uextend(self.v0),
self.v2 << ireduce(self.v1),
self.v3 << sextend(self.v2),
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
typing = typing.extract()
itype0 = TypeVar("t", "", ints=True, simd=(1, 256))
itype1 = TypeVar("t1", "", ints=True, simd=(1, 256))
itype2 = TypeVar("t2", "", ints=True, simd=(1, 256))
itype3 = TypeVar("t3", "", ints=True, simd=(1, 256))
check_typing(typing, ({
self.v0: itype0,
self.v1: itype1,
self.v2: itype2,
self.v3: itype3,
}, [WiderOrEq(itype1, itype0),
WiderOrEq(itype1, itype2),
WiderOrEq(itype3, itype2)]))
def test_extend_reduce_enumeration(self):
# type: () -> None
for op in (uextend, sextend, ireduce):
r = Rtl(
self.v1 << op(self.v0),
)
ti = TypeEnv()
typing = ti_rtl(r, ti).extract()
# The number of possible typings is 9 * (3+ 2*2 + 3) = 90
l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
assert (len(l) == len(set(l)) and len(l) == 90)
for (tv0, tv1) in l:
typ0, typ1 = (tv0.singleton_type(), tv1.singleton_type())
if (op == ireduce):
assert typ0.wider_or_equal(typ1)
else:
assert typ1.wider_or_equal(typ0)
def test_fpromote_fdemote(self):
# type: () -> None
r = Rtl(
self.v1 << fpromote(self.v0),
self.v2 << fdemote(self.v1),
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
typing = typing.extract()
ftype0 = TypeVar("t", "", floats=True, simd=(1, 256))
ftype1 = TypeVar("t1", "", floats=True, simd=(1, 256))
ftype2 = TypeVar("t2", "", floats=True, simd=(1, 256))
check_typing(typing, ({
self.v0: ftype0,
self.v1: ftype1,
self.v2: ftype2,
}, [WiderOrEq(ftype1, ftype0),
WiderOrEq(ftype1, ftype2)]))
def test_fpromote_fdemote_enumeration(self):
# type: () -> None
for op in (fpromote, fdemote):
r = Rtl(
self.v1 << op(self.v0),
)
ti = TypeEnv()
typing = ti_rtl(r, ti).extract()
# The number of possible typings is 9*(2 + 1) = 27
l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
assert (len(l) == len(set(l)) and len(l) == 27)
for (tv0, tv1) in l:
(typ0, typ1) = (tv0.singleton_type(), tv1.singleton_type())
if (op == fdemote):
assert typ0.wider_or_equal(typ1)
else:
assert typ1.wider_or_equal(typ0)
class TestXForm(TypeCheckingBaseTest):
def test_iadd_cout(self):
@@ -453,7 +539,7 @@ class TestXForm(TypeCheckingBaseTest):
self.v3: i32t,
self.v4: i32t,
self.v5: i32t,
}, []), x.symtab)
}, [WiderOrEq(i32t, itype)]), x.symtab)
def test_bound_inst_inference1(self):
# Second example taken from issue #26
@@ -477,7 +563,7 @@ class TestXForm(TypeCheckingBaseTest):
self.v3: i32t,
self.v4: i32t,
self.v5: i32t,
}, []), x.symtab)
}, [WiderOrEq(i32t, itype)]), x.symtab)
def test_fully_bound_inst_inference(self):
# Second example taken from issue #26 with complete bounds
@@ -494,6 +580,7 @@ class TestXForm(TypeCheckingBaseTest):
i8t = TypeVar.singleton(i8)
i32t = TypeVar.singleton(i32)
# Note no constraints here since they are all trivial
check_typing(x.ti, ({
self.v0: i8t,
self.v1: i8t,

View File

@@ -8,7 +8,7 @@ from itertools import product
try:
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
from typing import Iterable, List # noqa
from typing import Iterable, List, Any # noqa
from typing import cast
from .xform import Rtl, XForm # noqa
from .ast import Expr # noqa
@@ -25,9 +25,72 @@ class TypeConstraint(object):
"""
Base class for all runtime-emittable type constraints.
"""
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> TypeConstraint
"""
Translate any TypeVars in the constraint according to the map or
TypeEnv m
"""
def translate_one(a):
# type: (Any) -> Any
if (isinstance(a, TypeVar)):
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
return a
res = None # type: TypeConstraint
res = self.__class__(*tuple(map(translate_one, self._args())))
return res
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, self.__class__)):
return False
assert isinstance(other, TypeConstraint) # help MyPy figure out other
return self._args() == other._args()
def is_concrete(self):
# type: () -> bool
"""
Return true iff all typevars in the constraint are singletons.
"""
tvs = filter(lambda x: isinstance(x, TypeVar), self._args())
return [] == list(filter(lambda x: x.singleton_type() is None, tvs))
def __hash__(self):
# type: () -> int
return hash(self._args())
def _args(self):
# type: () -> Tuple[Any,...]
"""
Return a tuple with the exact arguments passed to __init__ to create
this object.
"""
assert False, "Abstract"
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
assert False, "Abstract"
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert False, "Abstract"
def __repr__(self):
# type: () -> str
return (self.__class__.__name__ + '(' +
', '.join(map(str, self._args())) + ')')
class ConstrainTVsEqual(TypeConstraint):
class TypesEqual(TypeConstraint):
"""
Constraint specifying that two derived type vars must have the same runtime
type.
@@ -37,48 +100,24 @@ class ConstrainTVsEqual(TypeConstraint):
assert tv1.is_derived and tv2.is_derived
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv1, self.tv2)
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
return self.tv1 == self.tv2 or \
(self.tv1.singleton_type() is not None and
self.tv2.singleton_type() is not None)
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVsEqual(m[self.tv1], m[self.tv2])
else:
return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m))
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVsEqual)):
return False
return (self.tv1, self.tv2) == (other.tv1, other.tv2)
def __hash__(self):
# type: () -> int
return hash((self.tv1, self.tv2))
""" See TypeConstraint.is_trivial() """
return self.tv1 == self.tv2 or self.is_concrete()
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert self.tv1.singleton_type() is not None and \
self.tv2.singleton_type() is not None
""" See TypeConstraint.eval() """
assert self.is_concrete()
return self.tv1.singleton_type() == self.tv2.singleton_type()
class ConstrainTVInTypeset(TypeConstraint):
class InTypeset(TypeConstraint):
"""
Constraint specifying that a type var must belong to some typeset.
"""
@@ -88,11 +127,14 @@ class ConstrainTVInTypeset(TypeConstraint):
self.tv = tv
self.ts = ts
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv, self.ts)
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
""" See TypeConstraint.is_trivial() """
tv_ts = self.tv.get_typeset().copy()
# Trivially True
@@ -104,39 +146,78 @@ class ConstrainTVInTypeset(TypeConstraint):
if (tv_ts.size() == 0):
return True
return False
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVInTypeset(m[self.tv], self.ts)
else:
return ConstrainTVInTypeset(subst(self.tv, m), self.ts)
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVInTypeset)):
return False
return (self.tv, self.ts) == (other.tv, other.ts)
def __hash__(self):
# type: () -> int
return hash((self.tv, self.ts))
return self.is_concrete()
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert self.tv.singleton_type() is not None
""" See TypeConstraint.eval() """
assert self.is_concrete()
return self.tv.get_typeset().issubset(self.ts)
class WiderOrEq(TypeConstraint):
"""
Constraint specifying that a type var tv1 must be wider than or equal to
type var tv2 at runtime. This requires that:
1) They have the same number of lanes
2) In a lane tv1 has at least as many bits as tv2.
"""
def __init__(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
self.tv1 = tv1
self.tv2 = tv2
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv1, self.tv2)
def is_trivial(self):
# type: () -> bool
""" See TypeConstraint.is_trivial() """
# Trivially true
if (self.tv1 == self.tv2):
return True
ts1 = self.tv1.get_typeset()
ts2 = self.tv2.get_typeset()
def set_wider_or_equal(s1, s2):
# type: (Set[int], Set[int]) -> bool
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
# Trivially True
if set_wider_or_equal(ts1.ints, ts2.ints) and\
set_wider_or_equal(ts1.floats, ts2.floats) and\
set_wider_or_equal(ts1.bools, ts2.bools):
return True
def set_narrower(s1, s2):
# type: (Set[int], Set[int]) -> bool
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
# Trivially False
if set_narrower(ts1.ints, ts2.ints) and\
set_narrower(ts1.floats, ts2.floats) and\
set_narrower(ts1.bools, ts2.bools):
return True
# Trivially False
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
return True
return self.is_concrete()
def eval(self):
# type: () -> bool
""" See TypeConstraint.eval() """
assert self.is_concrete()
typ1 = self.tv1.singleton_type()
typ2 = self.tv2.singleton_type()
return typ1.wider_or_equal(typ2)
class TypeEnv(object):
"""
Class encapsulating the neccessary book keeping for type inference.
@@ -204,12 +285,11 @@ class TypeEnv(object):
self.type_map[tv1] = tv2
def add_constraint(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
def add_constraint(self, constr):
# type: (TypeConstraint) -> None
"""
Add a new equivalence constraint between tv1 and tv2
Add a new constraint
"""
constr = ConstrainTVsEqual(tv1, tv2)
if (constr not in self.constraints):
self.constraints.append(constr)
@@ -261,6 +341,7 @@ class TypeEnv(object):
Get the free typevars in the current type env.
"""
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
# Filter out None here due to singleton type vars
return sorted(filter(lambda x: x is not None, tvs),
key=lambda x: x.name)
@@ -326,17 +407,18 @@ class TypeEnv(object):
new_constraints = [] # type: List[TypeConstraint]
for constr in self.constraints:
# Currently typeinference only generates ConstrainTVsEqual
# constraints
assert isinstance(constr, ConstrainTVsEqual)
constr = constr.translate(self)
if constr.is_trivial() or constr in new_constraints:
continue
# Sanity: translated constraints should refer to only real vars
assert constr.tv1.free_typevar() in vars_tvs and\
constr.tv2.free_typevar() in vars_tvs
for arg in constr._args():
if (not isinstance(arg, TypeVar)):
continue
arg_free_tv = arg.free_typevar()
assert arg_free_tv is None or arg_free_tv in vars_tvs
new_constraints.append(constr)
@@ -372,9 +454,6 @@ class TypeEnv(object):
# Check if constraints are satisfied for this typing
failed = None
for constr in self.constraints:
# Currently typeinference only generates ConstrainTVsEqual
# constraints
assert isinstance(constr, ConstrainTVsEqual)
concrete_constr = constr.translate(m)
if not concrete_constr.eval():
failed = concrete_constr
@@ -401,22 +480,27 @@ class TypeEnv(object):
# Add all registered TVs (as some of them may be singleton nodes not
# appearing in the graph
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, Optional[str]]]
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
for (k, v) in self.type_map.items():
# Add all intermediate TVs appearing in edges
nodes.add(k)
nodes.add(v)
edges.add((k, v, "dotted", None))
edges.add((k, v, "dotted", "forward", None))
while (v.is_derived):
nodes.add(v.base)
edges.add((v, v.base, "solid", v.derived_func))
edges.add((v, v.base, "solid", "forward", v.derived_func))
v = v.base
for constr in self.constraints:
assert isinstance(constr, ConstrainTVsEqual)
if isinstance(constr, TypesEqual):
assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", None))
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
elif isinstance(constr, WiderOrEq):
assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
else:
assert False, "Can't display constraint {}".format(constr)
root_nodes = set([x for x in nodes
if x not in self.type_map and not x.is_derived])
@@ -428,17 +512,12 @@ class TypeEnv(object):
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
r += ";\n"
for (n1, n2, style, elabel) in edges:
e = label(n1)
if style == "dashed":
e += '--'
else:
e += '->'
e += label(n2)
e += "[style={}".format(style)
for (n1, n2, style, direction, elabel) in edges:
e = label(n1) + "->" + label(n2)
e += "[style={},dir={}".format(style, direction)
if elabel is not None:
e += ",label={}".format(elabel)
e += ",label=\"{}\"".format(elabel)
e += "];\n"
r += e
@@ -589,7 +668,7 @@ def unify(tv1, tv2, typ):
inv_f = TypeVar.inverse_func(tv1.derived_func)
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
typ.add_constraint(tv1, tv2)
typ.add_constraint(TypesEqual(tv1, tv2))
return typ
@@ -648,6 +727,10 @@ def ti_def(definition, typ):
typ = get_type_env(typ_or_err)
# Add any instruction specific constraints
for constr in inst.constraints:
typ.add_constraint(constr.translate(m))
return typ

View File

@@ -49,6 +49,26 @@ class ValueType(object):
else:
raise AttributeError("No type named '{}'".format(name))
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
assert False, "Abstract"
def lane_count(self):
# type: () -> int
"""Return the number of lanes."""
assert False, "Abstract"
def wider_or_equal(self, other):
# type: (ValueType) -> bool
"""
Return true iff:
1. self and other have equal number of lanes
2. each lane in self has at least as many bits as a lane in other
"""
return (self.lane_count() == other.lane_count() and
self.lane_bits() >= other.lane_bits())
class ScalarType(ValueType):
"""
@@ -85,6 +105,11 @@ class ScalarType(ValueType):
self._vectors[lanes] = v
return v
def lane_count(self):
# type: () -> int
"""Return the number of lanes."""
return 1
class VectorType(ValueType):
"""
@@ -112,6 +137,16 @@ class VectorType(ValueType):
return ('VectorType(base={}, lanes={})'
.format(self.base.name, self.lanes))
def lane_count(self):
# type: () -> int
"""Return the number of lanes."""
return self.lanes
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.base.lane_bits()
class IntType(ScalarType):
"""A concrete scalar integer type."""
@@ -138,6 +173,11 @@ class IntType(ScalarType):
else:
return typ
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.bits
class FloatType(ScalarType):
"""A concrete scalar floating point type."""
@@ -164,6 +204,11 @@ class FloatType(ScalarType):
else:
return typ
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.bits
class BoolType(ScalarType):
"""A concrete scalar boolean type."""
@@ -189,3 +234,8 @@ class BoolType(ScalarType):
return cast(BoolType, typ)
else:
return typ
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.bits

View File

@@ -11,8 +11,8 @@ from __future__ import absolute_import
from srcgen import Formatter
from base import legalize, instructions
from cdsl.ast import Var
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, ConstrainTVsEqual,\
ConstrainTVInTypeset
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, TypesEqual,\
InTypeset, WiderOrEq
from unique_table import UniqueTable
from gen_instr import gen_typesets_table
from cdsl.typevar import TypeVar
@@ -66,7 +66,7 @@ def get_runtime_typechecks(xform):
assert xform_ts.issubset(src_ts)
if src_ts != xform_ts:
check_l.append(ConstrainTVInTypeset(xform.ti[v], xform_ts))
check_l.append(InTypeset(xform.ti[v], xform_ts))
# 2,3) Add any constraints that appear in xform.ti
check_l.extend(xform.ti.constraints)
@@ -81,6 +81,14 @@ def emit_runtime_typecheck(check, fmt, type_sets):
"""
def build_derived_expr(tv):
# type: (TypeVar) -> str
"""
Build an expression of type Option<Type> corresponding to a concrete
type transformed by the sequence of derivation functions in tv.
We are using Option<Type>, as some constraints may cause an
over/underflow on patterns that do not match them. We want to capture
this without panicking at runtime.
"""
if not tv.is_derived:
assert tv.name.startswith('typeof_')
return "Some({})".format(tv.name)
@@ -102,7 +110,8 @@ def emit_runtime_typecheck(check, fmt, type_sets):
else:
assert False, "Unknown derived function {}".format(tv.derived_func)
if (isinstance(check, ConstrainTVInTypeset)):
if (isinstance(check, InTypeset)):
assert not check.tv.is_derived
tv = check.tv.name
if check.ts not in type_sets.index:
type_sets.add(check.ts)
@@ -112,10 +121,27 @@ def emit_runtime_typecheck(check, fmt, type_sets):
with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv),
'};'):
fmt.line('return false;')
elif (isinstance(check, ConstrainTVsEqual)):
tv1 = build_derived_expr(check.tv1)
tv2 = build_derived_expr(check.tv2)
with fmt.indented('if {} != {} {{'.format(tv1, tv2), '};'):
elif (isinstance(check, TypesEqual)):
with fmt.indented('{', '};'):
fmt.line('let a = {};'.format(build_derived_expr(check.tv1)))
fmt.line('let b = {};'.format(build_derived_expr(check.tv2)))
fmt.comment('On overflow constraint doesn\'t appply')
with fmt.indented('if a.is_none() || b.is_none() {', '};'):
fmt.line('return false;')
with fmt.indented('if a != b {', '};'):
fmt.line('return false;')
elif (isinstance(check, WiderOrEq)):
with fmt.indented('{', '};'):
fmt.line('let a = {};'.format(build_derived_expr(check.tv1)))
fmt.line('let b = {};'.format(build_derived_expr(check.tv2)))
fmt.comment('On overflow constraint doesn\'t appply')
with fmt.indented('if a.is_none() || b.is_none() {', '};'):
fmt.line('return false;')
with fmt.indented('if !a.wider_or_equal(b) {', '};'):
fmt.line('return false;')
else:
assert False, "Unknown check {}".format(check)

View File

@@ -4,7 +4,7 @@ from unittest import TestCase
from srcgen import Formatter
from gen_legalizer import get_runtime_typechecks, emit_runtime_typecheck
from base.instructions import vselect, vsplit, isplit, iconcat, vconcat, \
iconst, b1, icmp, copy # noqa
iconst, b1, icmp, copy, sextend, uextend, ireduce, fdemote, fpromote # noqa
from base.legalize import narrow, expand # noqa
from base.immediates import intcc # noqa
from cdsl.typevar import TypeVar, TypeSet
@@ -57,9 +57,32 @@ def equiv_check(tv1, tv2):
# type: (TypeVar, TypeVar) -> CheckProducer
return lambda typesets: format_check(
typesets,
'if Some({}).map(|t: Type| -> t.as_bool()) != ' +
'Some({}).map(|t: Type| -> t.as_bool()) ' +
'{{\n return false;\n}};\n', tv1, tv2)
'{{\n' +
' let a = {};\n' +
' let b = {};\n' +
' if a.is_none() || b.is_none() {{\n' +
' return false;\n' +
' }};\n' +
' if a != b {{\n' +
' return false;\n' +
' }};\n' +
'}};\n', tv1, tv2)
def wider_check(tv1, tv2):
# type: (TypeVar, TypeVar) -> CheckProducer
return lambda typesets: format_check(
typesets,
'{{\n' +
' let a = {};\n' +
' let b = {};\n' +
' if a.is_none() || b.is_none() {{\n' +
' return false;\n' +
' }};\n' +
' if !a.wider_or_equal(b) {{\n' +
' return false;\n' +
' }};\n' +
'}};\n', tv1, tv2)
def sequence(*args):
@@ -138,8 +161,49 @@ class TestRuntimeChecks(TestCase):
self.v5 << vselect(self.v1, self.v3, self.v4),
)
x = XForm(r, r)
tv2_exp = 'Some({}).map(|t: Type| -> t.as_bool())'\
.format(self.v2.get_typevar().name)
tv3_exp = 'Some({}).map(|t: Type| -> t.as_bool())'\
.format(self.v3.get_typevar().name)
self.check_yo_check(
x, sequence(typeset_check(self.v3, ts),
equiv_check(self.v2.get_typevar(),
self.v3.get_typevar())))
equiv_check(tv2_exp, tv3_exp)))
def test_reduce_extend(self):
# type: () -> None
r = Rtl(
self.v1 << uextend(self.v0),
self.v2 << ireduce(self.v1),
self.v3 << sextend(self.v2),
)
x = XForm(r, r)
tv0_exp = 'Some({})'.format(self.v0.get_typevar().name)
tv1_exp = 'Some({})'.format(self.v1.get_typevar().name)
tv2_exp = 'Some({})'.format(self.v2.get_typevar().name)
tv3_exp = 'Some({})'.format(self.v3.get_typevar().name)
self.check_yo_check(
x, sequence(wider_check(tv1_exp, tv0_exp),
wider_check(tv1_exp, tv2_exp),
wider_check(tv3_exp, tv2_exp)))
def test_demote_promote(self):
# type: () -> None
r = Rtl(
self.v1 << fpromote(self.v0),
self.v2 << fdemote(self.v1),
self.v3 << fpromote(self.v2),
)
x = XForm(r, r)
tv0_exp = 'Some({})'.format(self.v0.get_typevar().name)
tv1_exp = 'Some({})'.format(self.v1.get_typevar().name)
tv2_exp = 'Some({})'.format(self.v2.get_typevar().name)
tv3_exp = 'Some({})'.format(self.v3.get_typevar().name)
self.check_yo_check(
x, sequence(wider_check(tv1_exp, tv0_exp),
wider_check(tv1_exp, tv2_exp),
wider_check(tv3_exp, tv2_exp)))

View File

@@ -236,6 +236,13 @@ impl Type {
pub fn index(self) -> usize {
self.0 as usize
}
/// True iff:
/// 1) self.lane_count() == other.lane_count() and
/// 2) self.lane_bits() >= other.lane_bits()
pub fn wider_or_equal(self, other: Type) -> bool {
self.lane_count() == other.lane_count() && self.lane_bits() >= other.lane_bits()
}
}
impl Display for Type {