Add fix for #114 (#115)

* Reduce code duplication in TypeConstraint subclasses; Add ConstrainWiderOrEqual to ti and to ireduce,{s,u}extend and f{promote,demote}; Fix bug in emitting constraint edges in TypeEnv.dot(); Modify runtime constraint checks to reject match when they encounter overflow

* Rename Constrain types to something shorter; Move lane_bits/lane_counts in subclasses of ValueType; Add wider_or_eq function in rust and python;
This commit is contained in:
d1m0
2017-07-12 08:51:55 -07:00
committed by Jakob Stoklund Olesen
parent de5501bc47
commit a9147ebd30
8 changed files with 471 additions and 132 deletions

View File

@@ -12,6 +12,7 @@ from base.types import i8, f32, f64, b1
from base.immediates import imm64, uimm8, ieee32, ieee64, offset32, uoffset32 from base.immediates import imm64, uimm8, ieee32, ieee64, offset32, uoffset32
from base.immediates import intcc, floatcc, memflags, regunit from base.immediates import intcc, floatcc, memflags, regunit
from base import entities from base import entities
from cdsl.ti import WiderOrEq
import base.formats # noqa import base.formats # noqa
GROUP = InstructionGroup("base", "Shared base instruction set") GROUP = InstructionGroup("base", "Shared base instruction set")
@@ -1405,7 +1406,7 @@ ireduce = Instruction(
and each lane must not have more bits that the input lanes. If the and each lane must not have more bits that the input lanes. If the
input and output types are the same, this is a no-op. input and output types are the same, this is a no-op.
""", """,
ins=x, outs=a) ins=x, outs=a, constraints=WiderOrEq(Int, IntTo))
IntTo = TypeVar( IntTo = TypeVar(
@@ -1427,7 +1428,7 @@ uextend = Instruction(
and each lane must not have fewer bits that the input lanes. If the and each lane must not have fewer bits that the input lanes. If the
input and output types are the same, this is a no-op. input and output types are the same, this is a no-op.
""", """,
ins=x, outs=a) ins=x, outs=a, constraints=WiderOrEq(IntTo, Int))
sextend = Instruction( sextend = Instruction(
'sextend', r""" 'sextend', r"""
@@ -1441,7 +1442,7 @@ sextend = Instruction(
and each lane must not have fewer bits that the input lanes. If the and each lane must not have fewer bits that the input lanes. If the
input and output types are the same, this is a no-op. input and output types are the same, this is a no-op.
""", """,
ins=x, outs=a) ins=x, outs=a, constraints=WiderOrEq(IntTo, Int))
FloatTo = TypeVar( FloatTo = TypeVar(
'FloatTo', 'A scalar or vector floating point number', 'FloatTo', 'A scalar or vector floating point number',
@@ -1457,14 +1458,14 @@ fpromote = Instruction(
Each lane in `x` is converted to the destination floating point format. Each lane in `x` is converted to the destination floating point format.
This is an exact operation. This is an exact operation.
Since Cretonne currently only supports two floating point formats, this Cretonne currently only supports two floating point formats
instruction always converts :type:`f32` to :type:`f64`. This may change - :type:`f32` and :type:`f64`. This may change in the future.
in the future.
The result type must have the same number of vector lanes as the input, The result type must have the same number of vector lanes as the input,
and the result lanes must be larger than the input lanes. and the result lanes must not have fewer bits than the input lanes. If
the input and output types are the same, this is a no-op.
""", """,
ins=x, outs=a) ins=x, outs=a, constraints=WiderOrEq(FloatTo, Float))
fdemote = Instruction( fdemote = Instruction(
'fdemote', r""" 'fdemote', r"""
@@ -1473,14 +1474,14 @@ fdemote = Instruction(
Each lane in `x` is converted to the destination floating point format Each lane in `x` is converted to the destination floating point format
by rounding to nearest, ties to even. by rounding to nearest, ties to even.
Since Cretonne currently only supports two floating point formats, this Cretonne currently only supports two floating point formats
instruction always converts :type:`f64` to :type:`f32`. This may change - :type:`f32` and :type:`f64`. This may change in the future.
in the future.
The result type must have the same number of vector lanes as the input, The result type must have the same number of vector lanes as the input,
and the result lanes must be smaller than the input lanes. and the result lanes must not have more bits than the input lanes. If
the input and output types are the same, this is a no-op.
""", """,
ins=x, outs=a) ins=x, outs=a, constraints=WiderOrEq(Float, FloatTo))
x = Operand('x', Float) x = Operand('x', Float)
a = Operand('a', IntTo) a = Operand('a', IntTo)

View File

@@ -10,8 +10,10 @@ try:
if TYPE_CHECKING: if TYPE_CHECKING:
from .ast import Expr, Apply # noqa from .ast import Expr, Apply # noqa
from .typevar import TypeVar # noqa from .typevar import TypeVar # noqa
from .ti import TypeConstraint # noqa
# List of operands for ins/outs: # List of operands for ins/outs:
OpList = Union[Sequence[Operand], Operand] OpList = Union[Sequence[Operand], Operand]
ConstrList = Union[Sequence[TypeConstraint], TypeConstraint]
MaybeBoundInst = Union['Instruction', 'BoundInstruction'] MaybeBoundInst = Union['Instruction', 'BoundInstruction']
except ImportError: except ImportError:
pass pass
@@ -80,6 +82,7 @@ class Instruction(object):
operands and other operand kinds. operands and other operand kinds.
:param outs: Tuple of output operands. The output operands must be SSA :param outs: Tuple of output operands. The output operands must be SSA
values or `variable_args`. values or `variable_args`.
:param constraints: Tuple of instruction-specific TypeConstraints.
:param is_terminator: This is a terminator instruction. :param is_terminator: This is a terminator instruction.
:param is_branch: This is a branch instruction. :param is_branch: This is a branch instruction.
:param is_call: This is a call instruction. :param is_call: This is a call instruction.
@@ -102,13 +105,14 @@ class Instruction(object):
'can_trap': 'Can this instruction cause a trap?', 'can_trap': 'Can this instruction cause a trap?',
} }
def __init__(self, name, doc, ins=(), outs=(), **kwargs): def __init__(self, name, doc, ins=(), outs=(), constraints=(), **kwargs):
# type: (str, str, OpList, OpList, **Any) -> None # noqa # type: (str, str, OpList, OpList, ConstrList, **Any) -> None
self.name = name self.name = name
self.camel_name = camel_case(name) self.camel_name = camel_case(name)
self.__doc__ = doc self.__doc__ = doc
self.ins = self._to_operand_tuple(ins) self.ins = self._to_operand_tuple(ins)
self.outs = self._to_operand_tuple(outs) self.outs = self._to_operand_tuple(outs)
self.constraints = self._to_constraint_tuple(constraints)
self.format = InstructionFormat.lookup(self.ins, self.outs) self.format = InstructionFormat.lookup(self.ins, self.outs)
# Opcode number, assigned by gen_instr.py. # Opcode number, assigned by gen_instr.py.
@@ -268,6 +272,23 @@ class Instruction(object):
assert isinstance(op, Operand) assert isinstance(op, Operand)
return x return x
@staticmethod
def _to_constraint_tuple(x):
# type: (ConstrList) -> Tuple[TypeConstraint, ...]
"""
Allow a single TypeConstraint instance instead of the awkward singleton
tuple syntax.
"""
# import placed here to avoid circular dependency
from .ti import TypeConstraint # noqa
if isinstance(x, TypeConstraint):
x = (x,)
else:
x = tuple(x)
for op in x:
assert isinstance(op, TypeConstraint)
return x
def bind(self, *args): def bind(self, *args):
# type: (*ValueType) -> BoundInstruction # type: (*ValueType) -> BoundInstruction
""" """

View File

@@ -1,13 +1,14 @@
from __future__ import absolute_import from __future__ import absolute_import
from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\ from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\
b1, icmp, iadd_cout, iadd_cin, uextend, ireduce b1, icmp, iadd_cout, iadd_cin, uextend, sextend, ireduce, fpromote, \
fdemote
from base.legalize import narrow, expand from base.legalize import narrow, expand
from base.immediates import intcc from base.immediates import intcc
from base.types import i32, i8 from base.types import i32, i8
from .typevar import TypeVar from .typevar import TypeVar
from .ast import Var, Def from .ast import Var, Def
from .xform import Rtl, XForm from .xform import Rtl, XForm
from .ti import ti_rtl, subst, TypeEnv, get_type_env, ConstrainTVsEqual from .ti import ti_rtl, subst, TypeEnv, get_type_env, TypesEqual, WiderOrEq
from unittest import TestCase from unittest import TestCase
from functools import reduce from functools import reduce
@@ -52,9 +53,10 @@ def agree(me, other):
# Translate our constraints using m, and sort # Translate our constraints using m, and sort
me_equiv_constr = sorted([constr.translate(m) me_equiv_constr = sorted([constr.translate(m)
for constr in me.constraints]) for constr in me.constraints], key=repr)
# Sort other's constraints # Sort other's constraints
other_equiv_constr = sorted(other.constraints) other_equiv_constr = sorted([constr.translate(other)
for constr in other.constraints], key=repr)
return me_equiv_constr == other_equiv_constr return me_equiv_constr == other_equiv_constr
@@ -78,7 +80,7 @@ def check_typing(got_or_err, expected, symtab=None):
tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()} tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()}
# Rewrite the TVs in the input constraints to their XForm internal # Rewrite the TVs in the input constraints to their XForm internal
# versions # versions
c = [(subst(a, subst_m), subst(b, subst_m)) for (a, b) in c] c = [constr.translate(subst_m) for constr in c]
else: else:
# If no symtab, just convert m from Var->TypeVar map to a # If no symtab, just convert m from Var->TypeVar map to a
# TypeVar->TypeVar map # TypeVar->TypeVar map
@@ -209,7 +211,7 @@ class TestRTL(TypeCheckingBaseTest):
self.v3: txn, self.v3: txn,
self.v4: txn, self.v4: txn,
self.v5: txn, self.v5: txn,
}, [ConstrainTVsEqual(ixn.as_bool(), txn.as_bool())])) }, [TypesEqual(ixn.as_bool(), txn.as_bool())]))
def test_vselect_vsplits(self): def test_vselect_vsplits(self):
# type: () -> None # type: () -> None
@@ -319,6 +321,90 @@ class TestRTL(TypeCheckingBaseTest):
"Error: empty type created when unifying " + "Error: empty type created when unifying " +
"`typeof_v4` and `typeof_v5`") "`typeof_v4` and `typeof_v5`")
def test_extend_reduce(self):
# type: () -> None
r = Rtl(
self.v1 << uextend(self.v0),
self.v2 << ireduce(self.v1),
self.v3 << sextend(self.v2),
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
typing = typing.extract()
itype0 = TypeVar("t", "", ints=True, simd=(1, 256))
itype1 = TypeVar("t1", "", ints=True, simd=(1, 256))
itype2 = TypeVar("t2", "", ints=True, simd=(1, 256))
itype3 = TypeVar("t3", "", ints=True, simd=(1, 256))
check_typing(typing, ({
self.v0: itype0,
self.v1: itype1,
self.v2: itype2,
self.v3: itype3,
}, [WiderOrEq(itype1, itype0),
WiderOrEq(itype1, itype2),
WiderOrEq(itype3, itype2)]))
def test_extend_reduce_enumeration(self):
# type: () -> None
for op in (uextend, sextend, ireduce):
r = Rtl(
self.v1 << op(self.v0),
)
ti = TypeEnv()
typing = ti_rtl(r, ti).extract()
# The number of possible typings is 9 * (3+ 2*2 + 3) = 90
l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
assert (len(l) == len(set(l)) and len(l) == 90)
for (tv0, tv1) in l:
typ0, typ1 = (tv0.singleton_type(), tv1.singleton_type())
if (op == ireduce):
assert typ0.wider_or_equal(typ1)
else:
assert typ1.wider_or_equal(typ0)
def test_fpromote_fdemote(self):
# type: () -> None
r = Rtl(
self.v1 << fpromote(self.v0),
self.v2 << fdemote(self.v1),
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
typing = typing.extract()
ftype0 = TypeVar("t", "", floats=True, simd=(1, 256))
ftype1 = TypeVar("t1", "", floats=True, simd=(1, 256))
ftype2 = TypeVar("t2", "", floats=True, simd=(1, 256))
check_typing(typing, ({
self.v0: ftype0,
self.v1: ftype1,
self.v2: ftype2,
}, [WiderOrEq(ftype1, ftype0),
WiderOrEq(ftype1, ftype2)]))
def test_fpromote_fdemote_enumeration(self):
# type: () -> None
for op in (fpromote, fdemote):
r = Rtl(
self.v1 << op(self.v0),
)
ti = TypeEnv()
typing = ti_rtl(r, ti).extract()
# The number of possible typings is 9*(2 + 1) = 27
l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()]
assert (len(l) == len(set(l)) and len(l) == 27)
for (tv0, tv1) in l:
(typ0, typ1) = (tv0.singleton_type(), tv1.singleton_type())
if (op == fdemote):
assert typ0.wider_or_equal(typ1)
else:
assert typ1.wider_or_equal(typ0)
class TestXForm(TypeCheckingBaseTest): class TestXForm(TypeCheckingBaseTest):
def test_iadd_cout(self): def test_iadd_cout(self):
@@ -453,7 +539,7 @@ class TestXForm(TypeCheckingBaseTest):
self.v3: i32t, self.v3: i32t,
self.v4: i32t, self.v4: i32t,
self.v5: i32t, self.v5: i32t,
}, []), x.symtab) }, [WiderOrEq(i32t, itype)]), x.symtab)
def test_bound_inst_inference1(self): def test_bound_inst_inference1(self):
# Second example taken from issue #26 # Second example taken from issue #26
@@ -477,7 +563,7 @@ class TestXForm(TypeCheckingBaseTest):
self.v3: i32t, self.v3: i32t,
self.v4: i32t, self.v4: i32t,
self.v5: i32t, self.v5: i32t,
}, []), x.symtab) }, [WiderOrEq(i32t, itype)]), x.symtab)
def test_fully_bound_inst_inference(self): def test_fully_bound_inst_inference(self):
# Second example taken from issue #26 with complete bounds # Second example taken from issue #26 with complete bounds
@@ -494,6 +580,7 @@ class TestXForm(TypeCheckingBaseTest):
i8t = TypeVar.singleton(i8) i8t = TypeVar.singleton(i8)
i32t = TypeVar.singleton(i32) i32t = TypeVar.singleton(i32)
# Note no constraints here since they are all trivial
check_typing(x.ti, ({ check_typing(x.ti, ({
self.v0: i8t, self.v0: i8t,
self.v1: i8t, self.v1: i8t,

View File

@@ -8,7 +8,7 @@ from itertools import product
try: try:
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
from typing import Iterable, List # noqa from typing import Iterable, List, Any # noqa
from typing import cast from typing import cast
from .xform import Rtl, XForm # noqa from .xform import Rtl, XForm # noqa
from .ast import Expr # noqa from .ast import Expr # noqa
@@ -25,9 +25,72 @@ class TypeConstraint(object):
""" """
Base class for all runtime-emittable type constraints. Base class for all runtime-emittable type constraints.
""" """
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> TypeConstraint
"""
Translate any TypeVars in the constraint according to the map or
TypeEnv m
"""
def translate_one(a):
# type: (Any) -> Any
if (isinstance(a, TypeVar)):
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
return a
res = None # type: TypeConstraint
res = self.__class__(*tuple(map(translate_one, self._args())))
return res
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, self.__class__)):
return False
assert isinstance(other, TypeConstraint) # help MyPy figure out other
return self._args() == other._args()
def is_concrete(self):
# type: () -> bool
"""
Return true iff all typevars in the constraint are singletons.
"""
tvs = filter(lambda x: isinstance(x, TypeVar), self._args())
return [] == list(filter(lambda x: x.singleton_type() is None, tvs))
def __hash__(self):
# type: () -> int
return hash(self._args())
def _args(self):
# type: () -> Tuple[Any,...]
"""
Return a tuple with the exact arguments passed to __init__ to create
this object.
"""
assert False, "Abstract"
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
assert False, "Abstract"
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert False, "Abstract"
def __repr__(self):
# type: () -> str
return (self.__class__.__name__ + '(' +
', '.join(map(str, self._args())) + ')')
class ConstrainTVsEqual(TypeConstraint): class TypesEqual(TypeConstraint):
""" """
Constraint specifying that two derived type vars must have the same runtime Constraint specifying that two derived type vars must have the same runtime
type. type.
@@ -37,48 +100,24 @@ class ConstrainTVsEqual(TypeConstraint):
assert tv1.is_derived and tv2.is_derived assert tv1.is_derived and tv2.is_derived
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr) (self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv1, self.tv2)
def is_trivial(self): def is_trivial(self):
# type: () -> bool # type: () -> bool
""" """ See TypeConstraint.is_trivial() """
Return true if this constrain is statically decidable. return self.tv1 == self.tv2 or self.is_concrete()
"""
return self.tv1 == self.tv2 or \
(self.tv1.singleton_type() is not None and
self.tv2.singleton_type() is not None)
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVsEqual(m[self.tv1], m[self.tv2])
else:
return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m))
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVsEqual)):
return False
return (self.tv1, self.tv2) == (other.tv1, other.tv2)
def __hash__(self):
# type: () -> int
return hash((self.tv1, self.tv2))
def eval(self): def eval(self):
# type: () -> bool # type: () -> bool
""" """ See TypeConstraint.eval() """
Evaluate this constraint. Should only be called when the constraint has assert self.is_concrete()
been translated to concrete types.
"""
assert self.tv1.singleton_type() is not None and \
self.tv2.singleton_type() is not None
return self.tv1.singleton_type() == self.tv2.singleton_type() return self.tv1.singleton_type() == self.tv2.singleton_type()
class ConstrainTVInTypeset(TypeConstraint): class InTypeset(TypeConstraint):
""" """
Constraint specifying that a type var must belong to some typeset. Constraint specifying that a type var must belong to some typeset.
""" """
@@ -88,11 +127,14 @@ class ConstrainTVInTypeset(TypeConstraint):
self.tv = tv self.tv = tv
self.ts = ts self.ts = ts
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv, self.ts)
def is_trivial(self): def is_trivial(self):
# type: () -> bool # type: () -> bool
""" """ See TypeConstraint.is_trivial() """
Return true if this constrain is statically decidable.
"""
tv_ts = self.tv.get_typeset().copy() tv_ts = self.tv.get_typeset().copy()
# Trivially True # Trivially True
@@ -104,39 +146,78 @@ class ConstrainTVInTypeset(TypeConstraint):
if (tv_ts.size() == 0): if (tv_ts.size() == 0):
return True return True
return False return self.is_concrete()
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVInTypeset(m[self.tv], self.ts)
else:
return ConstrainTVInTypeset(subst(self.tv, m), self.ts)
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVInTypeset)):
return False
return (self.tv, self.ts) == (other.tv, other.ts)
def __hash__(self):
# type: () -> int
return hash((self.tv, self.ts))
def eval(self): def eval(self):
# type: () -> bool # type: () -> bool
""" """ See TypeConstraint.eval() """
Evaluate this constraint. Should only be called when the constraint has assert self.is_concrete()
been translated to concrete types.
"""
assert self.tv.singleton_type() is not None
return self.tv.get_typeset().issubset(self.ts) return self.tv.get_typeset().issubset(self.ts)
class WiderOrEq(TypeConstraint):
"""
Constraint specifying that a type var tv1 must be wider than or equal to
type var tv2 at runtime. This requires that:
1) They have the same number of lanes
2) In a lane tv1 has at least as many bits as tv2.
"""
def __init__(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
self.tv1 = tv1
self.tv2 = tv2
def _args(self):
# type: () -> Tuple[Any,...]
""" See TypeConstraint._args() """
return (self.tv1, self.tv2)
def is_trivial(self):
# type: () -> bool
""" See TypeConstraint.is_trivial() """
# Trivially true
if (self.tv1 == self.tv2):
return True
ts1 = self.tv1.get_typeset()
ts2 = self.tv2.get_typeset()
def set_wider_or_equal(s1, s2):
# type: (Set[int], Set[int]) -> bool
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
# Trivially True
if set_wider_or_equal(ts1.ints, ts2.ints) and\
set_wider_or_equal(ts1.floats, ts2.floats) and\
set_wider_or_equal(ts1.bools, ts2.bools):
return True
def set_narrower(s1, s2):
# type: (Set[int], Set[int]) -> bool
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
# Trivially False
if set_narrower(ts1.ints, ts2.ints) and\
set_narrower(ts1.floats, ts2.floats) and\
set_narrower(ts1.bools, ts2.bools):
return True
# Trivially False
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
return True
return self.is_concrete()
def eval(self):
# type: () -> bool
""" See TypeConstraint.eval() """
assert self.is_concrete()
typ1 = self.tv1.singleton_type()
typ2 = self.tv2.singleton_type()
return typ1.wider_or_equal(typ2)
class TypeEnv(object): class TypeEnv(object):
""" """
Class encapsulating the neccessary book keeping for type inference. Class encapsulating the neccessary book keeping for type inference.
@@ -204,12 +285,11 @@ class TypeEnv(object):
self.type_map[tv1] = tv2 self.type_map[tv1] = tv2
def add_constraint(self, tv1, tv2): def add_constraint(self, constr):
# type: (TypeVar, TypeVar) -> None # type: (TypeConstraint) -> None
""" """
Add a new equivalence constraint between tv1 and tv2 Add a new constraint
""" """
constr = ConstrainTVsEqual(tv1, tv2)
if (constr not in self.constraints): if (constr not in self.constraints):
self.constraints.append(constr) self.constraints.append(constr)
@@ -261,6 +341,7 @@ class TypeEnv(object):
Get the free typevars in the current type env. Get the free typevars in the current type env.
""" """
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()]) tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
# Filter out None here due to singleton type vars # Filter out None here due to singleton type vars
return sorted(filter(lambda x: x is not None, tvs), return sorted(filter(lambda x: x is not None, tvs),
key=lambda x: x.name) key=lambda x: x.name)
@@ -326,17 +407,18 @@ class TypeEnv(object):
new_constraints = [] # type: List[TypeConstraint] new_constraints = [] # type: List[TypeConstraint]
for constr in self.constraints: for constr in self.constraints:
# Currently typeinference only generates ConstrainTVsEqual
# constraints
assert isinstance(constr, ConstrainTVsEqual)
constr = constr.translate(self) constr = constr.translate(self)
if constr.is_trivial() or constr in new_constraints: if constr.is_trivial() or constr in new_constraints:
continue continue
# Sanity: translated constraints should refer to only real vars # Sanity: translated constraints should refer to only real vars
assert constr.tv1.free_typevar() in vars_tvs and\ for arg in constr._args():
constr.tv2.free_typevar() in vars_tvs if (not isinstance(arg, TypeVar)):
continue
arg_free_tv = arg.free_typevar()
assert arg_free_tv is None or arg_free_tv in vars_tvs
new_constraints.append(constr) new_constraints.append(constr)
@@ -372,9 +454,6 @@ class TypeEnv(object):
# Check if constraints are satisfied for this typing # Check if constraints are satisfied for this typing
failed = None failed = None
for constr in self.constraints: for constr in self.constraints:
# Currently typeinference only generates ConstrainTVsEqual
# constraints
assert isinstance(constr, ConstrainTVsEqual)
concrete_constr = constr.translate(m) concrete_constr = constr.translate(m)
if not concrete_constr.eval(): if not concrete_constr.eval():
failed = concrete_constr failed = concrete_constr
@@ -401,22 +480,27 @@ class TypeEnv(object):
# Add all registered TVs (as some of them may be singleton nodes not # Add all registered TVs (as some of them may be singleton nodes not
# appearing in the graph # appearing in the graph
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar] nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, Optional[str]]] edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
for (k, v) in self.type_map.items(): for (k, v) in self.type_map.items():
# Add all intermediate TVs appearing in edges # Add all intermediate TVs appearing in edges
nodes.add(k) nodes.add(k)
nodes.add(v) nodes.add(v)
edges.add((k, v, "dotted", None)) edges.add((k, v, "dotted", "forward", None))
while (v.is_derived): while (v.is_derived):
nodes.add(v.base) nodes.add(v.base)
edges.add((v, v.base, "solid", v.derived_func)) edges.add((v, v.base, "solid", "forward", v.derived_func))
v = v.base v = v.base
for constr in self.constraints: for constr in self.constraints:
assert isinstance(constr, ConstrainTVsEqual) if isinstance(constr, TypesEqual):
assert constr.tv1 in nodes and constr.tv2 in nodes assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", None)) edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
elif isinstance(constr, WiderOrEq):
assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
else:
assert False, "Can't display constraint {}".format(constr)
root_nodes = set([x for x in nodes root_nodes = set([x for x in nodes
if x not in self.type_map and not x.is_derived]) if x not in self.type_map and not x.is_derived])
@@ -428,17 +512,12 @@ class TypeEnv(object):
r += "[xlabel=\"{}\"]".format(self[n].get_typeset()) r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
r += ";\n" r += ";\n"
for (n1, n2, style, elabel) in edges: for (n1, n2, style, direction, elabel) in edges:
e = label(n1) e = label(n1) + "->" + label(n2)
if style == "dashed": e += "[style={},dir={}".format(style, direction)
e += '--'
else:
e += '->'
e += label(n2)
e += "[style={}".format(style)
if elabel is not None: if elabel is not None:
e += ",label={}".format(elabel) e += ",label=\"{}\"".format(elabel)
e += "];\n" e += "];\n"
r += e r += e
@@ -589,7 +668,7 @@ def unify(tv1, tv2, typ):
inv_f = TypeVar.inverse_func(tv1.derived_func) inv_f = TypeVar.inverse_func(tv1.derived_func)
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ) return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
typ.add_constraint(tv1, tv2) typ.add_constraint(TypesEqual(tv1, tv2))
return typ return typ
@@ -648,6 +727,10 @@ def ti_def(definition, typ):
typ = get_type_env(typ_or_err) typ = get_type_env(typ_or_err)
# Add any instruction specific constraints
for constr in inst.constraints:
typ.add_constraint(constr.translate(m))
return typ return typ

View File

@@ -49,6 +49,26 @@ class ValueType(object):
else: else:
raise AttributeError("No type named '{}'".format(name)) raise AttributeError("No type named '{}'".format(name))
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
assert False, "Abstract"
def lane_count(self):
# type: () -> int
"""Return the number of lanes."""
assert False, "Abstract"
def wider_or_equal(self, other):
# type: (ValueType) -> bool
"""
Return true iff:
1. self and other have equal number of lanes
2. each lane in self has at least as many bits as a lane in other
"""
return (self.lane_count() == other.lane_count() and
self.lane_bits() >= other.lane_bits())
class ScalarType(ValueType): class ScalarType(ValueType):
""" """
@@ -85,6 +105,11 @@ class ScalarType(ValueType):
self._vectors[lanes] = v self._vectors[lanes] = v
return v return v
def lane_count(self):
# type: () -> int
"""Return the number of lanes."""
return 1
class VectorType(ValueType): class VectorType(ValueType):
""" """
@@ -112,6 +137,16 @@ class VectorType(ValueType):
return ('VectorType(base={}, lanes={})' return ('VectorType(base={}, lanes={})'
.format(self.base.name, self.lanes)) .format(self.base.name, self.lanes))
def lane_count(self):
# type: () -> int
"""Return the number of lanes."""
return self.lanes
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.base.lane_bits()
class IntType(ScalarType): class IntType(ScalarType):
"""A concrete scalar integer type.""" """A concrete scalar integer type."""
@@ -138,6 +173,11 @@ class IntType(ScalarType):
else: else:
return typ return typ
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.bits
class FloatType(ScalarType): class FloatType(ScalarType):
"""A concrete scalar floating point type.""" """A concrete scalar floating point type."""
@@ -164,6 +204,11 @@ class FloatType(ScalarType):
else: else:
return typ return typ
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.bits
class BoolType(ScalarType): class BoolType(ScalarType):
"""A concrete scalar boolean type.""" """A concrete scalar boolean type."""
@@ -189,3 +234,8 @@ class BoolType(ScalarType):
return cast(BoolType, typ) return cast(BoolType, typ)
else: else:
return typ return typ
def lane_bits(self):
# type: () -> int
"""Return the number of bits in a lane."""
return self.bits

View File

@@ -11,8 +11,8 @@ from __future__ import absolute_import
from srcgen import Formatter from srcgen import Formatter
from base import legalize, instructions from base import legalize, instructions
from cdsl.ast import Var from cdsl.ast import Var
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, ConstrainTVsEqual,\ from cdsl.ti import ti_rtl, TypeEnv, get_type_env, TypesEqual,\
ConstrainTVInTypeset InTypeset, WiderOrEq
from unique_table import UniqueTable from unique_table import UniqueTable
from gen_instr import gen_typesets_table from gen_instr import gen_typesets_table
from cdsl.typevar import TypeVar from cdsl.typevar import TypeVar
@@ -66,7 +66,7 @@ def get_runtime_typechecks(xform):
assert xform_ts.issubset(src_ts) assert xform_ts.issubset(src_ts)
if src_ts != xform_ts: if src_ts != xform_ts:
check_l.append(ConstrainTVInTypeset(xform.ti[v], xform_ts)) check_l.append(InTypeset(xform.ti[v], xform_ts))
# 2,3) Add any constraints that appear in xform.ti # 2,3) Add any constraints that appear in xform.ti
check_l.extend(xform.ti.constraints) check_l.extend(xform.ti.constraints)
@@ -81,6 +81,14 @@ def emit_runtime_typecheck(check, fmt, type_sets):
""" """
def build_derived_expr(tv): def build_derived_expr(tv):
# type: (TypeVar) -> str # type: (TypeVar) -> str
"""
Build an expression of type Option<Type> corresponding to a concrete
type transformed by the sequence of derivation functions in tv.
We are using Option<Type>, as some constraints may cause an
over/underflow on patterns that do not match them. We want to capture
this without panicking at runtime.
"""
if not tv.is_derived: if not tv.is_derived:
assert tv.name.startswith('typeof_') assert tv.name.startswith('typeof_')
return "Some({})".format(tv.name) return "Some({})".format(tv.name)
@@ -102,7 +110,8 @@ def emit_runtime_typecheck(check, fmt, type_sets):
else: else:
assert False, "Unknown derived function {}".format(tv.derived_func) assert False, "Unknown derived function {}".format(tv.derived_func)
if (isinstance(check, ConstrainTVInTypeset)): if (isinstance(check, InTypeset)):
assert not check.tv.is_derived
tv = check.tv.name tv = check.tv.name
if check.ts not in type_sets.index: if check.ts not in type_sets.index:
type_sets.add(check.ts) type_sets.add(check.ts)
@@ -112,11 +121,28 @@ def emit_runtime_typecheck(check, fmt, type_sets):
with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv), with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv),
'};'): '};'):
fmt.line('return false;') fmt.line('return false;')
elif (isinstance(check, ConstrainTVsEqual)): elif (isinstance(check, TypesEqual)):
tv1 = build_derived_expr(check.tv1) with fmt.indented('{', '};'):
tv2 = build_derived_expr(check.tv2) fmt.line('let a = {};'.format(build_derived_expr(check.tv1)))
with fmt.indented('if {} != {} {{'.format(tv1, tv2), '};'): fmt.line('let b = {};'.format(build_derived_expr(check.tv2)))
fmt.line('return false;')
fmt.comment('On overflow constraint doesn\'t appply')
with fmt.indented('if a.is_none() || b.is_none() {', '};'):
fmt.line('return false;')
with fmt.indented('if a != b {', '};'):
fmt.line('return false;')
elif (isinstance(check, WiderOrEq)):
with fmt.indented('{', '};'):
fmt.line('let a = {};'.format(build_derived_expr(check.tv1)))
fmt.line('let b = {};'.format(build_derived_expr(check.tv2)))
fmt.comment('On overflow constraint doesn\'t appply')
with fmt.indented('if a.is_none() || b.is_none() {', '};'):
fmt.line('return false;')
with fmt.indented('if !a.wider_or_equal(b) {', '};'):
fmt.line('return false;')
else: else:
assert False, "Unknown check {}".format(check) assert False, "Unknown check {}".format(check)

View File

@@ -4,7 +4,7 @@ from unittest import TestCase
from srcgen import Formatter from srcgen import Formatter
from gen_legalizer import get_runtime_typechecks, emit_runtime_typecheck from gen_legalizer import get_runtime_typechecks, emit_runtime_typecheck
from base.instructions import vselect, vsplit, isplit, iconcat, vconcat, \ from base.instructions import vselect, vsplit, isplit, iconcat, vconcat, \
iconst, b1, icmp, copy # noqa iconst, b1, icmp, copy, sextend, uextend, ireduce, fdemote, fpromote # noqa
from base.legalize import narrow, expand # noqa from base.legalize import narrow, expand # noqa
from base.immediates import intcc # noqa from base.immediates import intcc # noqa
from cdsl.typevar import TypeVar, TypeSet from cdsl.typevar import TypeVar, TypeSet
@@ -57,9 +57,32 @@ def equiv_check(tv1, tv2):
# type: (TypeVar, TypeVar) -> CheckProducer # type: (TypeVar, TypeVar) -> CheckProducer
return lambda typesets: format_check( return lambda typesets: format_check(
typesets, typesets,
'if Some({}).map(|t: Type| -> t.as_bool()) != ' + '{{\n' +
'Some({}).map(|t: Type| -> t.as_bool()) ' + ' let a = {};\n' +
'{{\n return false;\n}};\n', tv1, tv2) ' let b = {};\n' +
' if a.is_none() || b.is_none() {{\n' +
' return false;\n' +
' }};\n' +
' if a != b {{\n' +
' return false;\n' +
' }};\n' +
'}};\n', tv1, tv2)
def wider_check(tv1, tv2):
# type: (TypeVar, TypeVar) -> CheckProducer
return lambda typesets: format_check(
typesets,
'{{\n' +
' let a = {};\n' +
' let b = {};\n' +
' if a.is_none() || b.is_none() {{\n' +
' return false;\n' +
' }};\n' +
' if !a.wider_or_equal(b) {{\n' +
' return false;\n' +
' }};\n' +
'}};\n', tv1, tv2)
def sequence(*args): def sequence(*args):
@@ -138,8 +161,49 @@ class TestRuntimeChecks(TestCase):
self.v5 << vselect(self.v1, self.v3, self.v4), self.v5 << vselect(self.v1, self.v3, self.v4),
) )
x = XForm(r, r) x = XForm(r, r)
tv2_exp = 'Some({}).map(|t: Type| -> t.as_bool())'\
.format(self.v2.get_typevar().name)
tv3_exp = 'Some({}).map(|t: Type| -> t.as_bool())'\
.format(self.v3.get_typevar().name)
self.check_yo_check( self.check_yo_check(
x, sequence(typeset_check(self.v3, ts), x, sequence(typeset_check(self.v3, ts),
equiv_check(self.v2.get_typevar(), equiv_check(tv2_exp, tv3_exp)))
self.v3.get_typevar())))
def test_reduce_extend(self):
# type: () -> None
r = Rtl(
self.v1 << uextend(self.v0),
self.v2 << ireduce(self.v1),
self.v3 << sextend(self.v2),
)
x = XForm(r, r)
tv0_exp = 'Some({})'.format(self.v0.get_typevar().name)
tv1_exp = 'Some({})'.format(self.v1.get_typevar().name)
tv2_exp = 'Some({})'.format(self.v2.get_typevar().name)
tv3_exp = 'Some({})'.format(self.v3.get_typevar().name)
self.check_yo_check(
x, sequence(wider_check(tv1_exp, tv0_exp),
wider_check(tv1_exp, tv2_exp),
wider_check(tv3_exp, tv2_exp)))
def test_demote_promote(self):
# type: () -> None
r = Rtl(
self.v1 << fpromote(self.v0),
self.v2 << fdemote(self.v1),
self.v3 << fpromote(self.v2),
)
x = XForm(r, r)
tv0_exp = 'Some({})'.format(self.v0.get_typevar().name)
tv1_exp = 'Some({})'.format(self.v1.get_typevar().name)
tv2_exp = 'Some({})'.format(self.v2.get_typevar().name)
tv3_exp = 'Some({})'.format(self.v3.get_typevar().name)
self.check_yo_check(
x, sequence(wider_check(tv1_exp, tv0_exp),
wider_check(tv1_exp, tv2_exp),
wider_check(tv3_exp, tv2_exp)))

View File

@@ -236,6 +236,13 @@ impl Type {
pub fn index(self) -> usize { pub fn index(self) -> usize {
self.0 as usize self.0 as usize
} }
/// True iff:
/// 1) self.lane_count() == other.lane_count() and
/// 2) self.lane_bits() >= other.lane_bits()
pub fn wider_or_equal(self, other: Type) -> bool {
self.lane_count() == other.lane_count() && self.lane_bits() >= other.lane_bits()
}
} }
impl Display for Type { impl Display for Type {