Emit runtime type checks in legalizer.rs (#112)
* Emit runtime type checks in legalizer.rs
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
464f2625d4
commit
98f822f347
@@ -10,7 +10,7 @@ from .typevar import TypeVar
|
|||||||
from .predicates import IsEqual, And
|
from .predicates import IsEqual, And
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Union, Tuple, Sequence, TYPE_CHECKING # noqa
|
from typing import Union, Tuple, Sequence, TYPE_CHECKING, Dict, List # noqa
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .operands import ImmediateKind # noqa
|
from .operands import ImmediateKind # noqa
|
||||||
from .predicates import PredNode # noqa
|
from .predicates import PredNode # noqa
|
||||||
@@ -18,6 +18,19 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def replace_var(arg, m):
|
||||||
|
# type: (Expr, Dict[Var, Var]) -> Expr
|
||||||
|
"""
|
||||||
|
Given a var v return either m[v] or a new variable v' (and remember
|
||||||
|
m[v]=v'). Otherwise return the argument unchanged
|
||||||
|
"""
|
||||||
|
if isinstance(arg, Var):
|
||||||
|
new_arg = m.get(arg, Var(arg.name)) # type: Var
|
||||||
|
m[arg] = new_arg
|
||||||
|
return new_arg
|
||||||
|
return arg
|
||||||
|
|
||||||
|
|
||||||
class Def(object):
|
class Def(object):
|
||||||
"""
|
"""
|
||||||
An AST definition associates a set of variables with the values produced by
|
An AST definition associates a set of variables with the values produced by
|
||||||
@@ -60,6 +73,21 @@ class Def(object):
|
|||||||
return "({}) << {!s}".format(
|
return "({}) << {!s}".format(
|
||||||
', '.join(map(str, self.defs)), self.expr)
|
', '.join(map(str, self.defs)), self.expr)
|
||||||
|
|
||||||
|
def copy(self, m):
|
||||||
|
# type: (Dict[Var, Var]) -> Def
|
||||||
|
"""
|
||||||
|
Return a copy of this Def with vars replaced with fresh variables,
|
||||||
|
in accordance with the map m. Update m as neccessary.
|
||||||
|
"""
|
||||||
|
new_expr = self.expr.copy(m)
|
||||||
|
new_defs = [] # type: List[Var]
|
||||||
|
for v in self.defs:
|
||||||
|
new_v = replace_var(v, m)
|
||||||
|
assert(isinstance(new_v, Var))
|
||||||
|
new_defs.append(new_v)
|
||||||
|
|
||||||
|
return Def(tuple(new_defs), new_expr)
|
||||||
|
|
||||||
|
|
||||||
class Expr(object):
|
class Expr(object):
|
||||||
"""
|
"""
|
||||||
@@ -303,6 +331,15 @@ class Apply(Expr):
|
|||||||
|
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
|
def copy(self, m):
|
||||||
|
# type: (Dict[Var, Var]) -> Apply
|
||||||
|
"""
|
||||||
|
Return a copy of this Expr with vars replaced with fresh variables,
|
||||||
|
in accordance with the map m. Update m as neccessary.
|
||||||
|
"""
|
||||||
|
return Apply(self.inst, tuple(map(lambda e: replace_var(e, m),
|
||||||
|
self.args)))
|
||||||
|
|
||||||
|
|
||||||
class Enumerator(Expr):
|
class Enumerator(Expr):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -6,30 +6,17 @@ from base.immediates import intcc
|
|||||||
from .typevar import TypeVar
|
from .typevar import TypeVar
|
||||||
from .ast import Var, Def
|
from .ast import Var, Def
|
||||||
from .xform import Rtl, XForm
|
from .xform import Rtl, XForm
|
||||||
from .ti import ti_rtl, subst, TypeEnv, get_type_env
|
from .ti import ti_rtl, subst, TypeEnv, get_type_env, ConstrainTVsEqual
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .ti import TypeMap, ConstraintList, VarMap, TypingOrError # noqa
|
from .ti import TypeMap, ConstraintList, VarMap, TypingOrError # noqa
|
||||||
from .ti import Constraint
|
|
||||||
from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa
|
from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
TYPE_CHECKING = False
|
TYPE_CHECKING = False
|
||||||
|
|
||||||
|
|
||||||
def sort_constr(c):
|
|
||||||
# type: (Constraint) -> Constraint
|
|
||||||
"""
|
|
||||||
Sort the 2 typevars in a constraint by name for comparison
|
|
||||||
"""
|
|
||||||
r = tuple(sorted(c, key=lambda y: y.name))
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
return cast(Constraint, r)
|
|
||||||
else:
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
def agree(me, other):
|
def agree(me, other):
|
||||||
# type: (TypeEnv, TypeEnv) -> bool
|
# type: (TypeEnv, TypeEnv) -> bool
|
||||||
"""
|
"""
|
||||||
@@ -63,13 +50,10 @@ def agree(me, other):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Translate our constraints using m, and sort
|
# Translate our constraints using m, and sort
|
||||||
me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints]
|
me_equiv_constr = sorted([constr.translate(m)
|
||||||
me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr])
|
for constr in me.constraints])
|
||||||
|
|
||||||
# Sort other's constraints
|
# Sort other's constraints
|
||||||
other_equiv_constr = sorted([sort_constr(x) for x in other.constraints],
|
other_equiv_constr = sorted(other.constraints)
|
||||||
key=lambda y: y[0].name)
|
|
||||||
|
|
||||||
return me_equiv_constr == other_equiv_constr
|
return me_equiv_constr == other_equiv_constr
|
||||||
|
|
||||||
|
|
||||||
@@ -224,7 +208,7 @@ class TestRTL(TypeCheckingBaseTest):
|
|||||||
self.v3: txn,
|
self.v3: txn,
|
||||||
self.v4: txn,
|
self.v4: txn,
|
||||||
self.v5: txn,
|
self.v5: txn,
|
||||||
}, [(ixn.as_bool(), txn.as_bool())]))
|
}, [ConstrainTVsEqual(ixn.as_bool(), txn.as_bool())]))
|
||||||
|
|
||||||
def test_vselect_vsplits(self):
|
def test_vselect_vsplits(self):
|
||||||
# type: () -> None
|
# type: () -> None
|
||||||
|
|||||||
@@ -8,13 +8,12 @@ from itertools import product
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
||||||
from typing import Iterable # noqa
|
from typing import Iterable, List # noqa
|
||||||
from typing import cast, List
|
from typing import cast
|
||||||
from .xform import Rtl, XForm # noqa
|
from .xform import Rtl, XForm # noqa
|
||||||
from .ast import Expr # noqa
|
from .ast import Expr # noqa
|
||||||
|
from .typevar import TypeSet # noqa
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
Constraint = Tuple[TypeVar, TypeVar]
|
|
||||||
ConstraintList = List[Constraint]
|
|
||||||
TypeMap = Dict[TypeVar, TypeVar]
|
TypeMap = Dict[TypeVar, TypeVar]
|
||||||
VarMap = Dict[Var, TypeVar]
|
VarMap = Dict[Var, TypeVar]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -22,6 +21,122 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TypeConstraint(object):
|
||||||
|
"""
|
||||||
|
Base class for all runtime-emittable type constraints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ConstrainTVsEqual(TypeConstraint):
|
||||||
|
"""
|
||||||
|
Constraint specifying that two derived type vars must have the same runtime
|
||||||
|
type.
|
||||||
|
"""
|
||||||
|
def __init__(self, tv1, tv2):
|
||||||
|
# type: (TypeVar, TypeVar) -> None
|
||||||
|
assert tv1.is_derived and tv2.is_derived
|
||||||
|
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
|
||||||
|
|
||||||
|
def is_trivial(self):
|
||||||
|
# type: () -> bool
|
||||||
|
"""
|
||||||
|
Return true if this constrain is statically decidable.
|
||||||
|
"""
|
||||||
|
return self.tv1 == self.tv2 or \
|
||||||
|
(self.tv1.singleton_type() is not None and
|
||||||
|
self.tv2.singleton_type() is not None)
|
||||||
|
|
||||||
|
def translate(self, m):
|
||||||
|
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual
|
||||||
|
"""
|
||||||
|
Translate any TypeVars in the constraint according to the map m
|
||||||
|
"""
|
||||||
|
if isinstance(m, TypeEnv):
|
||||||
|
return ConstrainTVsEqual(m[self.tv1], m[self.tv2])
|
||||||
|
else:
|
||||||
|
return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m))
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
# type: (object) -> bool
|
||||||
|
if (not isinstance(other, ConstrainTVsEqual)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (self.tv1, self.tv2) == (other.tv1, other.tv2)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
# type: () -> int
|
||||||
|
return hash((self.tv1, self.tv2))
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
# type: () -> bool
|
||||||
|
"""
|
||||||
|
Evaluate this constraint. Should only be called when the constraint has
|
||||||
|
been translated to concrete types.
|
||||||
|
"""
|
||||||
|
assert self.tv1.singleton_type() is not None and \
|
||||||
|
self.tv2.singleton_type() is not None
|
||||||
|
return self.tv1.singleton_type() == self.tv2.singleton_type()
|
||||||
|
|
||||||
|
|
||||||
|
class ConstrainTVInTypeset(TypeConstraint):
|
||||||
|
"""
|
||||||
|
Constraint specifying that a type var must belong to some typeset.
|
||||||
|
"""
|
||||||
|
def __init__(self, tv, ts):
|
||||||
|
# type: (TypeVar, TypeSet) -> None
|
||||||
|
assert not tv.is_derived and tv.name.startswith("typeof_")
|
||||||
|
self.tv = tv
|
||||||
|
self.ts = ts
|
||||||
|
|
||||||
|
def is_trivial(self):
|
||||||
|
# type: () -> bool
|
||||||
|
"""
|
||||||
|
Return true if this constrain is statically decidable.
|
||||||
|
"""
|
||||||
|
tv_ts = self.tv.get_typeset().copy()
|
||||||
|
|
||||||
|
# Trivially True
|
||||||
|
if (tv_ts.issubset(self.ts)):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Trivially false
|
||||||
|
tv_ts &= self.ts
|
||||||
|
if (tv_ts.size() == 0):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def translate(self, m):
|
||||||
|
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset
|
||||||
|
"""
|
||||||
|
Translate any TypeVars in the constraint according to the map m
|
||||||
|
"""
|
||||||
|
if isinstance(m, TypeEnv):
|
||||||
|
return ConstrainTVInTypeset(m[self.tv], self.ts)
|
||||||
|
else:
|
||||||
|
return ConstrainTVInTypeset(subst(self.tv, m), self.ts)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
# type: (object) -> bool
|
||||||
|
if (not isinstance(other, ConstrainTVInTypeset)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (self.tv, self.ts) == (other.tv, other.ts)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
# type: () -> int
|
||||||
|
return hash((self.tv, self.ts))
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
# type: () -> bool
|
||||||
|
"""
|
||||||
|
Evaluate this constraint. Should only be called when the constraint has
|
||||||
|
been translated to concrete types.
|
||||||
|
"""
|
||||||
|
assert self.tv.singleton_type() is not None
|
||||||
|
return self.tv.get_typeset().issubset(self.ts)
|
||||||
|
|
||||||
|
|
||||||
class TypeEnv(object):
|
class TypeEnv(object):
|
||||||
"""
|
"""
|
||||||
Class encapsulating the neccessary book keeping for type inference.
|
Class encapsulating the neccessary book keeping for type inference.
|
||||||
@@ -43,13 +158,13 @@ class TypeEnv(object):
|
|||||||
RANK_INTERNAL = 0
|
RANK_INTERNAL = 0
|
||||||
|
|
||||||
def __init__(self, arg=None):
|
def __init__(self, arg=None):
|
||||||
# type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None
|
# type: (Optional[Tuple[TypeMap, List[TypeConstraint]]]) -> None
|
||||||
self.ranks = {} # type: Dict[TypeVar, int]
|
self.ranks = {} # type: Dict[TypeVar, int]
|
||||||
self.vars = set() # type: Set[Var]
|
self.vars = set() # type: Set[Var]
|
||||||
|
|
||||||
if arg is None:
|
if arg is None:
|
||||||
self.type_map = {} # type: TypeMap
|
self.type_map = {} # type: TypeMap
|
||||||
self.constraints = [] # type: ConstraintList
|
self.constraints = [] # type: List[TypeConstraint]
|
||||||
else:
|
else:
|
||||||
self.type_map, self.constraints = arg
|
self.type_map, self.constraints = arg
|
||||||
|
|
||||||
@@ -94,7 +209,9 @@ class TypeEnv(object):
|
|||||||
"""
|
"""
|
||||||
Add a new equivalence constraint between tv1 and tv2
|
Add a new equivalence constraint between tv1 and tv2
|
||||||
"""
|
"""
|
||||||
self.constraints.append((tv1, tv2))
|
constr = ConstrainTVsEqual(tv1, tv2)
|
||||||
|
if (constr not in self.constraints):
|
||||||
|
self.constraints.append(constr)
|
||||||
|
|
||||||
def get_uid(self):
|
def get_uid(self):
|
||||||
# type: () -> str
|
# type: () -> str
|
||||||
@@ -206,15 +323,24 @@ class TypeEnv(object):
|
|||||||
"""
|
"""
|
||||||
vars_tvs = set([v.get_typevar() for v in self.vars])
|
vars_tvs = set([v.get_typevar() for v in self.vars])
|
||||||
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
|
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
|
||||||
new_constraints = [(self[tv1], self[tv2])
|
|
||||||
for (tv1, tv2) in self.constraints]
|
|
||||||
|
|
||||||
# Sanity: new constraints and the new type_map should only contain
|
new_constraints = [] # type: List[TypeConstraint]
|
||||||
# tvs associated with real vars
|
for constr in self.constraints:
|
||||||
for (a, b) in new_constraints:
|
# Currently typeinference only generates ConstrainTVsEqual
|
||||||
assert a.free_typevar() in vars_tvs and\
|
# constraints
|
||||||
b.free_typevar() in vars_tvs
|
assert isinstance(constr, ConstrainTVsEqual)
|
||||||
|
constr = constr.translate(self)
|
||||||
|
|
||||||
|
if constr.is_trivial() or constr in new_constraints:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Sanity: translated constraints should refer to only real vars
|
||||||
|
assert constr.tv1.free_typevar() in vars_tvs and\
|
||||||
|
constr.tv2.free_typevar() in vars_tvs
|
||||||
|
|
||||||
|
new_constraints.append(constr)
|
||||||
|
|
||||||
|
# Sanity: translated typemap should refer to only real vars
|
||||||
for (k, v) in new_type_map.items():
|
for (k, v) in new_type_map.items():
|
||||||
assert k in vars_tvs
|
assert k in vars_tvs
|
||||||
assert v.free_typevar() is None or v.free_typevar() in vars_tvs
|
assert v.free_typevar() is None or v.free_typevar() in vars_tvs
|
||||||
@@ -245,13 +371,13 @@ class TypeEnv(object):
|
|||||||
|
|
||||||
# Check if constraints are satisfied for this typing
|
# Check if constraints are satisfied for this typing
|
||||||
failed = None
|
failed = None
|
||||||
for (tv1, tv2) in self.constraints:
|
for constr in self.constraints:
|
||||||
tv1 = subst(tv1, m)
|
# Currently typeinference only generates ConstrainTVsEqual
|
||||||
tv2 = subst(tv2, m)
|
# constraints
|
||||||
assert tv1.get_typeset().size() == 1 and\
|
assert isinstance(constr, ConstrainTVsEqual)
|
||||||
tv2.get_typeset().size() == 1
|
concrete_constr = constr.translate(m)
|
||||||
if (tv1.get_typeset() != tv2.get_typeset()):
|
if not concrete_constr.eval():
|
||||||
failed = (tv1, tv2)
|
failed = concrete_constr
|
||||||
break
|
break
|
||||||
|
|
||||||
if (failed is not None):
|
if (failed is not None):
|
||||||
@@ -287,9 +413,10 @@ class TypeEnv(object):
|
|||||||
edges.add((v, v.base, "solid", v.derived_func))
|
edges.add((v, v.base, "solid", v.derived_func))
|
||||||
v = v.base
|
v = v.base
|
||||||
|
|
||||||
for (a, b) in self.constraints:
|
for constr in self.constraints:
|
||||||
assert a in nodes and b in nodes
|
assert isinstance(constr, ConstrainTVsEqual)
|
||||||
edges.add((a, b, "dashed", None))
|
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||||
|
edges.add((constr.tv1, constr.tv2, "dashed", None))
|
||||||
|
|
||||||
root_nodes = set([x for x in nodes
|
root_nodes = set([x for x in nodes
|
||||||
if x not in self.type_map and not x.is_derived])
|
if x not in self.type_map and not x.is_derived])
|
||||||
|
|||||||
@@ -41,6 +41,14 @@ class Rtl(object):
|
|||||||
# type: (*DefApply) -> None
|
# type: (*DefApply) -> None
|
||||||
self.rtl = tuple(map(canonicalize_defapply, args))
|
self.rtl = tuple(map(canonicalize_defapply, args))
|
||||||
|
|
||||||
|
def copy(self, m):
|
||||||
|
# type: (Dict[Var, Var]) -> Rtl
|
||||||
|
"""
|
||||||
|
Return a copy of this rtl with all Vars substituted with copies or
|
||||||
|
according to m. Update m as neccessary.
|
||||||
|
"""
|
||||||
|
return Rtl(*[d.copy(m) for d in self.rtl])
|
||||||
|
|
||||||
|
|
||||||
class XForm(object):
|
class XForm(object):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -336,6 +336,25 @@ def get_constraint(op, ctrl_typevar, type_sets):
|
|||||||
return 'Same'
|
return 'Same'
|
||||||
|
|
||||||
|
|
||||||
|
# TypeSet indexes are encoded in 8 bits, with `0xff` reserved.
|
||||||
|
typeset_limit = 0xff
|
||||||
|
|
||||||
|
|
||||||
|
def gen_typesets_table(fmt, type_sets):
|
||||||
|
# type: (srcgen.Formatter, UniqueTable) -> None
|
||||||
|
"""
|
||||||
|
Generate the table of ValueTypeSets described by type_sets.
|
||||||
|
"""
|
||||||
|
fmt.comment('Table of value type sets.')
|
||||||
|
assert len(type_sets.table) <= typeset_limit, "Too many type sets"
|
||||||
|
with fmt.indented(
|
||||||
|
'const TYPE_SETS : [ValueTypeSet; {}] = ['
|
||||||
|
.format(len(type_sets.table)), '];'):
|
||||||
|
for ts in type_sets.table:
|
||||||
|
with fmt.indented('ValueTypeSet {', '},'):
|
||||||
|
ts.emit_fields(fmt)
|
||||||
|
|
||||||
|
|
||||||
def gen_type_constraints(fmt, instrs):
|
def gen_type_constraints(fmt, instrs):
|
||||||
# type: (srcgen.Formatter, Sequence[Instruction]) -> None
|
# type: (srcgen.Formatter, Sequence[Instruction]) -> None
|
||||||
"""
|
"""
|
||||||
@@ -360,9 +379,6 @@ def gen_type_constraints(fmt, instrs):
|
|||||||
# Preload table with constraints for typical binops.
|
# Preload table with constraints for typical binops.
|
||||||
operand_seqs.add(['Same'] * 3)
|
operand_seqs.add(['Same'] * 3)
|
||||||
|
|
||||||
# TypeSet indexes are encoded in 8 bits, with `0xff` reserved.
|
|
||||||
typeset_limit = 0xff
|
|
||||||
|
|
||||||
fmt.comment('Table of opcode constraints.')
|
fmt.comment('Table of opcode constraints.')
|
||||||
with fmt.indented(
|
with fmt.indented(
|
||||||
'const OPCODE_CONSTRAINTS : [OpcodeConstraints; {}] = ['
|
'const OPCODE_CONSTRAINTS : [OpcodeConstraints; {}] = ['
|
||||||
@@ -418,14 +434,7 @@ def gen_type_constraints(fmt, instrs):
|
|||||||
fmt.line('typeset_offset: {},'.format(ctrl_typeset))
|
fmt.line('typeset_offset: {},'.format(ctrl_typeset))
|
||||||
fmt.line('constraint_offset: {},'.format(offset))
|
fmt.line('constraint_offset: {},'.format(offset))
|
||||||
|
|
||||||
fmt.comment('Table of value type sets.')
|
gen_typesets_table(fmt, type_sets)
|
||||||
assert len(type_sets.table) <= typeset_limit, "Too many type sets"
|
|
||||||
with fmt.indented(
|
|
||||||
'const TYPE_SETS : [ValueTypeSet; {}] = ['
|
|
||||||
.format(len(type_sets.table)), '];'):
|
|
||||||
for ts in type_sets.table:
|
|
||||||
with fmt.indented('ValueTypeSet {', '},'):
|
|
||||||
ts.emit_fields(fmt)
|
|
||||||
|
|
||||||
fmt.comment('Table of operand constraint sequences.')
|
fmt.comment('Table of operand constraint sequences.')
|
||||||
with fmt.indented(
|
with fmt.indented(
|
||||||
|
|||||||
@@ -11,16 +11,116 @@ from __future__ import absolute_import
|
|||||||
from srcgen import Formatter
|
from srcgen import Formatter
|
||||||
from base import legalize, instructions
|
from base import legalize, instructions
|
||||||
from cdsl.ast import Var
|
from cdsl.ast import Var
|
||||||
|
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, ConstrainTVsEqual,\
|
||||||
|
ConstrainTVInTypeset
|
||||||
|
from unique_table import UniqueTable
|
||||||
|
from gen_instr import gen_typesets_table
|
||||||
|
from cdsl.typevar import TypeVar
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Sequence # noqa
|
from typing import Sequence, List, Dict # noqa
|
||||||
from cdsl.isa import TargetISA # noqa
|
from cdsl.isa import TargetISA # noqa
|
||||||
from cdsl.ast import Def # noqa
|
from cdsl.ast import Def # noqa
|
||||||
from cdsl.xform import XForm, XFormGroup # noqa
|
from cdsl.xform import XForm, XFormGroup # noqa
|
||||||
|
from cdsl.typevar import TypeSet # noqa
|
||||||
|
from cdsl.ti import TypeConstraint # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_runtime_typechecks(xform):
|
||||||
|
# type: (XForm) -> List[TypeConstraint]
|
||||||
|
"""
|
||||||
|
Given a XForm build a list of runtime type checks neccessary to determine
|
||||||
|
if it applies. We have 2 types of runtime checks:
|
||||||
|
1) typevar tv belongs to typeset T - needed for free tvs whose
|
||||||
|
typeset is constrainted by their use in the dst pattern
|
||||||
|
|
||||||
|
2) tv1 == tv2 where tv1 and tv2 are derived TVs - caused by unification
|
||||||
|
of non-bijective functions
|
||||||
|
"""
|
||||||
|
check_l = [] # type: List[TypeConstraint]
|
||||||
|
|
||||||
|
# 1) Perform ti only on the source RTL. Accumulate any free tvs that have a
|
||||||
|
# different inferred type in src, compared to the type inferred for both
|
||||||
|
# src and dst.
|
||||||
|
symtab = {} # type: Dict[Var, Var]
|
||||||
|
src_copy = xform.src.copy(symtab)
|
||||||
|
src_typenv = get_type_env(ti_rtl(src_copy, TypeEnv()))
|
||||||
|
|
||||||
|
for v in xform.ti.vars:
|
||||||
|
if not v.has_free_typevar():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# In rust the local variable containing a free TV associated with var v
|
||||||
|
# has name typeof_v. We rely on the python TVs having the same name.
|
||||||
|
assert "typeof_{}".format(v) == xform.ti[v].name
|
||||||
|
|
||||||
|
if v not in symtab:
|
||||||
|
# We can have singleton vars defined only on dst. Ignore them
|
||||||
|
assert v.get_typevar().singleton_type() is not None
|
||||||
|
continue
|
||||||
|
|
||||||
|
src_ts = src_typenv[symtab[v]].get_typeset()
|
||||||
|
xform_ts = xform.ti[v].get_typeset()
|
||||||
|
|
||||||
|
assert xform_ts.issubset(src_ts)
|
||||||
|
if src_ts != xform_ts:
|
||||||
|
check_l.append(ConstrainTVInTypeset(xform.ti[v], xform_ts))
|
||||||
|
|
||||||
|
# 2,3) Add any constraints that appear in xform.ti
|
||||||
|
check_l.extend(xform.ti.constraints)
|
||||||
|
|
||||||
|
return check_l
|
||||||
|
|
||||||
|
|
||||||
|
def emit_runtime_typecheck(check, fmt, type_sets):
|
||||||
|
# type: (TypeConstraint, Formatter, UniqueTable) -> None
|
||||||
|
"""
|
||||||
|
Emit rust code for the given check.
|
||||||
|
"""
|
||||||
|
def build_derived_expr(tv):
|
||||||
|
# type: (TypeVar) -> str
|
||||||
|
if not tv.is_derived:
|
||||||
|
assert tv.name.startswith('typeof_')
|
||||||
|
return "Some({})".format(tv.name)
|
||||||
|
|
||||||
|
base_exp = build_derived_expr(tv.base)
|
||||||
|
if (tv.derived_func == TypeVar.LANEOF):
|
||||||
|
return "{}.map(|t: Type| -> t.lane_type())".format(base_exp)
|
||||||
|
elif (tv.derived_func == TypeVar.ASBOOL):
|
||||||
|
return "{}.map(|t: Type| -> t.as_bool())".format(base_exp)
|
||||||
|
elif (tv.derived_func == TypeVar.HALFWIDTH):
|
||||||
|
return "{}.and_then(|t: Type| -> t.half_width())".format(base_exp)
|
||||||
|
elif (tv.derived_func == TypeVar.DOUBLEWIDTH):
|
||||||
|
return "{}.and_then(|t: Type| -> t.double_width())"\
|
||||||
|
.format(base_exp)
|
||||||
|
elif (tv.derived_func == TypeVar.HALFVECTOR):
|
||||||
|
return "{}.and_then(|t: Type| -> t.half_vector())".format(base_exp)
|
||||||
|
elif (tv.derived_func == TypeVar.DOUBLEVECTOR):
|
||||||
|
return "{}.and_then(|t: Type| -> t.by(2))".format(base_exp)
|
||||||
|
else:
|
||||||
|
assert False, "Unknown derived function {}".format(tv.derived_func)
|
||||||
|
|
||||||
|
if (isinstance(check, ConstrainTVInTypeset)):
|
||||||
|
tv = check.tv.name
|
||||||
|
if check.ts not in type_sets.index:
|
||||||
|
type_sets.add(check.ts)
|
||||||
|
ts = type_sets.index[check.ts]
|
||||||
|
|
||||||
|
fmt.comment("{} must belong to {}".format(tv, check.ts))
|
||||||
|
with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv),
|
||||||
|
'};'):
|
||||||
|
fmt.line('return false;')
|
||||||
|
elif (isinstance(check, ConstrainTVsEqual)):
|
||||||
|
tv1 = build_derived_expr(check.tv1)
|
||||||
|
tv2 = build_derived_expr(check.tv2)
|
||||||
|
with fmt.indented('if {} != {} {{'.format(tv1, tv2), '};'):
|
||||||
|
fmt.line('return false;')
|
||||||
|
else:
|
||||||
|
assert False, "Unknown check {}".format(check)
|
||||||
|
|
||||||
|
|
||||||
def unwrap_inst(iref, node, fmt):
|
def unwrap_inst(iref, node, fmt):
|
||||||
# type: (str, Def, Formatter) -> bool
|
# type: (str, Def, Formatter) -> bool
|
||||||
"""
|
"""
|
||||||
@@ -183,8 +283,8 @@ def emit_dst_inst(node, fmt):
|
|||||||
fmt.line('pos.next_inst();')
|
fmt.line('pos.next_inst();')
|
||||||
|
|
||||||
|
|
||||||
def gen_xform(xform, fmt):
|
def gen_xform(xform, fmt, type_sets):
|
||||||
# type: (XForm, Formatter) -> None
|
# type: (XForm, Formatter, UniqueTable) -> None
|
||||||
"""
|
"""
|
||||||
Emit code for `xform`, assuming the the opcode of xform's root instruction
|
Emit code for `xform`, assuming the the opcode of xform's root instruction
|
||||||
has already been matched.
|
has already been matched.
|
||||||
@@ -203,6 +303,10 @@ def gen_xform(xform, fmt):
|
|||||||
instp = xform.src.rtl[0].expr.inst_predicate()
|
instp = xform.src.rtl[0].expr.inst_predicate()
|
||||||
assert instp is None, "Instruction predicates not supported in legalizer"
|
assert instp is None, "Instruction predicates not supported in legalizer"
|
||||||
|
|
||||||
|
# Emit any runtime checks.
|
||||||
|
for check in get_runtime_typechecks(xform):
|
||||||
|
emit_runtime_typecheck(check, fmt, type_sets)
|
||||||
|
|
||||||
# Emit the destination pattern.
|
# Emit the destination pattern.
|
||||||
for dst in xform.dst.rtl:
|
for dst in xform.dst.rtl:
|
||||||
emit_dst_inst(dst, fmt)
|
emit_dst_inst(dst, fmt)
|
||||||
@@ -213,8 +317,8 @@ def gen_xform(xform, fmt):
|
|||||||
fmt.line('assert_eq!(pos.remove_inst(), inst);')
|
fmt.line('assert_eq!(pos.remove_inst(), inst);')
|
||||||
|
|
||||||
|
|
||||||
def gen_xform_group(xgrp, fmt):
|
def gen_xform_group(xgrp, fmt, type_sets):
|
||||||
# type: (XFormGroup, Formatter) -> None
|
# type: (XFormGroup, Formatter, UniqueTable) -> None
|
||||||
fmt.doc_comment("Legalize the instruction pointed to by `pos`.")
|
fmt.doc_comment("Legalize the instruction pointed to by `pos`.")
|
||||||
fmt.line('#[allow(unused_variables,unused_assignments)]')
|
fmt.line('#[allow(unused_variables,unused_assignments)]')
|
||||||
with fmt.indented(
|
with fmt.indented(
|
||||||
@@ -231,7 +335,7 @@ def gen_xform_group(xgrp, fmt):
|
|||||||
inst = xform.src.rtl[0].expr.inst
|
inst = xform.src.rtl[0].expr.inst
|
||||||
with fmt.indented(
|
with fmt.indented(
|
||||||
'Opcode::{} => {{'.format(inst.camel_name), '}'):
|
'Opcode::{} => {{'.format(inst.camel_name), '}'):
|
||||||
gen_xform(xform, fmt)
|
gen_xform(xform, fmt, type_sets)
|
||||||
# We'll assume there are uncovered opcodes.
|
# We'll assume there are uncovered opcodes.
|
||||||
fmt.line('_ => return false,')
|
fmt.line('_ => return false,')
|
||||||
fmt.line('true')
|
fmt.line('true')
|
||||||
@@ -240,6 +344,11 @@ def gen_xform_group(xgrp, fmt):
|
|||||||
def generate(isas, out_dir):
|
def generate(isas, out_dir):
|
||||||
# type: (Sequence[TargetISA], str) -> None
|
# type: (Sequence[TargetISA], str) -> None
|
||||||
fmt = Formatter()
|
fmt = Formatter()
|
||||||
gen_xform_group(legalize.narrow, fmt)
|
# Table of TypeSet instances
|
||||||
gen_xform_group(legalize.expand, fmt)
|
type_sets = UniqueTable()
|
||||||
|
|
||||||
|
gen_xform_group(legalize.narrow, fmt, type_sets)
|
||||||
|
gen_xform_group(legalize.expand, fmt, type_sets)
|
||||||
|
|
||||||
|
gen_typesets_table(fmt, type_sets)
|
||||||
fmt.update_file('legalizer.rs', out_dir)
|
fmt.update_file('legalizer.rs', out_dir)
|
||||||
|
|||||||
145
lib/cretonne/meta/test_gen_legalizer.py
Normal file
145
lib/cretonne/meta/test_gen_legalizer.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import doctest
|
||||||
|
import gen_legalizer
|
||||||
|
from unittest import TestCase
|
||||||
|
from srcgen import Formatter
|
||||||
|
from gen_legalizer import get_runtime_typechecks, emit_runtime_typecheck
|
||||||
|
from base.instructions import vselect, vsplit, isplit, iconcat, vconcat, \
|
||||||
|
iconst, b1, icmp, copy # noqa
|
||||||
|
from base.legalize import narrow, expand # noqa
|
||||||
|
from base.immediates import intcc # noqa
|
||||||
|
from cdsl.typevar import TypeVar, TypeSet
|
||||||
|
from cdsl.ast import Var, Def # noqa
|
||||||
|
from cdsl.xform import Rtl, XForm # noqa
|
||||||
|
from cdsl.ti import ti_rtl, subst, TypeEnv, get_type_env # noqa
|
||||||
|
from unique_table import UniqueTable
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import Callable, TYPE_CHECKING, Iterable, Any # noqa
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
CheckProducer = Callable[[UniqueTable], str]
|
||||||
|
except ImportError:
|
||||||
|
TYPE_CHECKING = False
|
||||||
|
|
||||||
|
|
||||||
|
def load_tests(loader, tests, ignore):
|
||||||
|
# type: (Any, Any, Any) -> Any
|
||||||
|
tests.addTests(doctest.DocTestSuite(gen_legalizer))
|
||||||
|
return tests
|
||||||
|
|
||||||
|
|
||||||
|
def format_check(typesets, s, *args):
|
||||||
|
# type: (...) -> str
|
||||||
|
def transform(x):
|
||||||
|
# type: (Any) -> str
|
||||||
|
if isinstance(x, TypeSet):
|
||||||
|
return str(typesets.index[x])
|
||||||
|
elif isinstance(x, TypeVar):
|
||||||
|
assert not x.is_derived
|
||||||
|
return x.name
|
||||||
|
else:
|
||||||
|
return str(x)
|
||||||
|
|
||||||
|
dummy_s = s # type: str
|
||||||
|
args = tuple(map(lambda x: transform(x), args))
|
||||||
|
return dummy_s.format(*args)
|
||||||
|
|
||||||
|
|
||||||
|
def typeset_check(v, ts):
|
||||||
|
# type: (Var, TypeSet) -> CheckProducer
|
||||||
|
return lambda typesets: format_check(
|
||||||
|
typesets,
|
||||||
|
'if !TYPE_SETS[{}].contains(typeof_{}) ' +
|
||||||
|
'{{\n return false;\n}};\n', ts, v)
|
||||||
|
|
||||||
|
|
||||||
|
def equiv_check(tv1, tv2):
|
||||||
|
# type: (TypeVar, TypeVar) -> CheckProducer
|
||||||
|
return lambda typesets: format_check(
|
||||||
|
typesets,
|
||||||
|
'if Some({}).map(|t: Type| -> t.as_bool()) != ' +
|
||||||
|
'Some({}).map(|t: Type| -> t.as_bool()) ' +
|
||||||
|
'{{\n return false;\n}};\n', tv1, tv2)
|
||||||
|
|
||||||
|
|
||||||
|
def sequence(*args):
|
||||||
|
# type: (...) -> CheckProducer
|
||||||
|
dummy = args # type: Iterable[CheckProducer]
|
||||||
|
|
||||||
|
def sequenceF(typesets):
|
||||||
|
# type: (UniqueTable) -> str
|
||||||
|
def strconcat(acc, el):
|
||||||
|
# type: (str, CheckProducer) -> str
|
||||||
|
return acc + el(typesets)
|
||||||
|
|
||||||
|
return reduce(strconcat, dummy, "")
|
||||||
|
return sequenceF
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuntimeChecks(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# type: () -> None
|
||||||
|
self.v0 = Var("v0")
|
||||||
|
self.v1 = Var("v1")
|
||||||
|
self.v2 = Var("v2")
|
||||||
|
self.v3 = Var("v3")
|
||||||
|
self.v4 = Var("v4")
|
||||||
|
self.v5 = Var("v5")
|
||||||
|
self.v6 = Var("v6")
|
||||||
|
self.v7 = Var("v7")
|
||||||
|
self.v8 = Var("v8")
|
||||||
|
self.v9 = Var("v9")
|
||||||
|
self.imm0 = Var("imm0")
|
||||||
|
self.IxN_nonscalar = TypeVar("IxN_nonscalar", "", ints=True,
|
||||||
|
scalars=False, simd=True)
|
||||||
|
self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True,
|
||||||
|
scalars=False, simd=True)
|
||||||
|
self.b1 = TypeVar.singleton(b1)
|
||||||
|
|
||||||
|
def check_yo_check(self, xform, expected_f):
|
||||||
|
# type: (XForm, CheckProducer) -> None
|
||||||
|
fmt = Formatter()
|
||||||
|
type_sets = UniqueTable()
|
||||||
|
for check in get_runtime_typechecks(xform):
|
||||||
|
emit_runtime_typecheck(check, fmt, type_sets)
|
||||||
|
|
||||||
|
# Remove comments
|
||||||
|
got = "".join([l for l in fmt.lines if not l.strip().startswith("//")])
|
||||||
|
expected = expected_f(type_sets)
|
||||||
|
self.assertEqual(got, expected)
|
||||||
|
|
||||||
|
def test_width_check(self):
|
||||||
|
# type: () -> None
|
||||||
|
x = XForm(Rtl(self.v0 << copy(self.v1)),
|
||||||
|
Rtl((self.v2, self.v3) << isplit(self.v1),
|
||||||
|
self.v0 << iconcat(self.v2, self.v3)))
|
||||||
|
|
||||||
|
WideInt = TypeSet(lanes=(1, 256), ints=(16, 64))
|
||||||
|
self.check_yo_check(x, typeset_check(self.v1, WideInt))
|
||||||
|
|
||||||
|
def test_lanes_check(self):
|
||||||
|
# type: () -> None
|
||||||
|
x = XForm(Rtl(self.v0 << copy(self.v1)),
|
||||||
|
Rtl((self.v2, self.v3) << vsplit(self.v1),
|
||||||
|
self.v0 << vconcat(self.v2, self.v3)))
|
||||||
|
|
||||||
|
WideVec = TypeSet(lanes=(2, 256), ints=(8, 64), floats=(32, 64),
|
||||||
|
bools=(1, 64))
|
||||||
|
self.check_yo_check(x, typeset_check(self.v1, WideVec))
|
||||||
|
|
||||||
|
def test_vselect_imm(self):
|
||||||
|
# type: () -> None
|
||||||
|
ts = TypeSet(lanes=(2, 256), ints=(8, 64),
|
||||||
|
floats=(32, 64), bools=(8, 64))
|
||||||
|
r = Rtl(
|
||||||
|
self.v0 << iconst(self.imm0),
|
||||||
|
self.v1 << icmp(intcc.eq, self.v2, self.v0),
|
||||||
|
self.v5 << vselect(self.v1, self.v3, self.v4),
|
||||||
|
)
|
||||||
|
x = XForm(r, r)
|
||||||
|
|
||||||
|
self.check_yo_check(
|
||||||
|
x, sequence(typeset_check(self.v3, ts),
|
||||||
|
equiv_check(self.v2.get_typevar(),
|
||||||
|
self.v3.get_typevar())))
|
||||||
@@ -506,10 +506,14 @@ type BitSet16 = BitSet<u16>;
|
|||||||
/// A value type set describes the permitted set of types for a type variable.
|
/// A value type set describes the permitted set of types for a type variable.
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
pub struct ValueTypeSet {
|
pub struct ValueTypeSet {
|
||||||
lanes: BitSet16,
|
/// Allowed lane sizes
|
||||||
ints: BitSet8,
|
pub lanes: BitSet16,
|
||||||
floats: BitSet8,
|
/// Allowed int widths
|
||||||
bools: BitSet8,
|
pub ints: BitSet8,
|
||||||
|
/// Allowed float widths
|
||||||
|
pub floats: BitSet8,
|
||||||
|
/// Allowed bool widths
|
||||||
|
pub bools: BitSet8,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ValueTypeSet {
|
impl ValueTypeSet {
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ use flowgraph::ControlFlowGraph;
|
|||||||
use ir::{Function, Cursor, DataFlowGraph, InstructionData, Opcode, InstBuilder};
|
use ir::{Function, Cursor, DataFlowGraph, InstructionData, Opcode, InstBuilder};
|
||||||
use ir::condcodes::IntCC;
|
use ir::condcodes::IntCC;
|
||||||
use isa::{TargetIsa, Legalize};
|
use isa::{TargetIsa, Legalize};
|
||||||
|
use bitset::BitSet;
|
||||||
|
use ir::instructions::ValueTypeSet;
|
||||||
|
|
||||||
mod boundary;
|
mod boundary;
|
||||||
mod split;
|
mod split;
|
||||||
|
|||||||
Reference in New Issue
Block a user