diff --git a/lib/cretonne/meta/base/instructions.py b/lib/cretonne/meta/base/instructions.py index dc9accf1cc..109bc73f9e 100644 --- a/lib/cretonne/meta/base/instructions.py +++ b/lib/cretonne/meta/base/instructions.py @@ -12,6 +12,7 @@ from base.types import i8, f32, f64, b1 from base.immediates import imm64, uimm8, ieee32, ieee64, offset32, uoffset32 from base.immediates import intcc, floatcc, memflags, regunit from base import entities +from cdsl.ti import WiderOrEq import base.formats # noqa GROUP = InstructionGroup("base", "Shared base instruction set") @@ -1405,7 +1406,7 @@ ireduce = Instruction( and each lane must not have more bits that the input lanes. If the input and output types are the same, this is a no-op. """, - ins=x, outs=a) + ins=x, outs=a, constraints=WiderOrEq(Int, IntTo)) IntTo = TypeVar( @@ -1427,7 +1428,7 @@ uextend = Instruction( and each lane must not have fewer bits that the input lanes. If the input and output types are the same, this is a no-op. """, - ins=x, outs=a) + ins=x, outs=a, constraints=WiderOrEq(IntTo, Int)) sextend = Instruction( 'sextend', r""" @@ -1441,7 +1442,7 @@ sextend = Instruction( and each lane must not have fewer bits that the input lanes. If the input and output types are the same, this is a no-op. """, - ins=x, outs=a) + ins=x, outs=a, constraints=WiderOrEq(IntTo, Int)) FloatTo = TypeVar( 'FloatTo', 'A scalar or vector floating point number', @@ -1457,14 +1458,14 @@ fpromote = Instruction( Each lane in `x` is converted to the destination floating point format. This is an exact operation. - Since Cretonne currently only supports two floating point formats, this - instruction always converts :type:`f32` to :type:`f64`. This may change - in the future. + Cretonne currently only supports two floating point formats + - :type:`f32` and :type:`f64`. This may change in the future. The result type must have the same number of vector lanes as the input, - and the result lanes must be larger than the input lanes. + and the result lanes must not have fewer bits than the input lanes. If + the input and output types are the same, this is a no-op. """, - ins=x, outs=a) + ins=x, outs=a, constraints=WiderOrEq(FloatTo, Float)) fdemote = Instruction( 'fdemote', r""" @@ -1473,14 +1474,14 @@ fdemote = Instruction( Each lane in `x` is converted to the destination floating point format by rounding to nearest, ties to even. - Since Cretonne currently only supports two floating point formats, this - instruction always converts :type:`f64` to :type:`f32`. This may change - in the future. + Cretonne currently only supports two floating point formats + - :type:`f32` and :type:`f64`. This may change in the future. The result type must have the same number of vector lanes as the input, - and the result lanes must be smaller than the input lanes. + and the result lanes must not have more bits than the input lanes. If + the input and output types are the same, this is a no-op. """, - ins=x, outs=a) + ins=x, outs=a, constraints=WiderOrEq(Float, FloatTo)) x = Operand('x', Float) a = Operand('a', IntTo) diff --git a/lib/cretonne/meta/cdsl/instructions.py b/lib/cretonne/meta/cdsl/instructions.py index 22c989bd65..511459fdbe 100644 --- a/lib/cretonne/meta/cdsl/instructions.py +++ b/lib/cretonne/meta/cdsl/instructions.py @@ -10,8 +10,10 @@ try: if TYPE_CHECKING: from .ast import Expr, Apply # noqa from .typevar import TypeVar # noqa + from .ti import TypeConstraint # noqa # List of operands for ins/outs: OpList = Union[Sequence[Operand], Operand] + ConstrList = Union[Sequence[TypeConstraint], TypeConstraint] MaybeBoundInst = Union['Instruction', 'BoundInstruction'] except ImportError: pass @@ -80,6 +82,7 @@ class Instruction(object): operands and other operand kinds. :param outs: Tuple of output operands. The output operands must be SSA values or `variable_args`. + :param constraints: Tuple of instruction-specific TypeConstraints. :param is_terminator: This is a terminator instruction. :param is_branch: This is a branch instruction. :param is_call: This is a call instruction. @@ -102,13 +105,14 @@ class Instruction(object): 'can_trap': 'Can this instruction cause a trap?', } - def __init__(self, name, doc, ins=(), outs=(), **kwargs): - # type: (str, str, OpList, OpList, **Any) -> None # noqa + def __init__(self, name, doc, ins=(), outs=(), constraints=(), **kwargs): + # type: (str, str, OpList, OpList, ConstrList, **Any) -> None self.name = name self.camel_name = camel_case(name) self.__doc__ = doc self.ins = self._to_operand_tuple(ins) self.outs = self._to_operand_tuple(outs) + self.constraints = self._to_constraint_tuple(constraints) self.format = InstructionFormat.lookup(self.ins, self.outs) # Opcode number, assigned by gen_instr.py. @@ -268,6 +272,23 @@ class Instruction(object): assert isinstance(op, Operand) return x + @staticmethod + def _to_constraint_tuple(x): + # type: (ConstrList) -> Tuple[TypeConstraint, ...] + """ + Allow a single TypeConstraint instance instead of the awkward singleton + tuple syntax. + """ + # import placed here to avoid circular dependency + from .ti import TypeConstraint # noqa + if isinstance(x, TypeConstraint): + x = (x,) + else: + x = tuple(x) + for op in x: + assert isinstance(op, TypeConstraint) + return x + def bind(self, *args): # type: (*ValueType) -> BoundInstruction """ diff --git a/lib/cretonne/meta/cdsl/test_ti.py b/lib/cretonne/meta/cdsl/test_ti.py index e89cedc16c..f60a9222f5 100644 --- a/lib/cretonne/meta/cdsl/test_ti.py +++ b/lib/cretonne/meta/cdsl/test_ti.py @@ -1,13 +1,14 @@ from __future__ import absolute_import from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint,\ - b1, icmp, iadd_cout, iadd_cin, uextend, ireduce + b1, icmp, iadd_cout, iadd_cin, uextend, sextend, ireduce, fpromote, \ + fdemote from base.legalize import narrow, expand from base.immediates import intcc from base.types import i32, i8 from .typevar import TypeVar from .ast import Var, Def from .xform import Rtl, XForm -from .ti import ti_rtl, subst, TypeEnv, get_type_env, ConstrainTVsEqual +from .ti import ti_rtl, subst, TypeEnv, get_type_env, TypesEqual, WiderOrEq from unittest import TestCase from functools import reduce @@ -52,9 +53,10 @@ def agree(me, other): # Translate our constraints using m, and sort me_equiv_constr = sorted([constr.translate(m) - for constr in me.constraints]) + for constr in me.constraints], key=repr) # Sort other's constraints - other_equiv_constr = sorted(other.constraints) + other_equiv_constr = sorted([constr.translate(other) + for constr in other.constraints], key=repr) return me_equiv_constr == other_equiv_constr @@ -78,7 +80,7 @@ def check_typing(got_or_err, expected, symtab=None): tv_m = {subst(k.get_typevar(), subst_m): v for (k, v) in m.items()} # Rewrite the TVs in the input constraints to their XForm internal # versions - c = [(subst(a, subst_m), subst(b, subst_m)) for (a, b) in c] + c = [constr.translate(subst_m) for constr in c] else: # If no symtab, just convert m from Var->TypeVar map to a # TypeVar->TypeVar map @@ -209,7 +211,7 @@ class TestRTL(TypeCheckingBaseTest): self.v3: txn, self.v4: txn, self.v5: txn, - }, [ConstrainTVsEqual(ixn.as_bool(), txn.as_bool())])) + }, [TypesEqual(ixn.as_bool(), txn.as_bool())])) def test_vselect_vsplits(self): # type: () -> None @@ -319,6 +321,90 @@ class TestRTL(TypeCheckingBaseTest): "Error: empty type created when unifying " + "`typeof_v4` and `typeof_v5`") + def test_extend_reduce(self): + # type: () -> None + r = Rtl( + self.v1 << uextend(self.v0), + self.v2 << ireduce(self.v1), + self.v3 << sextend(self.v2), + ) + ti = TypeEnv() + typing = ti_rtl(r, ti) + typing = typing.extract() + + itype0 = TypeVar("t", "", ints=True, simd=(1, 256)) + itype1 = TypeVar("t1", "", ints=True, simd=(1, 256)) + itype2 = TypeVar("t2", "", ints=True, simd=(1, 256)) + itype3 = TypeVar("t3", "", ints=True, simd=(1, 256)) + + check_typing(typing, ({ + self.v0: itype0, + self.v1: itype1, + self.v2: itype2, + self.v3: itype3, + }, [WiderOrEq(itype1, itype0), + WiderOrEq(itype1, itype2), + WiderOrEq(itype3, itype2)])) + + def test_extend_reduce_enumeration(self): + # type: () -> None + for op in (uextend, sextend, ireduce): + r = Rtl( + self.v1 << op(self.v0), + ) + ti = TypeEnv() + typing = ti_rtl(r, ti).extract() + + # The number of possible typings is 9 * (3+ 2*2 + 3) = 90 + l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()] + assert (len(l) == len(set(l)) and len(l) == 90) + for (tv0, tv1) in l: + typ0, typ1 = (tv0.singleton_type(), tv1.singleton_type()) + if (op == ireduce): + assert typ0.wider_or_equal(typ1) + else: + assert typ1.wider_or_equal(typ0) + + def test_fpromote_fdemote(self): + # type: () -> None + r = Rtl( + self.v1 << fpromote(self.v0), + self.v2 << fdemote(self.v1), + ) + ti = TypeEnv() + typing = ti_rtl(r, ti) + typing = typing.extract() + + ftype0 = TypeVar("t", "", floats=True, simd=(1, 256)) + ftype1 = TypeVar("t1", "", floats=True, simd=(1, 256)) + ftype2 = TypeVar("t2", "", floats=True, simd=(1, 256)) + + check_typing(typing, ({ + self.v0: ftype0, + self.v1: ftype1, + self.v2: ftype2, + }, [WiderOrEq(ftype1, ftype0), + WiderOrEq(ftype1, ftype2)])) + + def test_fpromote_fdemote_enumeration(self): + # type: () -> None + for op in (fpromote, fdemote): + r = Rtl( + self.v1 << op(self.v0), + ) + ti = TypeEnv() + typing = ti_rtl(r, ti).extract() + + # The number of possible typings is 9*(2 + 1) = 27 + l = [(t[self.v0], t[self.v1]) for t in typing.concrete_typings()] + assert (len(l) == len(set(l)) and len(l) == 27) + for (tv0, tv1) in l: + (typ0, typ1) = (tv0.singleton_type(), tv1.singleton_type()) + if (op == fdemote): + assert typ0.wider_or_equal(typ1) + else: + assert typ1.wider_or_equal(typ0) + class TestXForm(TypeCheckingBaseTest): def test_iadd_cout(self): @@ -453,7 +539,7 @@ class TestXForm(TypeCheckingBaseTest): self.v3: i32t, self.v4: i32t, self.v5: i32t, - }, []), x.symtab) + }, [WiderOrEq(i32t, itype)]), x.symtab) def test_bound_inst_inference1(self): # Second example taken from issue #26 @@ -477,7 +563,7 @@ class TestXForm(TypeCheckingBaseTest): self.v3: i32t, self.v4: i32t, self.v5: i32t, - }, []), x.symtab) + }, [WiderOrEq(i32t, itype)]), x.symtab) def test_fully_bound_inst_inference(self): # Second example taken from issue #26 with complete bounds @@ -494,6 +580,7 @@ class TestXForm(TypeCheckingBaseTest): i8t = TypeVar.singleton(i8) i32t = TypeVar.singleton(i32) + # Note no constraints here since they are all trivial check_typing(x.ti, ({ self.v0: i8t, self.v1: i8t, diff --git a/lib/cretonne/meta/cdsl/ti.py b/lib/cretonne/meta/cdsl/ti.py index 85bd92507f..79023b8b34 100644 --- a/lib/cretonne/meta/cdsl/ti.py +++ b/lib/cretonne/meta/cdsl/ti.py @@ -8,7 +8,7 @@ from itertools import product try: from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa - from typing import Iterable, List # noqa + from typing import Iterable, List, Any # noqa from typing import cast from .xform import Rtl, XForm # noqa from .ast import Expr # noqa @@ -25,9 +25,72 @@ class TypeConstraint(object): """ Base class for all runtime-emittable type constraints. """ + def translate(self, m): + # type: (Union[TypeEnv, TypeMap]) -> TypeConstraint + """ + Translate any TypeVars in the constraint according to the map or + TypeEnv m + """ + def translate_one(a): + # type: (Any) -> Any + if (isinstance(a, TypeVar)): + return m[a] if isinstance(m, TypeEnv) else subst(a, m) + return a + + res = None # type: TypeConstraint + res = self.__class__(*tuple(map(translate_one, self._args()))) + return res + + def __eq__(self, other): + # type: (object) -> bool + if (not isinstance(other, self.__class__)): + return False + + assert isinstance(other, TypeConstraint) # help MyPy figure out other + return self._args() == other._args() + + def is_concrete(self): + # type: () -> bool + """ + Return true iff all typevars in the constraint are singletons. + """ + tvs = filter(lambda x: isinstance(x, TypeVar), self._args()) + return [] == list(filter(lambda x: x.singleton_type() is None, tvs)) + + def __hash__(self): + # type: () -> int + return hash(self._args()) + + def _args(self): + # type: () -> Tuple[Any,...] + """ + Return a tuple with the exact arguments passed to __init__ to create + this object. + """ + assert False, "Abstract" + + def is_trivial(self): + # type: () -> bool + """ + Return true if this constrain is statically decidable. + """ + assert False, "Abstract" + + def eval(self): + # type: () -> bool + """ + Evaluate this constraint. Should only be called when the constraint has + been translated to concrete types. + """ + assert False, "Abstract" + + def __repr__(self): + # type: () -> str + return (self.__class__.__name__ + '(' + + ', '.join(map(str, self._args())) + ')') -class ConstrainTVsEqual(TypeConstraint): +class TypesEqual(TypeConstraint): """ Constraint specifying that two derived type vars must have the same runtime type. @@ -37,48 +100,24 @@ class ConstrainTVsEqual(TypeConstraint): assert tv1.is_derived and tv2.is_derived (self.tv1, self.tv2) = sorted([tv1, tv2], key=repr) + def _args(self): + # type: () -> Tuple[Any,...] + """ See TypeConstraint._args() """ + return (self.tv1, self.tv2) + def is_trivial(self): # type: () -> bool - """ - Return true if this constrain is statically decidable. - """ - return self.tv1 == self.tv2 or \ - (self.tv1.singleton_type() is not None and - self.tv2.singleton_type() is not None) - - def translate(self, m): - # type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual - """ - Translate any TypeVars in the constraint according to the map m - """ - if isinstance(m, TypeEnv): - return ConstrainTVsEqual(m[self.tv1], m[self.tv2]) - else: - return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m)) - - def __eq__(self, other): - # type: (object) -> bool - if (not isinstance(other, ConstrainTVsEqual)): - return False - - return (self.tv1, self.tv2) == (other.tv1, other.tv2) - - def __hash__(self): - # type: () -> int - return hash((self.tv1, self.tv2)) + """ See TypeConstraint.is_trivial() """ + return self.tv1 == self.tv2 or self.is_concrete() def eval(self): # type: () -> bool - """ - Evaluate this constraint. Should only be called when the constraint has - been translated to concrete types. - """ - assert self.tv1.singleton_type() is not None and \ - self.tv2.singleton_type() is not None + """ See TypeConstraint.eval() """ + assert self.is_concrete() return self.tv1.singleton_type() == self.tv2.singleton_type() -class ConstrainTVInTypeset(TypeConstraint): +class InTypeset(TypeConstraint): """ Constraint specifying that a type var must belong to some typeset. """ @@ -88,11 +127,14 @@ class ConstrainTVInTypeset(TypeConstraint): self.tv = tv self.ts = ts + def _args(self): + # type: () -> Tuple[Any,...] + """ See TypeConstraint._args() """ + return (self.tv, self.ts) + def is_trivial(self): # type: () -> bool - """ - Return true if this constrain is statically decidable. - """ + """ See TypeConstraint.is_trivial() """ tv_ts = self.tv.get_typeset().copy() # Trivially True @@ -104,39 +146,78 @@ class ConstrainTVInTypeset(TypeConstraint): if (tv_ts.size() == 0): return True - return False - - def translate(self, m): - # type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset - """ - Translate any TypeVars in the constraint according to the map m - """ - if isinstance(m, TypeEnv): - return ConstrainTVInTypeset(m[self.tv], self.ts) - else: - return ConstrainTVInTypeset(subst(self.tv, m), self.ts) - - def __eq__(self, other): - # type: (object) -> bool - if (not isinstance(other, ConstrainTVInTypeset)): - return False - - return (self.tv, self.ts) == (other.tv, other.ts) - - def __hash__(self): - # type: () -> int - return hash((self.tv, self.ts)) + return self.is_concrete() def eval(self): # type: () -> bool - """ - Evaluate this constraint. Should only be called when the constraint has - been translated to concrete types. - """ - assert self.tv.singleton_type() is not None + """ See TypeConstraint.eval() """ + assert self.is_concrete() return self.tv.get_typeset().issubset(self.ts) +class WiderOrEq(TypeConstraint): + """ + Constraint specifying that a type var tv1 must be wider than or equal to + type var tv2 at runtime. This requires that: + 1) They have the same number of lanes + 2) In a lane tv1 has at least as many bits as tv2. + """ + def __init__(self, tv1, tv2): + # type: (TypeVar, TypeVar) -> None + self.tv1 = tv1 + self.tv2 = tv2 + + def _args(self): + # type: () -> Tuple[Any,...] + """ See TypeConstraint._args() """ + return (self.tv1, self.tv2) + + def is_trivial(self): + # type: () -> bool + """ See TypeConstraint.is_trivial() """ + # Trivially true + if (self.tv1 == self.tv2): + return True + + ts1 = self.tv1.get_typeset() + ts2 = self.tv2.get_typeset() + + def set_wider_or_equal(s1, s2): + # type: (Set[int], Set[int]) -> bool + return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2) + + # Trivially True + if set_wider_or_equal(ts1.ints, ts2.ints) and\ + set_wider_or_equal(ts1.floats, ts2.floats) and\ + set_wider_or_equal(ts1.bools, ts2.bools): + return True + + def set_narrower(s1, s2): + # type: (Set[int], Set[int]) -> bool + return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2) + + # Trivially False + if set_narrower(ts1.ints, ts2.ints) and\ + set_narrower(ts1.floats, ts2.floats) and\ + set_narrower(ts1.bools, ts2.bools): + return True + + # Trivially False + if len(ts1.lanes.intersection(ts2.lanes)) == 0: + return True + + return self.is_concrete() + + def eval(self): + # type: () -> bool + """ See TypeConstraint.eval() """ + assert self.is_concrete() + typ1 = self.tv1.singleton_type() + typ2 = self.tv2.singleton_type() + + return typ1.wider_or_equal(typ2) + + class TypeEnv(object): """ Class encapsulating the neccessary book keeping for type inference. @@ -204,12 +285,11 @@ class TypeEnv(object): self.type_map[tv1] = tv2 - def add_constraint(self, tv1, tv2): - # type: (TypeVar, TypeVar) -> None + def add_constraint(self, constr): + # type: (TypeConstraint) -> None """ - Add a new equivalence constraint between tv1 and tv2 + Add a new constraint """ - constr = ConstrainTVsEqual(tv1, tv2) if (constr not in self.constraints): self.constraints.append(constr) @@ -261,6 +341,7 @@ class TypeEnv(object): Get the free typevars in the current type env. """ tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()]) + tvs = tvs.union(set([self[v].free_typevar() for v in self.vars])) # Filter out None here due to singleton type vars return sorted(filter(lambda x: x is not None, tvs), key=lambda x: x.name) @@ -326,17 +407,18 @@ class TypeEnv(object): new_constraints = [] # type: List[TypeConstraint] for constr in self.constraints: - # Currently typeinference only generates ConstrainTVsEqual - # constraints - assert isinstance(constr, ConstrainTVsEqual) constr = constr.translate(self) if constr.is_trivial() or constr in new_constraints: continue # Sanity: translated constraints should refer to only real vars - assert constr.tv1.free_typevar() in vars_tvs and\ - constr.tv2.free_typevar() in vars_tvs + for arg in constr._args(): + if (not isinstance(arg, TypeVar)): + continue + + arg_free_tv = arg.free_typevar() + assert arg_free_tv is None or arg_free_tv in vars_tvs new_constraints.append(constr) @@ -372,9 +454,6 @@ class TypeEnv(object): # Check if constraints are satisfied for this typing failed = None for constr in self.constraints: - # Currently typeinference only generates ConstrainTVsEqual - # constraints - assert isinstance(constr, ConstrainTVsEqual) concrete_constr = constr.translate(m) if not concrete_constr.eval(): failed = concrete_constr @@ -401,22 +480,27 @@ class TypeEnv(object): # Add all registered TVs (as some of them may be singleton nodes not # appearing in the graph nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar] - edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, Optional[str]]] + edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa for (k, v) in self.type_map.items(): # Add all intermediate TVs appearing in edges nodes.add(k) nodes.add(v) - edges.add((k, v, "dotted", None)) + edges.add((k, v, "dotted", "forward", None)) while (v.is_derived): nodes.add(v.base) - edges.add((v, v.base, "solid", v.derived_func)) + edges.add((v, v.base, "solid", "forward", v.derived_func)) v = v.base for constr in self.constraints: - assert isinstance(constr, ConstrainTVsEqual) - assert constr.tv1 in nodes and constr.tv2 in nodes - edges.add((constr.tv1, constr.tv2, "dashed", None)) + if isinstance(constr, TypesEqual): + assert constr.tv1 in nodes and constr.tv2 in nodes + edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal")) + elif isinstance(constr, WiderOrEq): + assert constr.tv1 in nodes and constr.tv2 in nodes + edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">=")) + else: + assert False, "Can't display constraint {}".format(constr) root_nodes = set([x for x in nodes if x not in self.type_map and not x.is_derived]) @@ -428,17 +512,12 @@ class TypeEnv(object): r += "[xlabel=\"{}\"]".format(self[n].get_typeset()) r += ";\n" - for (n1, n2, style, elabel) in edges: - e = label(n1) - if style == "dashed": - e += '--' - else: - e += '->' - e += label(n2) - e += "[style={}".format(style) + for (n1, n2, style, direction, elabel) in edges: + e = label(n1) + "->" + label(n2) + e += "[style={},dir={}".format(style, direction) if elabel is not None: - e += ",label={}".format(elabel) + e += ",label=\"{}\"".format(elabel) e += "];\n" r += e @@ -589,7 +668,7 @@ def unify(tv1, tv2, typ): inv_f = TypeVar.inverse_func(tv1.derived_func) return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ) - typ.add_constraint(tv1, tv2) + typ.add_constraint(TypesEqual(tv1, tv2)) return typ @@ -648,6 +727,10 @@ def ti_def(definition, typ): typ = get_type_env(typ_or_err) + # Add any instruction specific constraints + for constr in inst.constraints: + typ.add_constraint(constr.translate(m)) + return typ diff --git a/lib/cretonne/meta/cdsl/types.py b/lib/cretonne/meta/cdsl/types.py index eeb5325ae8..cebe472fbf 100644 --- a/lib/cretonne/meta/cdsl/types.py +++ b/lib/cretonne/meta/cdsl/types.py @@ -49,6 +49,26 @@ class ValueType(object): else: raise AttributeError("No type named '{}'".format(name)) + def lane_bits(self): + # type: () -> int + """Return the number of bits in a lane.""" + assert False, "Abstract" + + def lane_count(self): + # type: () -> int + """Return the number of lanes.""" + assert False, "Abstract" + + def wider_or_equal(self, other): + # type: (ValueType) -> bool + """ + Return true iff: + 1. self and other have equal number of lanes + 2. each lane in self has at least as many bits as a lane in other + """ + return (self.lane_count() == other.lane_count() and + self.lane_bits() >= other.lane_bits()) + class ScalarType(ValueType): """ @@ -85,6 +105,11 @@ class ScalarType(ValueType): self._vectors[lanes] = v return v + def lane_count(self): + # type: () -> int + """Return the number of lanes.""" + return 1 + class VectorType(ValueType): """ @@ -112,6 +137,16 @@ class VectorType(ValueType): return ('VectorType(base={}, lanes={})' .format(self.base.name, self.lanes)) + def lane_count(self): + # type: () -> int + """Return the number of lanes.""" + return self.lanes + + def lane_bits(self): + # type: () -> int + """Return the number of bits in a lane.""" + return self.base.lane_bits() + class IntType(ScalarType): """A concrete scalar integer type.""" @@ -138,6 +173,11 @@ class IntType(ScalarType): else: return typ + def lane_bits(self): + # type: () -> int + """Return the number of bits in a lane.""" + return self.bits + class FloatType(ScalarType): """A concrete scalar floating point type.""" @@ -164,6 +204,11 @@ class FloatType(ScalarType): else: return typ + def lane_bits(self): + # type: () -> int + """Return the number of bits in a lane.""" + return self.bits + class BoolType(ScalarType): """A concrete scalar boolean type.""" @@ -189,3 +234,8 @@ class BoolType(ScalarType): return cast(BoolType, typ) else: return typ + + def lane_bits(self): + # type: () -> int + """Return the number of bits in a lane.""" + return self.bits diff --git a/lib/cretonne/meta/gen_legalizer.py b/lib/cretonne/meta/gen_legalizer.py index ce2f559ef2..eeb567722e 100644 --- a/lib/cretonne/meta/gen_legalizer.py +++ b/lib/cretonne/meta/gen_legalizer.py @@ -11,8 +11,8 @@ from __future__ import absolute_import from srcgen import Formatter from base import legalize, instructions from cdsl.ast import Var -from cdsl.ti import ti_rtl, TypeEnv, get_type_env, ConstrainTVsEqual,\ - ConstrainTVInTypeset +from cdsl.ti import ti_rtl, TypeEnv, get_type_env, TypesEqual,\ + InTypeset, WiderOrEq from unique_table import UniqueTable from gen_instr import gen_typesets_table from cdsl.typevar import TypeVar @@ -66,7 +66,7 @@ def get_runtime_typechecks(xform): assert xform_ts.issubset(src_ts) if src_ts != xform_ts: - check_l.append(ConstrainTVInTypeset(xform.ti[v], xform_ts)) + check_l.append(InTypeset(xform.ti[v], xform_ts)) # 2,3) Add any constraints that appear in xform.ti check_l.extend(xform.ti.constraints) @@ -81,6 +81,14 @@ def emit_runtime_typecheck(check, fmt, type_sets): """ def build_derived_expr(tv): # type: (TypeVar) -> str + """ + Build an expression of type Option corresponding to a concrete + type transformed by the sequence of derivation functions in tv. + + We are using Option, as some constraints may cause an + over/underflow on patterns that do not match them. We want to capture + this without panicking at runtime. + """ if not tv.is_derived: assert tv.name.startswith('typeof_') return "Some({})".format(tv.name) @@ -102,7 +110,8 @@ def emit_runtime_typecheck(check, fmt, type_sets): else: assert False, "Unknown derived function {}".format(tv.derived_func) - if (isinstance(check, ConstrainTVInTypeset)): + if (isinstance(check, InTypeset)): + assert not check.tv.is_derived tv = check.tv.name if check.ts not in type_sets.index: type_sets.add(check.ts) @@ -112,11 +121,28 @@ def emit_runtime_typecheck(check, fmt, type_sets): with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv), '};'): fmt.line('return false;') - elif (isinstance(check, ConstrainTVsEqual)): - tv1 = build_derived_expr(check.tv1) - tv2 = build_derived_expr(check.tv2) - with fmt.indented('if {} != {} {{'.format(tv1, tv2), '};'): - fmt.line('return false;') + elif (isinstance(check, TypesEqual)): + with fmt.indented('{', '};'): + fmt.line('let a = {};'.format(build_derived_expr(check.tv1))) + fmt.line('let b = {};'.format(build_derived_expr(check.tv2))) + + fmt.comment('On overflow constraint doesn\'t appply') + with fmt.indented('if a.is_none() || b.is_none() {', '};'): + fmt.line('return false;') + + with fmt.indented('if a != b {', '};'): + fmt.line('return false;') + elif (isinstance(check, WiderOrEq)): + with fmt.indented('{', '};'): + fmt.line('let a = {};'.format(build_derived_expr(check.tv1))) + fmt.line('let b = {};'.format(build_derived_expr(check.tv2))) + + fmt.comment('On overflow constraint doesn\'t appply') + with fmt.indented('if a.is_none() || b.is_none() {', '};'): + fmt.line('return false;') + + with fmt.indented('if !a.wider_or_equal(b) {', '};'): + fmt.line('return false;') else: assert False, "Unknown check {}".format(check) diff --git a/lib/cretonne/meta/test_gen_legalizer.py b/lib/cretonne/meta/test_gen_legalizer.py index 38f26959f4..538ff69b63 100644 --- a/lib/cretonne/meta/test_gen_legalizer.py +++ b/lib/cretonne/meta/test_gen_legalizer.py @@ -4,7 +4,7 @@ from unittest import TestCase from srcgen import Formatter from gen_legalizer import get_runtime_typechecks, emit_runtime_typecheck from base.instructions import vselect, vsplit, isplit, iconcat, vconcat, \ - iconst, b1, icmp, copy # noqa + iconst, b1, icmp, copy, sextend, uextend, ireduce, fdemote, fpromote # noqa from base.legalize import narrow, expand # noqa from base.immediates import intcc # noqa from cdsl.typevar import TypeVar, TypeSet @@ -57,9 +57,32 @@ def equiv_check(tv1, tv2): # type: (TypeVar, TypeVar) -> CheckProducer return lambda typesets: format_check( typesets, - 'if Some({}).map(|t: Type| -> t.as_bool()) != ' + - 'Some({}).map(|t: Type| -> t.as_bool()) ' + - '{{\n return false;\n}};\n', tv1, tv2) + '{{\n' + + ' let a = {};\n' + + ' let b = {};\n' + + ' if a.is_none() || b.is_none() {{\n' + + ' return false;\n' + + ' }};\n' + + ' if a != b {{\n' + + ' return false;\n' + + ' }};\n' + + '}};\n', tv1, tv2) + + +def wider_check(tv1, tv2): + # type: (TypeVar, TypeVar) -> CheckProducer + return lambda typesets: format_check( + typesets, + '{{\n' + + ' let a = {};\n' + + ' let b = {};\n' + + ' if a.is_none() || b.is_none() {{\n' + + ' return false;\n' + + ' }};\n' + + ' if !a.wider_or_equal(b) {{\n' + + ' return false;\n' + + ' }};\n' + + '}};\n', tv1, tv2) def sequence(*args): @@ -138,8 +161,49 @@ class TestRuntimeChecks(TestCase): self.v5 << vselect(self.v1, self.v3, self.v4), ) x = XForm(r, r) + tv2_exp = 'Some({}).map(|t: Type| -> t.as_bool())'\ + .format(self.v2.get_typevar().name) + tv3_exp = 'Some({}).map(|t: Type| -> t.as_bool())'\ + .format(self.v3.get_typevar().name) self.check_yo_check( x, sequence(typeset_check(self.v3, ts), - equiv_check(self.v2.get_typevar(), - self.v3.get_typevar()))) + equiv_check(tv2_exp, tv3_exp))) + + def test_reduce_extend(self): + # type: () -> None + r = Rtl( + self.v1 << uextend(self.v0), + self.v2 << ireduce(self.v1), + self.v3 << sextend(self.v2), + ) + x = XForm(r, r) + + tv0_exp = 'Some({})'.format(self.v0.get_typevar().name) + tv1_exp = 'Some({})'.format(self.v1.get_typevar().name) + tv2_exp = 'Some({})'.format(self.v2.get_typevar().name) + tv3_exp = 'Some({})'.format(self.v3.get_typevar().name) + + self.check_yo_check( + x, sequence(wider_check(tv1_exp, tv0_exp), + wider_check(tv1_exp, tv2_exp), + wider_check(tv3_exp, tv2_exp))) + + def test_demote_promote(self): + # type: () -> None + r = Rtl( + self.v1 << fpromote(self.v0), + self.v2 << fdemote(self.v1), + self.v3 << fpromote(self.v2), + ) + x = XForm(r, r) + + tv0_exp = 'Some({})'.format(self.v0.get_typevar().name) + tv1_exp = 'Some({})'.format(self.v1.get_typevar().name) + tv2_exp = 'Some({})'.format(self.v2.get_typevar().name) + tv3_exp = 'Some({})'.format(self.v3.get_typevar().name) + + self.check_yo_check( + x, sequence(wider_check(tv1_exp, tv0_exp), + wider_check(tv1_exp, tv2_exp), + wider_check(tv3_exp, tv2_exp))) diff --git a/lib/cretonne/src/ir/types.rs b/lib/cretonne/src/ir/types.rs index e6379be2e6..0007154025 100644 --- a/lib/cretonne/src/ir/types.rs +++ b/lib/cretonne/src/ir/types.rs @@ -236,6 +236,13 @@ impl Type { pub fn index(self) -> usize { self.0 as usize } + + /// True iff: + /// 1) self.lane_count() == other.lane_count() and + /// 2) self.lane_bits() >= other.lane_bits() + pub fn wider_or_equal(self, other: Type) -> bool { + self.lane_count() == other.lane_count() && self.lane_bits() >= other.lane_bits() + } } impl Display for Type {