Emit runtime type checks in legalizer.rs (#112)

* Emit runtime type checks in legalizer.rs
This commit is contained in:
d1m0
2017-07-10 15:28:32 -07:00
committed by Jakob Stoklund Olesen
parent 464f2625d4
commit 98f822f347
9 changed files with 494 additions and 69 deletions

View File

@@ -10,7 +10,7 @@ from .typevar import TypeVar
from .predicates import IsEqual, And from .predicates import IsEqual, And
try: try:
from typing import Union, Tuple, Sequence, TYPE_CHECKING # noqa from typing import Union, Tuple, Sequence, TYPE_CHECKING, Dict, List # noqa
if TYPE_CHECKING: if TYPE_CHECKING:
from .operands import ImmediateKind # noqa from .operands import ImmediateKind # noqa
from .predicates import PredNode # noqa from .predicates import PredNode # noqa
@@ -18,6 +18,19 @@ except ImportError:
pass pass
def replace_var(arg, m):
# type: (Expr, Dict[Var, Var]) -> Expr
"""
Given a var v return either m[v] or a new variable v' (and remember
m[v]=v'). Otherwise return the argument unchanged
"""
if isinstance(arg, Var):
new_arg = m.get(arg, Var(arg.name)) # type: Var
m[arg] = new_arg
return new_arg
return arg
class Def(object): class Def(object):
""" """
An AST definition associates a set of variables with the values produced by An AST definition associates a set of variables with the values produced by
@@ -60,6 +73,21 @@ class Def(object):
return "({}) << {!s}".format( return "({}) << {!s}".format(
', '.join(map(str, self.defs)), self.expr) ', '.join(map(str, self.defs)), self.expr)
def copy(self, m):
# type: (Dict[Var, Var]) -> Def
"""
Return a copy of this Def with vars replaced with fresh variables,
in accordance with the map m. Update m as neccessary.
"""
new_expr = self.expr.copy(m)
new_defs = [] # type: List[Var]
for v in self.defs:
new_v = replace_var(v, m)
assert(isinstance(new_v, Var))
new_defs.append(new_v)
return Def(tuple(new_defs), new_expr)
class Expr(object): class Expr(object):
""" """
@@ -303,6 +331,15 @@ class Apply(Expr):
return pred return pred
def copy(self, m):
# type: (Dict[Var, Var]) -> Apply
"""
Return a copy of this Expr with vars replaced with fresh variables,
in accordance with the map m. Update m as neccessary.
"""
return Apply(self.inst, tuple(map(lambda e: replace_var(e, m),
self.args)))
class Enumerator(Expr): class Enumerator(Expr):
""" """

View File

@@ -6,30 +6,17 @@ from base.immediates import intcc
from .typevar import TypeVar from .typevar import TypeVar
from .ast import Var, Def from .ast import Var, Def
from .xform import Rtl, XForm from .xform import Rtl, XForm
from .ti import ti_rtl, subst, TypeEnv, get_type_env from .ti import ti_rtl, subst, TypeEnv, get_type_env, ConstrainTVsEqual
from unittest import TestCase from unittest import TestCase
from functools import reduce from functools import reduce
try: try:
from .ti import TypeMap, ConstraintList, VarMap, TypingOrError # noqa from .ti import TypeMap, ConstraintList, VarMap, TypingOrError # noqa
from .ti import Constraint
from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa
except ImportError: except ImportError:
TYPE_CHECKING = False TYPE_CHECKING = False
def sort_constr(c):
# type: (Constraint) -> Constraint
"""
Sort the 2 typevars in a constraint by name for comparison
"""
r = tuple(sorted(c, key=lambda y: y.name))
if TYPE_CHECKING:
return cast(Constraint, r)
else:
return r
def agree(me, other): def agree(me, other):
# type: (TypeEnv, TypeEnv) -> bool # type: (TypeEnv, TypeEnv) -> bool
""" """
@@ -63,13 +50,10 @@ def agree(me, other):
return False return False
# Translate our constraints using m, and sort # Translate our constraints using m, and sort
me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints] me_equiv_constr = sorted([constr.translate(m)
me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr]) for constr in me.constraints])
# Sort other's constraints # Sort other's constraints
other_equiv_constr = sorted([sort_constr(x) for x in other.constraints], other_equiv_constr = sorted(other.constraints)
key=lambda y: y[0].name)
return me_equiv_constr == other_equiv_constr return me_equiv_constr == other_equiv_constr
@@ -224,7 +208,7 @@ class TestRTL(TypeCheckingBaseTest):
self.v3: txn, self.v3: txn,
self.v4: txn, self.v4: txn,
self.v5: txn, self.v5: txn,
}, [(ixn.as_bool(), txn.as_bool())])) }, [ConstrainTVsEqual(ixn.as_bool(), txn.as_bool())]))
def test_vselect_vsplits(self): def test_vselect_vsplits(self):
# type: () -> None # type: () -> None

View File

@@ -8,13 +8,12 @@ from itertools import product
try: try:
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
from typing import Iterable # noqa from typing import Iterable, List # noqa
from typing import cast, List from typing import cast
from .xform import Rtl, XForm # noqa from .xform import Rtl, XForm # noqa
from .ast import Expr # noqa from .ast import Expr # noqa
from .typevar import TypeSet # noqa
if TYPE_CHECKING: if TYPE_CHECKING:
Constraint = Tuple[TypeVar, TypeVar]
ConstraintList = List[Constraint]
TypeMap = Dict[TypeVar, TypeVar] TypeMap = Dict[TypeVar, TypeVar]
VarMap = Dict[Var, TypeVar] VarMap = Dict[Var, TypeVar]
except ImportError: except ImportError:
@@ -22,6 +21,122 @@ except ImportError:
pass pass
class TypeConstraint(object):
"""
Base class for all runtime-emittable type constraints.
"""
class ConstrainTVsEqual(TypeConstraint):
"""
Constraint specifying that two derived type vars must have the same runtime
type.
"""
def __init__(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
assert tv1.is_derived and tv2.is_derived
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
return self.tv1 == self.tv2 or \
(self.tv1.singleton_type() is not None and
self.tv2.singleton_type() is not None)
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVsEqual(m[self.tv1], m[self.tv2])
else:
return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m))
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVsEqual)):
return False
return (self.tv1, self.tv2) == (other.tv1, other.tv2)
def __hash__(self):
# type: () -> int
return hash((self.tv1, self.tv2))
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert self.tv1.singleton_type() is not None and \
self.tv2.singleton_type() is not None
return self.tv1.singleton_type() == self.tv2.singleton_type()
class ConstrainTVInTypeset(TypeConstraint):
"""
Constraint specifying that a type var must belong to some typeset.
"""
def __init__(self, tv, ts):
# type: (TypeVar, TypeSet) -> None
assert not tv.is_derived and tv.name.startswith("typeof_")
self.tv = tv
self.ts = ts
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
tv_ts = self.tv.get_typeset().copy()
# Trivially True
if (tv_ts.issubset(self.ts)):
return True
# Trivially false
tv_ts &= self.ts
if (tv_ts.size() == 0):
return True
return False
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVInTypeset(m[self.tv], self.ts)
else:
return ConstrainTVInTypeset(subst(self.tv, m), self.ts)
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVInTypeset)):
return False
return (self.tv, self.ts) == (other.tv, other.ts)
def __hash__(self):
# type: () -> int
return hash((self.tv, self.ts))
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert self.tv.singleton_type() is not None
return self.tv.get_typeset().issubset(self.ts)
class TypeEnv(object): class TypeEnv(object):
""" """
Class encapsulating the neccessary book keeping for type inference. Class encapsulating the neccessary book keeping for type inference.
@@ -43,13 +158,13 @@ class TypeEnv(object):
RANK_INTERNAL = 0 RANK_INTERNAL = 0
def __init__(self, arg=None): def __init__(self, arg=None):
# type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None # type: (Optional[Tuple[TypeMap, List[TypeConstraint]]]) -> None
self.ranks = {} # type: Dict[TypeVar, int] self.ranks = {} # type: Dict[TypeVar, int]
self.vars = set() # type: Set[Var] self.vars = set() # type: Set[Var]
if arg is None: if arg is None:
self.type_map = {} # type: TypeMap self.type_map = {} # type: TypeMap
self.constraints = [] # type: ConstraintList self.constraints = [] # type: List[TypeConstraint]
else: else:
self.type_map, self.constraints = arg self.type_map, self.constraints = arg
@@ -94,7 +209,9 @@ class TypeEnv(object):
""" """
Add a new equivalence constraint between tv1 and tv2 Add a new equivalence constraint between tv1 and tv2
""" """
self.constraints.append((tv1, tv2)) constr = ConstrainTVsEqual(tv1, tv2)
if (constr not in self.constraints):
self.constraints.append(constr)
def get_uid(self): def get_uid(self):
# type: () -> str # type: () -> str
@@ -206,15 +323,24 @@ class TypeEnv(object):
""" """
vars_tvs = set([v.get_typevar() for v in self.vars]) vars_tvs = set([v.get_typevar() for v in self.vars])
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]} new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
new_constraints = [(self[tv1], self[tv2])
for (tv1, tv2) in self.constraints]
# Sanity: new constraints and the new type_map should only contain new_constraints = [] # type: List[TypeConstraint]
# tvs associated with real vars for constr in self.constraints:
for (a, b) in new_constraints: # Currently typeinference only generates ConstrainTVsEqual
assert a.free_typevar() in vars_tvs and\ # constraints
b.free_typevar() in vars_tvs assert isinstance(constr, ConstrainTVsEqual)
constr = constr.translate(self)
if constr.is_trivial() or constr in new_constraints:
continue
# Sanity: translated constraints should refer to only real vars
assert constr.tv1.free_typevar() in vars_tvs and\
constr.tv2.free_typevar() in vars_tvs
new_constraints.append(constr)
# Sanity: translated typemap should refer to only real vars
for (k, v) in new_type_map.items(): for (k, v) in new_type_map.items():
assert k in vars_tvs assert k in vars_tvs
assert v.free_typevar() is None or v.free_typevar() in vars_tvs assert v.free_typevar() is None or v.free_typevar() in vars_tvs
@@ -245,13 +371,13 @@ class TypeEnv(object):
# Check if constraints are satisfied for this typing # Check if constraints are satisfied for this typing
failed = None failed = None
for (tv1, tv2) in self.constraints: for constr in self.constraints:
tv1 = subst(tv1, m) # Currently typeinference only generates ConstrainTVsEqual
tv2 = subst(tv2, m) # constraints
assert tv1.get_typeset().size() == 1 and\ assert isinstance(constr, ConstrainTVsEqual)
tv2.get_typeset().size() == 1 concrete_constr = constr.translate(m)
if (tv1.get_typeset() != tv2.get_typeset()): if not concrete_constr.eval():
failed = (tv1, tv2) failed = concrete_constr
break break
if (failed is not None): if (failed is not None):
@@ -287,9 +413,10 @@ class TypeEnv(object):
edges.add((v, v.base, "solid", v.derived_func)) edges.add((v, v.base, "solid", v.derived_func))
v = v.base v = v.base
for (a, b) in self.constraints: for constr in self.constraints:
assert a in nodes and b in nodes assert isinstance(constr, ConstrainTVsEqual)
edges.add((a, b, "dashed", None)) assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", None))
root_nodes = set([x for x in nodes root_nodes = set([x for x in nodes
if x not in self.type_map and not x.is_derived]) if x not in self.type_map and not x.is_derived])

View File

@@ -41,6 +41,14 @@ class Rtl(object):
# type: (*DefApply) -> None # type: (*DefApply) -> None
self.rtl = tuple(map(canonicalize_defapply, args)) self.rtl = tuple(map(canonicalize_defapply, args))
def copy(self, m):
# type: (Dict[Var, Var]) -> Rtl
"""
Return a copy of this rtl with all Vars substituted with copies or
according to m. Update m as neccessary.
"""
return Rtl(*[d.copy(m) for d in self.rtl])
class XForm(object): class XForm(object):
""" """

View File

@@ -336,6 +336,25 @@ def get_constraint(op, ctrl_typevar, type_sets):
return 'Same' return 'Same'
# TypeSet indexes are encoded in 8 bits, with `0xff` reserved.
typeset_limit = 0xff
def gen_typesets_table(fmt, type_sets):
# type: (srcgen.Formatter, UniqueTable) -> None
"""
Generate the table of ValueTypeSets described by type_sets.
"""
fmt.comment('Table of value type sets.')
assert len(type_sets.table) <= typeset_limit, "Too many type sets"
with fmt.indented(
'const TYPE_SETS : [ValueTypeSet; {}] = ['
.format(len(type_sets.table)), '];'):
for ts in type_sets.table:
with fmt.indented('ValueTypeSet {', '},'):
ts.emit_fields(fmt)
def gen_type_constraints(fmt, instrs): def gen_type_constraints(fmt, instrs):
# type: (srcgen.Formatter, Sequence[Instruction]) -> None # type: (srcgen.Formatter, Sequence[Instruction]) -> None
""" """
@@ -360,9 +379,6 @@ def gen_type_constraints(fmt, instrs):
# Preload table with constraints for typical binops. # Preload table with constraints for typical binops.
operand_seqs.add(['Same'] * 3) operand_seqs.add(['Same'] * 3)
# TypeSet indexes are encoded in 8 bits, with `0xff` reserved.
typeset_limit = 0xff
fmt.comment('Table of opcode constraints.') fmt.comment('Table of opcode constraints.')
with fmt.indented( with fmt.indented(
'const OPCODE_CONSTRAINTS : [OpcodeConstraints; {}] = [' 'const OPCODE_CONSTRAINTS : [OpcodeConstraints; {}] = ['
@@ -418,14 +434,7 @@ def gen_type_constraints(fmt, instrs):
fmt.line('typeset_offset: {},'.format(ctrl_typeset)) fmt.line('typeset_offset: {},'.format(ctrl_typeset))
fmt.line('constraint_offset: {},'.format(offset)) fmt.line('constraint_offset: {},'.format(offset))
fmt.comment('Table of value type sets.') gen_typesets_table(fmt, type_sets)
assert len(type_sets.table) <= typeset_limit, "Too many type sets"
with fmt.indented(
'const TYPE_SETS : [ValueTypeSet; {}] = ['
.format(len(type_sets.table)), '];'):
for ts in type_sets.table:
with fmt.indented('ValueTypeSet {', '},'):
ts.emit_fields(fmt)
fmt.comment('Table of operand constraint sequences.') fmt.comment('Table of operand constraint sequences.')
with fmt.indented( with fmt.indented(

View File

@@ -11,16 +11,116 @@ from __future__ import absolute_import
from srcgen import Formatter from srcgen import Formatter
from base import legalize, instructions from base import legalize, instructions
from cdsl.ast import Var from cdsl.ast import Var
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, ConstrainTVsEqual,\
ConstrainTVInTypeset
from unique_table import UniqueTable
from gen_instr import gen_typesets_table
from cdsl.typevar import TypeVar
try: try:
from typing import Sequence # noqa from typing import Sequence, List, Dict # noqa
from cdsl.isa import TargetISA # noqa from cdsl.isa import TargetISA # noqa
from cdsl.ast import Def # noqa from cdsl.ast import Def # noqa
from cdsl.xform import XForm, XFormGroup # noqa from cdsl.xform import XForm, XFormGroup # noqa
from cdsl.typevar import TypeSet # noqa
from cdsl.ti import TypeConstraint # noqa
except ImportError: except ImportError:
pass pass
def get_runtime_typechecks(xform):
# type: (XForm) -> List[TypeConstraint]
"""
Given a XForm build a list of runtime type checks neccessary to determine
if it applies. We have 2 types of runtime checks:
1) typevar tv belongs to typeset T - needed for free tvs whose
typeset is constrainted by their use in the dst pattern
2) tv1 == tv2 where tv1 and tv2 are derived TVs - caused by unification
of non-bijective functions
"""
check_l = [] # type: List[TypeConstraint]
# 1) Perform ti only on the source RTL. Accumulate any free tvs that have a
# different inferred type in src, compared to the type inferred for both
# src and dst.
symtab = {} # type: Dict[Var, Var]
src_copy = xform.src.copy(symtab)
src_typenv = get_type_env(ti_rtl(src_copy, TypeEnv()))
for v in xform.ti.vars:
if not v.has_free_typevar():
continue
# In rust the local variable containing a free TV associated with var v
# has name typeof_v. We rely on the python TVs having the same name.
assert "typeof_{}".format(v) == xform.ti[v].name
if v not in symtab:
# We can have singleton vars defined only on dst. Ignore them
assert v.get_typevar().singleton_type() is not None
continue
src_ts = src_typenv[symtab[v]].get_typeset()
xform_ts = xform.ti[v].get_typeset()
assert xform_ts.issubset(src_ts)
if src_ts != xform_ts:
check_l.append(ConstrainTVInTypeset(xform.ti[v], xform_ts))
# 2,3) Add any constraints that appear in xform.ti
check_l.extend(xform.ti.constraints)
return check_l
def emit_runtime_typecheck(check, fmt, type_sets):
# type: (TypeConstraint, Formatter, UniqueTable) -> None
"""
Emit rust code for the given check.
"""
def build_derived_expr(tv):
# type: (TypeVar) -> str
if not tv.is_derived:
assert tv.name.startswith('typeof_')
return "Some({})".format(tv.name)
base_exp = build_derived_expr(tv.base)
if (tv.derived_func == TypeVar.LANEOF):
return "{}.map(|t: Type| -> t.lane_type())".format(base_exp)
elif (tv.derived_func == TypeVar.ASBOOL):
return "{}.map(|t: Type| -> t.as_bool())".format(base_exp)
elif (tv.derived_func == TypeVar.HALFWIDTH):
return "{}.and_then(|t: Type| -> t.half_width())".format(base_exp)
elif (tv.derived_func == TypeVar.DOUBLEWIDTH):
return "{}.and_then(|t: Type| -> t.double_width())"\
.format(base_exp)
elif (tv.derived_func == TypeVar.HALFVECTOR):
return "{}.and_then(|t: Type| -> t.half_vector())".format(base_exp)
elif (tv.derived_func == TypeVar.DOUBLEVECTOR):
return "{}.and_then(|t: Type| -> t.by(2))".format(base_exp)
else:
assert False, "Unknown derived function {}".format(tv.derived_func)
if (isinstance(check, ConstrainTVInTypeset)):
tv = check.tv.name
if check.ts not in type_sets.index:
type_sets.add(check.ts)
ts = type_sets.index[check.ts]
fmt.comment("{} must belong to {}".format(tv, check.ts))
with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv),
'};'):
fmt.line('return false;')
elif (isinstance(check, ConstrainTVsEqual)):
tv1 = build_derived_expr(check.tv1)
tv2 = build_derived_expr(check.tv2)
with fmt.indented('if {} != {} {{'.format(tv1, tv2), '};'):
fmt.line('return false;')
else:
assert False, "Unknown check {}".format(check)
def unwrap_inst(iref, node, fmt): def unwrap_inst(iref, node, fmt):
# type: (str, Def, Formatter) -> bool # type: (str, Def, Formatter) -> bool
""" """
@@ -183,8 +283,8 @@ def emit_dst_inst(node, fmt):
fmt.line('pos.next_inst();') fmt.line('pos.next_inst();')
def gen_xform(xform, fmt): def gen_xform(xform, fmt, type_sets):
# type: (XForm, Formatter) -> None # type: (XForm, Formatter, UniqueTable) -> None
""" """
Emit code for `xform`, assuming the the opcode of xform's root instruction Emit code for `xform`, assuming the the opcode of xform's root instruction
has already been matched. has already been matched.
@@ -203,6 +303,10 @@ def gen_xform(xform, fmt):
instp = xform.src.rtl[0].expr.inst_predicate() instp = xform.src.rtl[0].expr.inst_predicate()
assert instp is None, "Instruction predicates not supported in legalizer" assert instp is None, "Instruction predicates not supported in legalizer"
# Emit any runtime checks.
for check in get_runtime_typechecks(xform):
emit_runtime_typecheck(check, fmt, type_sets)
# Emit the destination pattern. # Emit the destination pattern.
for dst in xform.dst.rtl: for dst in xform.dst.rtl:
emit_dst_inst(dst, fmt) emit_dst_inst(dst, fmt)
@@ -213,8 +317,8 @@ def gen_xform(xform, fmt):
fmt.line('assert_eq!(pos.remove_inst(), inst);') fmt.line('assert_eq!(pos.remove_inst(), inst);')
def gen_xform_group(xgrp, fmt): def gen_xform_group(xgrp, fmt, type_sets):
# type: (XFormGroup, Formatter) -> None # type: (XFormGroup, Formatter, UniqueTable) -> None
fmt.doc_comment("Legalize the instruction pointed to by `pos`.") fmt.doc_comment("Legalize the instruction pointed to by `pos`.")
fmt.line('#[allow(unused_variables,unused_assignments)]') fmt.line('#[allow(unused_variables,unused_assignments)]')
with fmt.indented( with fmt.indented(
@@ -231,7 +335,7 @@ def gen_xform_group(xgrp, fmt):
inst = xform.src.rtl[0].expr.inst inst = xform.src.rtl[0].expr.inst
with fmt.indented( with fmt.indented(
'Opcode::{} => {{'.format(inst.camel_name), '}'): 'Opcode::{} => {{'.format(inst.camel_name), '}'):
gen_xform(xform, fmt) gen_xform(xform, fmt, type_sets)
# We'll assume there are uncovered opcodes. # We'll assume there are uncovered opcodes.
fmt.line('_ => return false,') fmt.line('_ => return false,')
fmt.line('true') fmt.line('true')
@@ -240,6 +344,11 @@ def gen_xform_group(xgrp, fmt):
def generate(isas, out_dir): def generate(isas, out_dir):
# type: (Sequence[TargetISA], str) -> None # type: (Sequence[TargetISA], str) -> None
fmt = Formatter() fmt = Formatter()
gen_xform_group(legalize.narrow, fmt) # Table of TypeSet instances
gen_xform_group(legalize.expand, fmt) type_sets = UniqueTable()
gen_xform_group(legalize.narrow, fmt, type_sets)
gen_xform_group(legalize.expand, fmt, type_sets)
gen_typesets_table(fmt, type_sets)
fmt.update_file('legalizer.rs', out_dir) fmt.update_file('legalizer.rs', out_dir)

View File

@@ -0,0 +1,145 @@
import doctest
import gen_legalizer
from unittest import TestCase
from srcgen import Formatter
from gen_legalizer import get_runtime_typechecks, emit_runtime_typecheck
from base.instructions import vselect, vsplit, isplit, iconcat, vconcat, \
iconst, b1, icmp, copy # noqa
from base.legalize import narrow, expand # noqa
from base.immediates import intcc # noqa
from cdsl.typevar import TypeVar, TypeSet
from cdsl.ast import Var, Def # noqa
from cdsl.xform import Rtl, XForm # noqa
from cdsl.ti import ti_rtl, subst, TypeEnv, get_type_env # noqa
from unique_table import UniqueTable
from functools import reduce
try:
from typing import Callable, TYPE_CHECKING, Iterable, Any # noqa
if TYPE_CHECKING:
CheckProducer = Callable[[UniqueTable], str]
except ImportError:
TYPE_CHECKING = False
def load_tests(loader, tests, ignore):
# type: (Any, Any, Any) -> Any
tests.addTests(doctest.DocTestSuite(gen_legalizer))
return tests
def format_check(typesets, s, *args):
# type: (...) -> str
def transform(x):
# type: (Any) -> str
if isinstance(x, TypeSet):
return str(typesets.index[x])
elif isinstance(x, TypeVar):
assert not x.is_derived
return x.name
else:
return str(x)
dummy_s = s # type: str
args = tuple(map(lambda x: transform(x), args))
return dummy_s.format(*args)
def typeset_check(v, ts):
# type: (Var, TypeSet) -> CheckProducer
return lambda typesets: format_check(
typesets,
'if !TYPE_SETS[{}].contains(typeof_{}) ' +
'{{\n return false;\n}};\n', ts, v)
def equiv_check(tv1, tv2):
# type: (TypeVar, TypeVar) -> CheckProducer
return lambda typesets: format_check(
typesets,
'if Some({}).map(|t: Type| -> t.as_bool()) != ' +
'Some({}).map(|t: Type| -> t.as_bool()) ' +
'{{\n return false;\n}};\n', tv1, tv2)
def sequence(*args):
# type: (...) -> CheckProducer
dummy = args # type: Iterable[CheckProducer]
def sequenceF(typesets):
# type: (UniqueTable) -> str
def strconcat(acc, el):
# type: (str, CheckProducer) -> str
return acc + el(typesets)
return reduce(strconcat, dummy, "")
return sequenceF
class TestRuntimeChecks(TestCase):
def setUp(self):
# type: () -> None
self.v0 = Var("v0")
self.v1 = Var("v1")
self.v2 = Var("v2")
self.v3 = Var("v3")
self.v4 = Var("v4")
self.v5 = Var("v5")
self.v6 = Var("v6")
self.v7 = Var("v7")
self.v8 = Var("v8")
self.v9 = Var("v9")
self.imm0 = Var("imm0")
self.IxN_nonscalar = TypeVar("IxN_nonscalar", "", ints=True,
scalars=False, simd=True)
self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True,
scalars=False, simd=True)
self.b1 = TypeVar.singleton(b1)
def check_yo_check(self, xform, expected_f):
# type: (XForm, CheckProducer) -> None
fmt = Formatter()
type_sets = UniqueTable()
for check in get_runtime_typechecks(xform):
emit_runtime_typecheck(check, fmt, type_sets)
# Remove comments
got = "".join([l for l in fmt.lines if not l.strip().startswith("//")])
expected = expected_f(type_sets)
self.assertEqual(got, expected)
def test_width_check(self):
# type: () -> None
x = XForm(Rtl(self.v0 << copy(self.v1)),
Rtl((self.v2, self.v3) << isplit(self.v1),
self.v0 << iconcat(self.v2, self.v3)))
WideInt = TypeSet(lanes=(1, 256), ints=(16, 64))
self.check_yo_check(x, typeset_check(self.v1, WideInt))
def test_lanes_check(self):
# type: () -> None
x = XForm(Rtl(self.v0 << copy(self.v1)),
Rtl((self.v2, self.v3) << vsplit(self.v1),
self.v0 << vconcat(self.v2, self.v3)))
WideVec = TypeSet(lanes=(2, 256), ints=(8, 64), floats=(32, 64),
bools=(1, 64))
self.check_yo_check(x, typeset_check(self.v1, WideVec))
def test_vselect_imm(self):
# type: () -> None
ts = TypeSet(lanes=(2, 256), ints=(8, 64),
floats=(32, 64), bools=(8, 64))
r = Rtl(
self.v0 << iconst(self.imm0),
self.v1 << icmp(intcc.eq, self.v2, self.v0),
self.v5 << vselect(self.v1, self.v3, self.v4),
)
x = XForm(r, r)
self.check_yo_check(
x, sequence(typeset_check(self.v3, ts),
equiv_check(self.v2.get_typevar(),
self.v3.get_typevar())))

View File

@@ -506,10 +506,14 @@ type BitSet16 = BitSet<u16>;
/// A value type set describes the permitted set of types for a type variable. /// A value type set describes the permitted set of types for a type variable.
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ValueTypeSet { pub struct ValueTypeSet {
lanes: BitSet16, /// Allowed lane sizes
ints: BitSet8, pub lanes: BitSet16,
floats: BitSet8, /// Allowed int widths
bools: BitSet8, pub ints: BitSet8,
/// Allowed float widths
pub floats: BitSet8,
/// Allowed bool widths
pub bools: BitSet8,
} }
impl ValueTypeSet { impl ValueTypeSet {

View File

@@ -18,6 +18,8 @@ use flowgraph::ControlFlowGraph;
use ir::{Function, Cursor, DataFlowGraph, InstructionData, Opcode, InstBuilder}; use ir::{Function, Cursor, DataFlowGraph, InstructionData, Opcode, InstBuilder};
use ir::condcodes::IntCC; use ir::condcodes::IntCC;
use isa::{TargetIsa, Legalize}; use isa::{TargetIsa, Legalize};
use bitset::BitSet;
use ir::instructions::ValueTypeSet;
mod boundary; mod boundary;
mod split; mod split;