Emit runtime type checks in legalizer.rs (#112)

* Emit runtime type checks in legalizer.rs
This commit is contained in:
d1m0
2017-07-10 15:28:32 -07:00
committed by Jakob Stoklund Olesen
parent 528e6ff3f5
commit fc11ae7b72
9 changed files with 494 additions and 69 deletions

View File

@@ -10,7 +10,7 @@ from .typevar import TypeVar
from .predicates import IsEqual, And
try:
from typing import Union, Tuple, Sequence, TYPE_CHECKING # noqa
from typing import Union, Tuple, Sequence, TYPE_CHECKING, Dict, List # noqa
if TYPE_CHECKING:
from .operands import ImmediateKind # noqa
from .predicates import PredNode # noqa
@@ -18,6 +18,19 @@ except ImportError:
pass
def replace_var(arg, m):
# type: (Expr, Dict[Var, Var]) -> Expr
"""
Given a var v return either m[v] or a new variable v' (and remember
m[v]=v'). Otherwise return the argument unchanged
"""
if isinstance(arg, Var):
new_arg = m.get(arg, Var(arg.name)) # type: Var
m[arg] = new_arg
return new_arg
return arg
class Def(object):
"""
An AST definition associates a set of variables with the values produced by
@@ -60,6 +73,21 @@ class Def(object):
return "({}) << {!s}".format(
', '.join(map(str, self.defs)), self.expr)
def copy(self, m):
# type: (Dict[Var, Var]) -> Def
"""
Return a copy of this Def with vars replaced with fresh variables,
in accordance with the map m. Update m as neccessary.
"""
new_expr = self.expr.copy(m)
new_defs = [] # type: List[Var]
for v in self.defs:
new_v = replace_var(v, m)
assert(isinstance(new_v, Var))
new_defs.append(new_v)
return Def(tuple(new_defs), new_expr)
class Expr(object):
"""
@@ -303,6 +331,15 @@ class Apply(Expr):
return pred
def copy(self, m):
# type: (Dict[Var, Var]) -> Apply
"""
Return a copy of this Expr with vars replaced with fresh variables,
in accordance with the map m. Update m as neccessary.
"""
return Apply(self.inst, tuple(map(lambda e: replace_var(e, m),
self.args)))
class Enumerator(Expr):
"""

View File

@@ -6,30 +6,17 @@ from base.immediates import intcc
from .typevar import TypeVar
from .ast import Var, Def
from .xform import Rtl, XForm
from .ti import ti_rtl, subst, TypeEnv, get_type_env
from .ti import ti_rtl, subst, TypeEnv, get_type_env, ConstrainTVsEqual
from unittest import TestCase
from functools import reduce
try:
from .ti import TypeMap, ConstraintList, VarMap, TypingOrError # noqa
from .ti import Constraint
from typing import List, Dict, Tuple, TYPE_CHECKING, cast # noqa
except ImportError:
TYPE_CHECKING = False
def sort_constr(c):
# type: (Constraint) -> Constraint
"""
Sort the 2 typevars in a constraint by name for comparison
"""
r = tuple(sorted(c, key=lambda y: y.name))
if TYPE_CHECKING:
return cast(Constraint, r)
else:
return r
def agree(me, other):
# type: (TypeEnv, TypeEnv) -> bool
"""
@@ -63,13 +50,10 @@ def agree(me, other):
return False
# Translate our constraints using m, and sort
me_equiv_constr = [(subst(a, m), subst(b, m)) for (a, b) in me.constraints]
me_equiv_constr = sorted([sort_constr(x) for x in me_equiv_constr])
me_equiv_constr = sorted([constr.translate(m)
for constr in me.constraints])
# Sort other's constraints
other_equiv_constr = sorted([sort_constr(x) for x in other.constraints],
key=lambda y: y[0].name)
other_equiv_constr = sorted(other.constraints)
return me_equiv_constr == other_equiv_constr
@@ -224,7 +208,7 @@ class TestRTL(TypeCheckingBaseTest):
self.v3: txn,
self.v4: txn,
self.v5: txn,
}, [(ixn.as_bool(), txn.as_bool())]))
}, [ConstrainTVsEqual(ixn.as_bool(), txn.as_bool())]))
def test_vselect_vsplits(self):
# type: () -> None

View File

@@ -8,13 +8,12 @@ from itertools import product
try:
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
from typing import Iterable # noqa
from typing import cast, List
from typing import Iterable, List # noqa
from typing import cast
from .xform import Rtl, XForm # noqa
from .ast import Expr # noqa
from .typevar import TypeSet # noqa
if TYPE_CHECKING:
Constraint = Tuple[TypeVar, TypeVar]
ConstraintList = List[Constraint]
TypeMap = Dict[TypeVar, TypeVar]
VarMap = Dict[Var, TypeVar]
except ImportError:
@@ -22,6 +21,122 @@ except ImportError:
pass
class TypeConstraint(object):
"""
Base class for all runtime-emittable type constraints.
"""
class ConstrainTVsEqual(TypeConstraint):
"""
Constraint specifying that two derived type vars must have the same runtime
type.
"""
def __init__(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
assert tv1.is_derived and tv2.is_derived
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
return self.tv1 == self.tv2 or \
(self.tv1.singleton_type() is not None and
self.tv2.singleton_type() is not None)
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVsEqual
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVsEqual(m[self.tv1], m[self.tv2])
else:
return ConstrainTVsEqual(subst(self.tv1, m), subst(self.tv2, m))
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVsEqual)):
return False
return (self.tv1, self.tv2) == (other.tv1, other.tv2)
def __hash__(self):
# type: () -> int
return hash((self.tv1, self.tv2))
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert self.tv1.singleton_type() is not None and \
self.tv2.singleton_type() is not None
return self.tv1.singleton_type() == self.tv2.singleton_type()
class ConstrainTVInTypeset(TypeConstraint):
"""
Constraint specifying that a type var must belong to some typeset.
"""
def __init__(self, tv, ts):
# type: (TypeVar, TypeSet) -> None
assert not tv.is_derived and tv.name.startswith("typeof_")
self.tv = tv
self.ts = ts
def is_trivial(self):
# type: () -> bool
"""
Return true if this constrain is statically decidable.
"""
tv_ts = self.tv.get_typeset().copy()
# Trivially True
if (tv_ts.issubset(self.ts)):
return True
# Trivially false
tv_ts &= self.ts
if (tv_ts.size() == 0):
return True
return False
def translate(self, m):
# type: (Union[TypeEnv, TypeMap]) -> ConstrainTVInTypeset
"""
Translate any TypeVars in the constraint according to the map m
"""
if isinstance(m, TypeEnv):
return ConstrainTVInTypeset(m[self.tv], self.ts)
else:
return ConstrainTVInTypeset(subst(self.tv, m), self.ts)
def __eq__(self, other):
# type: (object) -> bool
if (not isinstance(other, ConstrainTVInTypeset)):
return False
return (self.tv, self.ts) == (other.tv, other.ts)
def __hash__(self):
# type: () -> int
return hash((self.tv, self.ts))
def eval(self):
# type: () -> bool
"""
Evaluate this constraint. Should only be called when the constraint has
been translated to concrete types.
"""
assert self.tv.singleton_type() is not None
return self.tv.get_typeset().issubset(self.ts)
class TypeEnv(object):
"""
Class encapsulating the neccessary book keeping for type inference.
@@ -43,13 +158,13 @@ class TypeEnv(object):
RANK_INTERNAL = 0
def __init__(self, arg=None):
# type: (Optional[Tuple[TypeMap, ConstraintList]]) -> None
# type: (Optional[Tuple[TypeMap, List[TypeConstraint]]]) -> None
self.ranks = {} # type: Dict[TypeVar, int]
self.vars = set() # type: Set[Var]
if arg is None:
self.type_map = {} # type: TypeMap
self.constraints = [] # type: ConstraintList
self.constraints = [] # type: List[TypeConstraint]
else:
self.type_map, self.constraints = arg
@@ -94,7 +209,9 @@ class TypeEnv(object):
"""
Add a new equivalence constraint between tv1 and tv2
"""
self.constraints.append((tv1, tv2))
constr = ConstrainTVsEqual(tv1, tv2)
if (constr not in self.constraints):
self.constraints.append(constr)
def get_uid(self):
# type: () -> str
@@ -206,15 +323,24 @@ class TypeEnv(object):
"""
vars_tvs = set([v.get_typevar() for v in self.vars])
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
new_constraints = [(self[tv1], self[tv2])
for (tv1, tv2) in self.constraints]
# Sanity: new constraints and the new type_map should only contain
# tvs associated with real vars
for (a, b) in new_constraints:
assert a.free_typevar() in vars_tvs and\
b.free_typevar() in vars_tvs
new_constraints = [] # type: List[TypeConstraint]
for constr in self.constraints:
# Currently typeinference only generates ConstrainTVsEqual
# constraints
assert isinstance(constr, ConstrainTVsEqual)
constr = constr.translate(self)
if constr.is_trivial() or constr in new_constraints:
continue
# Sanity: translated constraints should refer to only real vars
assert constr.tv1.free_typevar() in vars_tvs and\
constr.tv2.free_typevar() in vars_tvs
new_constraints.append(constr)
# Sanity: translated typemap should refer to only real vars
for (k, v) in new_type_map.items():
assert k in vars_tvs
assert v.free_typevar() is None or v.free_typevar() in vars_tvs
@@ -245,13 +371,13 @@ class TypeEnv(object):
# Check if constraints are satisfied for this typing
failed = None
for (tv1, tv2) in self.constraints:
tv1 = subst(tv1, m)
tv2 = subst(tv2, m)
assert tv1.get_typeset().size() == 1 and\
tv2.get_typeset().size() == 1
if (tv1.get_typeset() != tv2.get_typeset()):
failed = (tv1, tv2)
for constr in self.constraints:
# Currently typeinference only generates ConstrainTVsEqual
# constraints
assert isinstance(constr, ConstrainTVsEqual)
concrete_constr = constr.translate(m)
if not concrete_constr.eval():
failed = concrete_constr
break
if (failed is not None):
@@ -287,9 +413,10 @@ class TypeEnv(object):
edges.add((v, v.base, "solid", v.derived_func))
v = v.base
for (a, b) in self.constraints:
assert a in nodes and b in nodes
edges.add((a, b, "dashed", None))
for constr in self.constraints:
assert isinstance(constr, ConstrainTVsEqual)
assert constr.tv1 in nodes and constr.tv2 in nodes
edges.add((constr.tv1, constr.tv2, "dashed", None))
root_nodes = set([x for x in nodes
if x not in self.type_map and not x.is_derived])

View File

@@ -41,6 +41,14 @@ class Rtl(object):
# type: (*DefApply) -> None
self.rtl = tuple(map(canonicalize_defapply, args))
def copy(self, m):
# type: (Dict[Var, Var]) -> Rtl
"""
Return a copy of this rtl with all Vars substituted with copies or
according to m. Update m as neccessary.
"""
return Rtl(*[d.copy(m) for d in self.rtl])
class XForm(object):
"""