diff --git a/lib/cretonne/meta/cdsl/ast.py b/lib/cretonne/meta/cdsl/ast.py index 6efc492cf5..0d87bf8914 100644 --- a/lib/cretonne/meta/cdsl/ast.py +++ b/lib/cretonne/meta/cdsl/ast.py @@ -10,7 +10,7 @@ from .typevar import TypeVar from .predicates import IsEqual, And try: - from typing import Union, Tuple, Sequence, TYPE_CHECKING # noqa + from typing import Union, Tuple, Sequence, TYPE_CHECKING, Dict, List # noqa if TYPE_CHECKING: from .operands import ImmediateKind # noqa from .predicates import PredNode # noqa @@ -18,6 +18,19 @@ except ImportError: pass +def replace_var(arg, m): + # type: (Expr, Dict[Var, Var]) -> Expr + """ + Given a var v return either m[v] or a new variable v' (and remember + m[v]=v'). Otherwise return the argument unchanged + """ + if isinstance(arg, Var): + new_arg = m.get(arg, Var(arg.name)) # type: Var + m[arg] = new_arg + return new_arg + return arg + + class Def(object): """ An AST definition associates a set of variables with the values produced by @@ -60,6 +73,21 @@ class Def(object): return "({}) << {!s}".format( ', '.join(map(str, self.defs)), self.expr) + def copy(self, m): + # type: (Dict[Var, Var]) -> Def + """ + Return a copy of this Def with vars replaced with fresh variables, + in accordance with the map m. Update m as neccessary. + """ + new_expr = self.expr.copy(m) + new_defs = [] # type: List[Var] + for v in self.defs: + new_v = replace_var(v, m) + assert(isinstance(new_v, Var)) + new_defs.append(new_v) + + return Def(tuple(new_defs), new_expr) + class Expr(object): """ @@ -303,6 +331,15 @@ class Apply(Expr): return pred + def copy(self, m): + # type: (Dict[Var, Var]) -> Apply + """ + Return a copy of this Expr with vars replaced with fresh variables, + in accordance with the map m. Update m as neccessary. + """ + return Apply(self.inst, tuple(map(lambda e: replace_var(e, m), + self.args))) + class Enumerator(Expr): """ diff --git a/lib/cretonne/meta/cdsl/test_ti.py b/lib/cretonne/meta/cdsl/test_ti.py index 396d5ef844..71f439864e 100644 --- a/lib/cretonne/meta/cdsl/test_ti.py +++ b/lib/cretonne/meta/cdsl/test_ti.py @@ -6,30 +6,17 @@ from base.immediates import intcc from .typevar import TypeVar from .ast import Var, Def from .xform import Rtl, XForm -from .ti import ti_rtl, subst, TypeEnv, get_type_env +from .ti import ti_rtl, subst, TypeEnv, get_type_env, ConstrainTVsEqual from unittest import TestCase from functools import reduce try: from .ti import TypeMap, ConstraintList, VarMap, TypingOrError # noqa - from .ti import Constraint from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa except ImportError: TYPE_CHECKING = False -def sort_constr(c): - # type: (Constraint) -> Constraint - """ - Sort the 2 typevars in a constraint by name for comparison - """ - r = tuple(sorted(c, key=lambda y: y.name)) - if TYPE_CHECKING: - return cast(Constraint, r) - else: - return r - - def agree(me, other): # type: (TypeEnv, TypeEnv) -> bool """ @@ -63,13 +50,10 @@ def agree(me, other): return False # Translate our constraints using m, and sort - me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints] - me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr]) - + me_equiv_constr = sorted([constr.translate(m) + for constr in me.constraints]) # Sort other's constraints - other_equiv_constr = sorted([sort_constr(x) for x in other.constraints], - key=lambda y: y[0].name) - + other_equiv_constr = sorted(other.constraints) return me_equiv_constr == other_equiv_constr @@ -224,7 +208,7 @@ class TestRTL(TypeCheckingBaseTest): self.v3: txn, self.v4: txn, self.v5: txn, - }, [(ixn.as_bool(), txn.as_bool())])) + }, [ConstrainTVsEqual(ixn.as_bool(), txn.as_bool())])) def test_vselect_vsplits(self): # type: () -> None diff --git a/lib/cretonne/meta/cdsl/ti.py b/lib/cretonne/meta/cdsl/ti.py index ce2f05c750..fe602e1342 100644 --- a/lib/cretonne/meta/cdsl/ti.py +++ b/lib/cretonne/meta/cdsl/ti.py @@ -8,13 +8,12 @@ from itertools import product try: from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa - from typing import Iterable # noqa - from typing import cast, List + from typing import Iterable, List # noqa + from typing import cast from .xform import Rtl, XForm # noqa from .ast import Expr # noqa + from .typevar import TypeSet # noqa if TYPE_CHECKING: - Constraint = Tuple[TypeVar, TypeVar] - ConstraintList = List[Constraint] TypeMap = Dict[TypeVar, TypeVar] VarMap = Dict[Var, TypeVar] except ImportError: @@ -22,6 +21,122 @@ except ImportError: pass +class TypeConstraint(object): + """ + Base class for all runtime-emittable type constraints. + """ + + +class ConstrainTVsEqual(TypeConstraint): + """ + Constraint specifying that two derived type vars must have the same runtime + type. + """ + def __init__(self, tv1, tv2): + # type: (TypeVar, TypeVar) -> None + assert tv1.is_derived and tv2.is_derived + (self.tv1, self.tv2) = sorted([tv1, tv2], key=repr) + + def is_trivial(self): + # type: () -> bool + """ + Return true if this constrain is statically decidable. + """ + return self.tv1 == self.tv2 or \ + (self.tv1.singleton_type() is not None and + self.tv2.singleton_type() is not None) + + def translate(self, m): + # type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual + """ + Translate any TypeVars in the constraint according to the map m + """ + if isinstance(m, TypeEnv): + return ConstrainTVsEqual(m[self.tv1], m[self.tv2]) + else: + return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m)) + + def __eq__(self, other): + # type: (object) -> bool + if (not isinstance(other, ConstrainTVsEqual)): + return False + + return (self.tv1, self.tv2) == (other.tv1, other.tv2) + + def __hash__(self): + # type: () -> int + return hash((self.tv1, self.tv2)) + + def eval(self): + # type: () -> bool + """ + Evaluate this constraint. Should only be called when the constraint has + been translated to concrete types. + """ + assert self.tv1.singleton_type() is not None and \ + self.tv2.singleton_type() is not None + return self.tv1.singleton_type() == self.tv2.singleton_type() + + +class ConstrainTVInTypeset(TypeConstraint): + """ + Constraint specifying that a type var must belong to some typeset. + """ + def __init__(self, tv, ts): + # type: (TypeVar, TypeSet) -> None + assert not tv.is_derived and tv.name.startswith("typeof_") + self.tv = tv + self.ts = ts + + def is_trivial(self): + # type: () -> bool + """ + Return true if this constrain is statically decidable. + """ + tv_ts = self.tv.get_typeset().copy() + + # Trivially True + if (tv_ts.issubset(self.ts)): + return True + + # Trivially false + tv_ts &= self.ts + if (tv_ts.size() == 0): + return True + + return False + + def translate(self, m): + # type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset + """ + Translate any TypeVars in the constraint according to the map m + """ + if isinstance(m, TypeEnv): + return ConstrainTVInTypeset(m[self.tv], self.ts) + else: + return ConstrainTVInTypeset(subst(self.tv, m), self.ts) + + def __eq__(self, other): + # type: (object) -> bool + if (not isinstance(other, ConstrainTVInTypeset)): + return False + + return (self.tv, self.ts) == (other.tv, other.ts) + + def __hash__(self): + # type: () -> int + return hash((self.tv, self.ts)) + + def eval(self): + # type: () -> bool + """ + Evaluate this constraint. Should only be called when the constraint has + been translated to concrete types. + """ + assert self.tv.singleton_type() is not None + return self.tv.get_typeset().issubset(self.ts) + + class TypeEnv(object): """ Class encapsulating the neccessary book keeping for type inference. @@ -43,13 +158,13 @@ class TypeEnv(object): RANK_INTERNAL = 0 def __init__(self, arg=None): - # type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None + # type: (Optional[Tuple[TypeMap, List[TypeConstraint]]]) -> None self.ranks = {} # type: Dict[TypeVar, int] self.vars = set() # type: Set[Var] if arg is None: self.type_map = {} # type: TypeMap - self.constraints = [] # type: ConstraintList + self.constraints = [] # type: List[TypeConstraint] else: self.type_map, self.constraints = arg @@ -94,7 +209,9 @@ class TypeEnv(object): """ Add a new equivalence constraint between tv1 and tv2 """ - self.constraints.append((tv1, tv2)) + constr = ConstrainTVsEqual(tv1, tv2) + if (constr not in self.constraints): + self.constraints.append(constr) def get_uid(self): # type: () -> str @@ -206,15 +323,24 @@ class TypeEnv(object): """ vars_tvs = set([v.get_typevar() for v in self.vars]) new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]} - new_constraints = [(self[tv1], self[tv2]) - for (tv1, tv2) in self.constraints] - # Sanity: new constraints and the new type_map should only contain - # tvs associated with real vars - for (a, b) in new_constraints: - assert a.free_typevar() in vars_tvs and\ - b.free_typevar() in vars_tvs + new_constraints = [] # type: List[TypeConstraint] + for constr in self.constraints: + # Currently typeinference only generates ConstrainTVsEqual + # constraints + assert isinstance(constr, ConstrainTVsEqual) + constr = constr.translate(self) + if constr.is_trivial() or constr in new_constraints: + continue + + # Sanity: translated constraints should refer to only real vars + assert constr.tv1.free_typevar() in vars_tvs and\ + constr.tv2.free_typevar() in vars_tvs + + new_constraints.append(constr) + + # Sanity: translated typemap should refer to only real vars for (k, v) in new_type_map.items(): assert k in vars_tvs assert v.free_typevar() is None or v.free_typevar() in vars_tvs @@ -245,13 +371,13 @@ class TypeEnv(object): # Check if constraints are satisfied for this typing failed = None - for (tv1, tv2) in self.constraints: - tv1 = subst(tv1, m) - tv2 = subst(tv2, m) - assert tv1.get_typeset().size() == 1 and\ - tv2.get_typeset().size() == 1 - if (tv1.get_typeset() != tv2.get_typeset()): - failed = (tv1, tv2) + for constr in self.constraints: + # Currently typeinference only generates ConstrainTVsEqual + # constraints + assert isinstance(constr, ConstrainTVsEqual) + concrete_constr = constr.translate(m) + if not concrete_constr.eval(): + failed = concrete_constr break if (failed is not None): @@ -287,9 +413,10 @@ class TypeEnv(object): edges.add((v, v.base, "solid", v.derived_func)) v = v.base - for (a, b) in self.constraints: - assert a in nodes and b in nodes - edges.add((a, b, "dashed", None)) + for constr in self.constraints: + assert isinstance(constr, ConstrainTVsEqual) + assert constr.tv1 in nodes and constr.tv2 in nodes + edges.add((constr.tv1, constr.tv2, "dashed", None)) root_nodes = set([x for x in nodes if x not in self.type_map and not x.is_derived]) diff --git a/lib/cretonne/meta/cdsl/xform.py b/lib/cretonne/meta/cdsl/xform.py index fd3c1dff1f..f809a91ce6 100644 --- a/lib/cretonne/meta/cdsl/xform.py +++ b/lib/cretonne/meta/cdsl/xform.py @@ -41,6 +41,14 @@ class Rtl(object): # type: (*DefApply) -> None self.rtl = tuple(map(canonicalize_defapply, args)) + def copy(self, m): + # type: (Dict[Var, Var]) -> Rtl + """ + Return a copy of this rtl with all Vars substituted with copies or + according to m. Update m as neccessary. + """ + return Rtl(*[d.copy(m) for d in self.rtl]) + class XForm(object): """ diff --git a/lib/cretonne/meta/gen_instr.py b/lib/cretonne/meta/gen_instr.py index 2d67311ae2..242e5152d3 100644 --- a/lib/cretonne/meta/gen_instr.py +++ b/lib/cretonne/meta/gen_instr.py @@ -336,6 +336,25 @@ def get_constraint(op, ctrl_typevar, type_sets): return 'Same' +# TypeSet indexes are encoded in 8 bits, with `0xff` reserved. +typeset_limit = 0xff + + +def gen_typesets_table(fmt, type_sets): + # type: (srcgen.Formatter, UniqueTable) -> None + """ + Generate the table of ValueTypeSets described by type_sets. + """ + fmt.comment('Table of value type sets.') + assert len(type_sets.table) <= typeset_limit, "Too many type sets" + with fmt.indented( + 'const TYPE_SETS : [ValueTypeSet; {}] = [' + .format(len(type_sets.table)), '];'): + for ts in type_sets.table: + with fmt.indented('ValueTypeSet {', '},'): + ts.emit_fields(fmt) + + def gen_type_constraints(fmt, instrs): # type: (srcgen.Formatter, Sequence[Instruction]) -> None """ @@ -360,9 +379,6 @@ def gen_type_constraints(fmt, instrs): # Preload table with constraints for typical binops. operand_seqs.add(['Same'] * 3) - # TypeSet indexes are encoded in 8 bits, with `0xff` reserved. - typeset_limit = 0xff - fmt.comment('Table of opcode constraints.') with fmt.indented( 'const OPCODE_CONSTRAINTS : [OpcodeConstraints; {}] = [' @@ -418,14 +434,7 @@ def gen_type_constraints(fmt, instrs): fmt.line('typeset_offset: {},'.format(ctrl_typeset)) fmt.line('constraint_offset: {},'.format(offset)) - fmt.comment('Table of value type sets.') - assert len(type_sets.table) <= typeset_limit, "Too many type sets" - with fmt.indented( - 'const TYPE_SETS : [ValueTypeSet; {}] = [' - .format(len(type_sets.table)), '];'): - for ts in type_sets.table: - with fmt.indented('ValueTypeSet {', '},'): - ts.emit_fields(fmt) + gen_typesets_table(fmt, type_sets) fmt.comment('Table of operand constraint sequences.') with fmt.indented( diff --git a/lib/cretonne/meta/gen_legalizer.py b/lib/cretonne/meta/gen_legalizer.py index 1376e3bc21..ce2f559ef2 100644 --- a/lib/cretonne/meta/gen_legalizer.py +++ b/lib/cretonne/meta/gen_legalizer.py @@ -11,16 +11,116 @@ from __future__ import absolute_import from srcgen import Formatter from base import legalize, instructions from cdsl.ast import Var +from cdsl.ti import ti_rtl, TypeEnv, get_type_env, ConstrainTVsEqual,\ + ConstrainTVInTypeset +from unique_table import UniqueTable +from gen_instr import gen_typesets_table +from cdsl.typevar import TypeVar try: - from typing import Sequence # noqa + from typing import Sequence, List, Dict # noqa from cdsl.isa import TargetISA # noqa from cdsl.ast import Def # noqa from cdsl.xform import XForm, XFormGroup # noqa + from cdsl.typevar import TypeSet # noqa + from cdsl.ti import TypeConstraint # noqa except ImportError: pass +def get_runtime_typechecks(xform): + # type: (XForm) -> List[TypeConstraint] + """ + Given a XForm build a list of runtime type checks neccessary to determine + if it applies. We have 2 types of runtime checks: + 1) typevar tv belongs to typeset T - needed for free tvs whose + typeset is constrainted by their use in the dst pattern + + 2) tv1 == tv2 where tv1 and tv2 are derived TVs - caused by unification + of non-bijective functions + """ + check_l = [] # type: List[TypeConstraint] + + # 1) Perform ti only on the source RTL. Accumulate any free tvs that have a + # different inferred type in src, compared to the type inferred for both + # src and dst. + symtab = {} # type: Dict[Var, Var] + src_copy = xform.src.copy(symtab) + src_typenv = get_type_env(ti_rtl(src_copy, TypeEnv())) + + for v in xform.ti.vars: + if not v.has_free_typevar(): + continue + + # In rust the local variable containing a free TV associated with var v + # has name typeof_v. We rely on the python TVs having the same name. + assert "typeof_{}".format(v) == xform.ti[v].name + + if v not in symtab: + # We can have singleton vars defined only on dst. Ignore them + assert v.get_typevar().singleton_type() is not None + continue + + src_ts = src_typenv[symtab[v]].get_typeset() + xform_ts = xform.ti[v].get_typeset() + + assert xform_ts.issubset(src_ts) + if src_ts != xform_ts: + check_l.append(ConstrainTVInTypeset(xform.ti[v], xform_ts)) + + # 2,3) Add any constraints that appear in xform.ti + check_l.extend(xform.ti.constraints) + + return check_l + + +def emit_runtime_typecheck(check, fmt, type_sets): + # type: (TypeConstraint, Formatter, UniqueTable) -> None + """ + Emit rust code for the given check. + """ + def build_derived_expr(tv): + # type: (TypeVar) -> str + if not tv.is_derived: + assert tv.name.startswith('typeof_') + return "Some({})".format(tv.name) + + base_exp = build_derived_expr(tv.base) + if (tv.derived_func == TypeVar.LANEOF): + return "{}.map(|t: Type| -> t.lane_type())".format(base_exp) + elif (tv.derived_func == TypeVar.ASBOOL): + return "{}.map(|t: Type| -> t.as_bool())".format(base_exp) + elif (tv.derived_func == TypeVar.HALFWIDTH): + return "{}.and_then(|t: Type| -> t.half_width())".format(base_exp) + elif (tv.derived_func == TypeVar.DOUBLEWIDTH): + return "{}.and_then(|t: Type| -> t.double_width())"\ + .format(base_exp) + elif (tv.derived_func == TypeVar.HALFVECTOR): + return "{}.and_then(|t: Type| -> t.half_vector())".format(base_exp) + elif (tv.derived_func == TypeVar.DOUBLEVECTOR): + return "{}.and_then(|t: Type| -> t.by(2))".format(base_exp) + else: + assert False, "Unknown derived function {}".format(tv.derived_func) + + if (isinstance(check, ConstrainTVInTypeset)): + tv = check.tv.name + if check.ts not in type_sets.index: + type_sets.add(check.ts) + ts = type_sets.index[check.ts] + + fmt.comment("{} must belong to {}".format(tv, check.ts)) + with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv), + '};'): + fmt.line('return false;') + elif (isinstance(check, ConstrainTVsEqual)): + tv1 = build_derived_expr(check.tv1) + tv2 = build_derived_expr(check.tv2) + with fmt.indented('if {} != {} {{'.format(tv1, tv2), '};'): + fmt.line('return false;') + else: + assert False, "Unknown check {}".format(check) + + def unwrap_inst(iref, node, fmt): # type: (str, Def, Formatter) -> bool """ @@ -183,8 +283,8 @@ def emit_dst_inst(node, fmt): fmt.line('pos.next_inst();') -def gen_xform(xform, fmt): - # type: (XForm, Formatter) -> None +def gen_xform(xform, fmt, type_sets): + # type: (XForm, Formatter, UniqueTable) -> None """ Emit code for `xform`, assuming the the opcode of xform's root instruction has already been matched. @@ -203,6 +303,10 @@ def gen_xform(xform, fmt): instp = xform.src.rtl[0].expr.inst_predicate() assert instp is None, "Instruction predicates not supported in legalizer" + # Emit any runtime checks. + for check in get_runtime_typechecks(xform): + emit_runtime_typecheck(check, fmt, type_sets) + # Emit the destination pattern. for dst in xform.dst.rtl: emit_dst_inst(dst, fmt) @@ -213,8 +317,8 @@ def gen_xform(xform, fmt): fmt.line('assert_eq!(pos.remove_inst(), inst);') -def gen_xform_group(xgrp, fmt): - # type: (XFormGroup, Formatter) -> None +def gen_xform_group(xgrp, fmt, type_sets): + # type: (XFormGroup, Formatter, UniqueTable) -> None fmt.doc_comment("Legalize the instruction pointed to by `pos`.") fmt.line('#[allow(unused_variables,unused_assignments)]') with fmt.indented( @@ -231,7 +335,7 @@ def gen_xform_group(xgrp, fmt): inst = xform.src.rtl[0].expr.inst with fmt.indented( 'Opcode::{} => {{'.format(inst.camel_name), '}'): - gen_xform(xform, fmt) + gen_xform(xform, fmt, type_sets) # We'll assume there are uncovered opcodes. fmt.line('_ => return false,') fmt.line('true') @@ -240,6 +344,11 @@ def gen_xform_group(xgrp, fmt): def generate(isas, out_dir): # type: (Sequence[TargetISA], str) -> None fmt = Formatter() - gen_xform_group(legalize.narrow, fmt) - gen_xform_group(legalize.expand, fmt) + # Table of TypeSet instances + type_sets = UniqueTable() + + gen_xform_group(legalize.narrow, fmt, type_sets) + gen_xform_group(legalize.expand, fmt, type_sets) + + gen_typesets_table(fmt, type_sets) fmt.update_file('legalizer.rs', out_dir) diff --git a/lib/cretonne/meta/test_gen_legalizer.py b/lib/cretonne/meta/test_gen_legalizer.py new file mode 100644 index 0000000000..38f26959f4 --- /dev/null +++ b/lib/cretonne/meta/test_gen_legalizer.py @@ -0,0 +1,145 @@ +import doctest +import gen_legalizer +from unittest import TestCase +from srcgen import Formatter +from gen_legalizer import get_runtime_typechecks, emit_runtime_typecheck +from base.instructions import vselect, vsplit, isplit, iconcat, vconcat, \ + iconst, b1, icmp, copy # noqa +from base.legalize import narrow, expand # noqa +from base.immediates import intcc # noqa +from cdsl.typevar import TypeVar, TypeSet +from cdsl.ast import Var, Def # noqa +from cdsl.xform import Rtl, XForm # noqa +from cdsl.ti import ti_rtl, subst, TypeEnv, get_type_env # noqa +from unique_table import UniqueTable +from functools import reduce + +try: + from typing import Callable, TYPE_CHECKING, Iterable, Any # noqa + if TYPE_CHECKING: + CheckProducer = Callable[[UniqueTable], str] +except ImportError: + TYPE_CHECKING = False + + +def load_tests(loader, tests, ignore): + # type: (Any, Any, Any) -> Any + tests.addTests(doctest.DocTestSuite(gen_legalizer)) + return tests + + +def format_check(typesets, s, *args): + # type: (...) -> str + def transform(x): + # type: (Any) -> str + if isinstance(x, TypeSet): + return str(typesets.index[x]) + elif isinstance(x, TypeVar): + assert not x.is_derived + return x.name + else: + return str(x) + + dummy_s = s # type: str + args = tuple(map(lambda x: transform(x), args)) + return dummy_s.format(*args) + + +def typeset_check(v, ts): + # type: (Var, TypeSet) -> CheckProducer + return lambda typesets: format_check( + typesets, + 'if !TYPE_SETS[{}].contains(typeof_{}) ' + + '{{\n return false;\n}};\n', ts, v) + + +def equiv_check(tv1, tv2): + # type: (TypeVar, TypeVar) -> CheckProducer + return lambda typesets: format_check( + typesets, + 'if Some({}).map(|t: Type| -> t.as_bool()) != ' + + 'Some({}).map(|t: Type| -> t.as_bool()) ' + + '{{\n return false;\n}};\n', tv1, tv2) + + +def sequence(*args): + # type: (...) -> CheckProducer + dummy = args # type: Iterable[CheckProducer] + + def sequenceF(typesets): + # type: (UniqueTable) -> str + def strconcat(acc, el): + # type: (str, CheckProducer) -> str + return acc + el(typesets) + + return reduce(strconcat, dummy, "") + return sequenceF + + +class TestRuntimeChecks(TestCase): + + def setUp(self): + # type: () -> None + self.v0 = Var("v0") + self.v1 = Var("v1") + self.v2 = Var("v2") + self.v3 = Var("v3") + self.v4 = Var("v4") + self.v5 = Var("v5") + self.v6 = Var("v6") + self.v7 = Var("v7") + self.v8 = Var("v8") + self.v9 = Var("v9") + self.imm0 = Var("imm0") + self.IxN_nonscalar = TypeVar("IxN_nonscalar", "", ints=True, + scalars=False, simd=True) + self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True, + scalars=False, simd=True) + self.b1 = TypeVar.singleton(b1) + + def check_yo_check(self, xform, expected_f): + # type: (XForm, CheckProducer) -> None + fmt = Formatter() + type_sets = UniqueTable() + for check in get_runtime_typechecks(xform): + emit_runtime_typecheck(check, fmt, type_sets) + + # Remove comments + got = "".join([l for l in fmt.lines if not l.strip().startswith("//")]) + expected = expected_f(type_sets) + self.assertEqual(got, expected) + + def test_width_check(self): + # type: () -> None + x = XForm(Rtl(self.v0 << copy(self.v1)), + Rtl((self.v2, self.v3) << isplit(self.v1), + self.v0 << iconcat(self.v2, self.v3))) + + WideInt = TypeSet(lanes=(1, 256), ints=(16, 64)) + self.check_yo_check(x, typeset_check(self.v1, WideInt)) + + def test_lanes_check(self): + # type: () -> None + x = XForm(Rtl(self.v0 << copy(self.v1)), + Rtl((self.v2, self.v3) << vsplit(self.v1), + self.v0 << vconcat(self.v2, self.v3))) + + WideVec = TypeSet(lanes=(2, 256), ints=(8, 64), floats=(32, 64), + bools=(1, 64)) + self.check_yo_check(x, typeset_check(self.v1, WideVec)) + + def test_vselect_imm(self): + # type: () -> None + ts = TypeSet(lanes=(2, 256), ints=(8, 64), + floats=(32, 64), bools=(8, 64)) + r = Rtl( + self.v0 << iconst(self.imm0), + self.v1 << icmp(intcc.eq, self.v2, self.v0), + self.v5 << vselect(self.v1, self.v3, self.v4), + ) + x = XForm(r, r) + + self.check_yo_check( + x, sequence(typeset_check(self.v3, ts), + equiv_check(self.v2.get_typevar(), + self.v3.get_typevar()))) diff --git a/lib/cretonne/src/ir/instructions.rs b/lib/cretonne/src/ir/instructions.rs index 35a200de08..cf914b853a 100644 --- a/lib/cretonne/src/ir/instructions.rs +++ b/lib/cretonne/src/ir/instructions.rs @@ -506,10 +506,14 @@ type BitSet16 = BitSet; /// A value type set describes the permitted set of types for a type variable. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct ValueTypeSet { - lanes: BitSet16, - ints: BitSet8, - floats: BitSet8, - bools: BitSet8, + /// Allowed lane sizes + pub lanes: BitSet16, + /// Allowed int widths + pub ints: BitSet8, + /// Allowed float widths + pub floats: BitSet8, + /// Allowed bool widths + pub bools: BitSet8, } impl ValueTypeSet { diff --git a/lib/cretonne/src/legalizer/mod.rs b/lib/cretonne/src/legalizer/mod.rs index 803f5808c2..f21f7e1d91 100644 --- a/lib/cretonne/src/legalizer/mod.rs +++ b/lib/cretonne/src/legalizer/mod.rs @@ -18,6 +18,8 @@ use flowgraph::ControlFlowGraph; use ir::{Function, Cursor, DataFlowGraph, InstructionData, Opcode, InstBuilder}; use ir::condcodes::IntCC; use isa::{TargetIsa, Legalize}; +use bitset::BitSet; +use ir::instructions::ValueTypeSet; mod boundary; mod split;