Emit runtime type checks in legalizer.rs (#112)
* Emit runtime type checks in legalizer.rs
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
528e6ff3f5
commit
fc11ae7b72
@@ -10,7 +10,7 @@ from .typevar import TypeVar
|
||||
from .predicates import IsEqual, And
|
||||
|
||||
try:
|
||||
from typing import Union, Tuple, Sequence, TYPE_CHECKING # noqa
|
||||
from typing import Union, Tuple, Sequence, TYPE_CHECKING, Dict, List # noqa
|
||||
if TYPE_CHECKING:
|
||||
from .operands import ImmediateKind # noqa
|
||||
from .predicates import PredNode # noqa
|
||||
@@ -18,6 +18,19 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def replace_var(arg, m):
|
||||
# type: (Expr, Dict[Var, Var]) -> Expr
|
||||
"""
|
||||
Given a var v return either m[v] or a new variable v' (and remember
|
||||
m[v]=v'). Otherwise return the argument unchanged
|
||||
"""
|
||||
if isinstance(arg, Var):
|
||||
new_arg = m.get(arg, Var(arg.name)) # type: Var
|
||||
m[arg] = new_arg
|
||||
return new_arg
|
||||
return arg
|
||||
|
||||
|
||||
class Def(object):
|
||||
"""
|
||||
An AST definition associates a set of variables with the values produced by
|
||||
@@ -60,6 +73,21 @@ class Def(object):
|
||||
return "({}) << {!s}".format(
|
||||
', '.join(map(str, self.defs)), self.expr)
|
||||
|
||||
def copy(self, m):
|
||||
# type: (Dict[Var, Var]) -> Def
|
||||
"""
|
||||
Return a copy of this Def with vars replaced with fresh variables,
|
||||
in accordance with the map m. Update m as neccessary.
|
||||
"""
|
||||
new_expr = self.expr.copy(m)
|
||||
new_defs = [] # type: List[Var]
|
||||
for v in self.defs:
|
||||
new_v = replace_var(v, m)
|
||||
assert(isinstance(new_v, Var))
|
||||
new_defs.append(new_v)
|
||||
|
||||
return Def(tuple(new_defs), new_expr)
|
||||
|
||||
|
||||
class Expr(object):
|
||||
"""
|
||||
@@ -303,6 +331,15 @@ class Apply(Expr):
|
||||
|
||||
return pred
|
||||
|
||||
def copy(self, m):
|
||||
# type: (Dict[Var, Var]) -> Apply
|
||||
"""
|
||||
Return a copy of this Expr with vars replaced with fresh variables,
|
||||
in accordance with the map m. Update m as neccessary.
|
||||
"""
|
||||
return Apply(self.inst, tuple(map(lambda e: replace_var(e, m),
|
||||
self.args)))
|
||||
|
||||
|
||||
class Enumerator(Expr):
|
||||
"""
|
||||
|
||||
@@ -6,30 +6,17 @@ from base.immediates import intcc
|
||||
from .typevar import TypeVar
|
||||
from .ast import Var, Def
|
||||
from .xform import Rtl, XForm
|
||||
from .ti import ti_rtl, subst, TypeEnv, get_type_env
|
||||
from .ti import ti_rtl, subst, TypeEnv, get_type_env, ConstrainTVsEqual
|
||||
from unittest import TestCase
|
||||
from functools import reduce
|
||||
|
||||
try:
|
||||
from .ti import TypeMap, ConstraintList, VarMap, TypingOrError # noqa
|
||||
from .ti import Constraint
|
||||
from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa
|
||||
except ImportError:
|
||||
TYPE_CHECKING = False
|
||||
|
||||
|
||||
def sort_constr(c):
|
||||
# type: (Constraint) -> Constraint
|
||||
"""
|
||||
Sort the 2 typevars in a constraint by name for comparison
|
||||
"""
|
||||
r = tuple(sorted(c, key=lambda y: y.name))
|
||||
if TYPE_CHECKING:
|
||||
return cast(Constraint, r)
|
||||
else:
|
||||
return r
|
||||
|
||||
|
||||
def agree(me, other):
|
||||
# type: (TypeEnv, TypeEnv) -> bool
|
||||
"""
|
||||
@@ -63,13 +50,10 @@ def agree(me, other):
|
||||
return False
|
||||
|
||||
# Translate our constraints using m, and sort
|
||||
me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints]
|
||||
me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr])
|
||||
|
||||
me_equiv_constr = sorted([constr.translate(m)
|
||||
for constr in me.constraints])
|
||||
# Sort other's constraints
|
||||
other_equiv_constr = sorted([sort_constr(x) for x in other.constraints],
|
||||
key=lambda y: y[0].name)
|
||||
|
||||
other_equiv_constr = sorted(other.constraints)
|
||||
return me_equiv_constr == other_equiv_constr
|
||||
|
||||
|
||||
@@ -224,7 +208,7 @@ class TestRTL(TypeCheckingBaseTest):
|
||||
self.v3: txn,
|
||||
self.v4: txn,
|
||||
self.v5: txn,
|
||||
}, [(ixn.as_bool(), txn.as_bool())]))
|
||||
}, [ConstrainTVsEqual(ixn.as_bool(), txn.as_bool())]))
|
||||
|
||||
def test_vselect_vsplits(self):
|
||||
# type: () -> None
|
||||
|
||||
@@ -8,13 +8,12 @@ from itertools import product
|
||||
|
||||
try:
|
||||
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
||||
from typing import Iterable # noqa
|
||||
from typing import cast, List
|
||||
from typing import Iterable, List # noqa
|
||||
from typing import cast
|
||||
from .xform import Rtl, XForm # noqa
|
||||
from .ast import Expr # noqa
|
||||
from .typevar import TypeSet # noqa
|
||||
if TYPE_CHECKING:
|
||||
Constraint = Tuple[TypeVar, TypeVar]
|
||||
ConstraintList = List[Constraint]
|
||||
TypeMap = Dict[TypeVar, TypeVar]
|
||||
VarMap = Dict[Var, TypeVar]
|
||||
except ImportError:
|
||||
@@ -22,6 +21,122 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class TypeConstraint(object):
|
||||
"""
|
||||
Base class for all runtime-emittable type constraints.
|
||||
"""
|
||||
|
||||
|
||||
class ConstrainTVsEqual(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that two derived type vars must have the same runtime
|
||||
type.
|
||||
"""
|
||||
def __init__(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
assert tv1.is_derived and tv2.is_derived
|
||||
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true if this constrain is statically decidable.
|
||||
"""
|
||||
return self.tv1 == self.tv2 or \
|
||||
(self.tv1.singleton_type() is not None and
|
||||
self.tv2.singleton_type() is not None)
|
||||
|
||||
def translate(self, m):
|
||||
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual
|
||||
"""
|
||||
Translate any TypeVars in the constraint according to the map m
|
||||
"""
|
||||
if isinstance(m, TypeEnv):
|
||||
return ConstrainTVsEqual(m[self.tv1], m[self.tv2])
|
||||
else:
|
||||
return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m))
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if (not isinstance(other, ConstrainTVsEqual)):
|
||||
return False
|
||||
|
||||
return (self.tv1, self.tv2) == (other.tv1, other.tv2)
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
return hash((self.tv1, self.tv2))
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Evaluate this constraint. Should only be called when the constraint has
|
||||
been translated to concrete types.
|
||||
"""
|
||||
assert self.tv1.singleton_type() is not None and \
|
||||
self.tv2.singleton_type() is not None
|
||||
return self.tv1.singleton_type() == self.tv2.singleton_type()
|
||||
|
||||
|
||||
class ConstrainTVInTypeset(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that a type var must belong to some typeset.
|
||||
"""
|
||||
def __init__(self, tv, ts):
|
||||
# type: (TypeVar, TypeSet) -> None
|
||||
assert not tv.is_derived and tv.name.startswith("typeof_")
|
||||
self.tv = tv
|
||||
self.ts = ts
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true if this constrain is statically decidable.
|
||||
"""
|
||||
tv_ts = self.tv.get_typeset().copy()
|
||||
|
||||
# Trivially True
|
||||
if (tv_ts.issubset(self.ts)):
|
||||
return True
|
||||
|
||||
# Trivially false
|
||||
tv_ts &= self.ts
|
||||
if (tv_ts.size() == 0):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def translate(self, m):
|
||||
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset
|
||||
"""
|
||||
Translate any TypeVars in the constraint according to the map m
|
||||
"""
|
||||
if isinstance(m, TypeEnv):
|
||||
return ConstrainTVInTypeset(m[self.tv], self.ts)
|
||||
else:
|
||||
return ConstrainTVInTypeset(subst(self.tv, m), self.ts)
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if (not isinstance(other, ConstrainTVInTypeset)):
|
||||
return False
|
||||
|
||||
return (self.tv, self.ts) == (other.tv, other.ts)
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
return hash((self.tv, self.ts))
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Evaluate this constraint. Should only be called when the constraint has
|
||||
been translated to concrete types.
|
||||
"""
|
||||
assert self.tv.singleton_type() is not None
|
||||
return self.tv.get_typeset().issubset(self.ts)
|
||||
|
||||
|
||||
class TypeEnv(object):
|
||||
"""
|
||||
Class encapsulating the neccessary book keeping for type inference.
|
||||
@@ -43,13 +158,13 @@ class TypeEnv(object):
|
||||
RANK_INTERNAL = 0
|
||||
|
||||
def __init__(self, arg=None):
|
||||
# type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None
|
||||
# type: (Optional[Tuple[TypeMap, List[TypeConstraint]]]) -> None
|
||||
self.ranks = {} # type: Dict[TypeVar, int]
|
||||
self.vars = set() # type: Set[Var]
|
||||
|
||||
if arg is None:
|
||||
self.type_map = {} # type: TypeMap
|
||||
self.constraints = [] # type: ConstraintList
|
||||
self.constraints = [] # type: List[TypeConstraint]
|
||||
else:
|
||||
self.type_map, self.constraints = arg
|
||||
|
||||
@@ -94,7 +209,9 @@ class TypeEnv(object):
|
||||
"""
|
||||
Add a new equivalence constraint between tv1 and tv2
|
||||
"""
|
||||
self.constraints.append((tv1, tv2))
|
||||
constr = ConstrainTVsEqual(tv1, tv2)
|
||||
if (constr not in self.constraints):
|
||||
self.constraints.append(constr)
|
||||
|
||||
def get_uid(self):
|
||||
# type: () -> str
|
||||
@@ -206,15 +323,24 @@ class TypeEnv(object):
|
||||
"""
|
||||
vars_tvs = set([v.get_typevar() for v in self.vars])
|
||||
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
|
||||
new_constraints = [(self[tv1], self[tv2])
|
||||
for (tv1, tv2) in self.constraints]
|
||||
|
||||
# Sanity: new constraints and the new type_map should only contain
|
||||
# tvs associated with real vars
|
||||
for (a, b) in new_constraints:
|
||||
assert a.free_typevar() in vars_tvs and\
|
||||
b.free_typevar() in vars_tvs
|
||||
new_constraints = [] # type: List[TypeConstraint]
|
||||
for constr in self.constraints:
|
||||
# Currently typeinference only generates ConstrainTVsEqual
|
||||
# constraints
|
||||
assert isinstance(constr, ConstrainTVsEqual)
|
||||
constr = constr.translate(self)
|
||||
|
||||
if constr.is_trivial() or constr in new_constraints:
|
||||
continue
|
||||
|
||||
# Sanity: translated constraints should refer to only real vars
|
||||
assert constr.tv1.free_typevar() in vars_tvs and\
|
||||
constr.tv2.free_typevar() in vars_tvs
|
||||
|
||||
new_constraints.append(constr)
|
||||
|
||||
# Sanity: translated typemap should refer to only real vars
|
||||
for (k, v) in new_type_map.items():
|
||||
assert k in vars_tvs
|
||||
assert v.free_typevar() is None or v.free_typevar() in vars_tvs
|
||||
@@ -245,13 +371,13 @@ class TypeEnv(object):
|
||||
|
||||
# Check if constraints are satisfied for this typing
|
||||
failed = None
|
||||
for (tv1, tv2) in self.constraints:
|
||||
tv1 = subst(tv1, m)
|
||||
tv2 = subst(tv2, m)
|
||||
assert tv1.get_typeset().size() == 1 and\
|
||||
tv2.get_typeset().size() == 1
|
||||
if (tv1.get_typeset() != tv2.get_typeset()):
|
||||
failed = (tv1, tv2)
|
||||
for constr in self.constraints:
|
||||
# Currently typeinference only generates ConstrainTVsEqual
|
||||
# constraints
|
||||
assert isinstance(constr, ConstrainTVsEqual)
|
||||
concrete_constr = constr.translate(m)
|
||||
if not concrete_constr.eval():
|
||||
failed = concrete_constr
|
||||
break
|
||||
|
||||
if (failed is not None):
|
||||
@@ -287,9 +413,10 @@ class TypeEnv(object):
|
||||
edges.add((v, v.base, "solid", v.derived_func))
|
||||
v = v.base
|
||||
|
||||
for (a, b) in self.constraints:
|
||||
assert a in nodes and b in nodes
|
||||
edges.add((a, b, "dashed", None))
|
||||
for constr in self.constraints:
|
||||
assert isinstance(constr, ConstrainTVsEqual)
|
||||
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", None))
|
||||
|
||||
root_nodes = set([x for x in nodes
|
||||
if x not in self.type_map and not x.is_derived])
|
||||
|
||||
@@ -41,6 +41,14 @@ class Rtl(object):
|
||||
# type: (*DefApply) -> None
|
||||
self.rtl = tuple(map(canonicalize_defapply, args))
|
||||
|
||||
def copy(self, m):
|
||||
# type: (Dict[Var, Var]) -> Rtl
|
||||
"""
|
||||
Return a copy of this rtl with all Vars substituted with copies or
|
||||
according to m. Update m as neccessary.
|
||||
"""
|
||||
return Rtl(*[d.copy(m) for d in self.rtl])
|
||||
|
||||
|
||||
class XForm(object):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user