moved crates in lib/ to src/, renamed crates, modified some files' text (#660)
moved crates in lib/ to src/, renamed crates, modified some files' text (#660)
This commit is contained in:
894
cranelift/codegen/meta-python/cdsl/ti.py
Normal file
894
cranelift/codegen/meta-python/cdsl/ti.py
Normal file
@@ -0,0 +1,894 @@
|
||||
"""
|
||||
Type Inference
|
||||
"""
|
||||
from .typevar import TypeVar
|
||||
from .ast import Def, Var
|
||||
from copy import copy
|
||||
from itertools import product
|
||||
|
||||
try:
|
||||
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
||||
from typing import Iterable, List, Any, TypeVar as MTypeVar # noqa
|
||||
from typing import cast
|
||||
from .xform import Rtl, XForm # noqa
|
||||
from .ast import Expr # noqa
|
||||
from .typevar import TypeSet # noqa
|
||||
if TYPE_CHECKING:
|
||||
T = MTypeVar('T')
|
||||
TypeMap = Dict[TypeVar, TypeVar]
|
||||
VarTyping = Dict[Var, TypeVar]
|
||||
except ImportError:
|
||||
TYPE_CHECKING = False
|
||||
pass
|
||||
|
||||
|
||||
class TypeConstraint(object):
|
||||
"""
|
||||
Base class for all runtime-emittable type constraints.
|
||||
"""
|
||||
|
||||
def __init__(self, tv, tc):
|
||||
# type: (TypeVar, Union[TypeVar, TypeSet]) -> None
|
||||
"""
|
||||
Abstract "constructor" for linters
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def translate(self, m):
|
||||
# type: (Union[TypeEnv, TypeMap]) -> TypeConstraint
|
||||
"""
|
||||
Translate any TypeVars in the constraint according to the map or
|
||||
TypeEnv m
|
||||
"""
|
||||
def translate_one(a):
|
||||
# type: (Any) -> Any
|
||||
if (isinstance(a, TypeVar)):
|
||||
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
|
||||
return a
|
||||
|
||||
res = None # type: TypeConstraint
|
||||
res = self.__class__(*tuple(map(translate_one, self._args())))
|
||||
return res
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if (not isinstance(other, self.__class__)):
|
||||
return False
|
||||
|
||||
assert isinstance(other, TypeConstraint) # help MyPy figure out other
|
||||
return self._args() == other._args()
|
||||
|
||||
def is_concrete(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true iff all typevars in the constraint are singletons.
|
||||
"""
|
||||
return [] == list(filter(lambda x: x.singleton_type() is None,
|
||||
self.tvs()))
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
return hash(self._args())
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
"""
|
||||
Return a tuple with the exact arguments passed to __init__ to create
|
||||
this object.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def tvs(self):
|
||||
# type: () -> Iterable[TypeVar]
|
||||
"""
|
||||
Return the typevars contained in this constraint.
|
||||
"""
|
||||
return list(filter(lambda x: isinstance(x, TypeVar), self._args()))
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return true if this constrain is statically decidable.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Evaluate this constraint. Should only be called when the constraint has
|
||||
been translated to concrete types.
|
||||
"""
|
||||
assert False, "Abstract"
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return (self.__class__.__name__ + '(' +
|
||||
', '.join(map(str, self._args())) + ')')
|
||||
|
||||
|
||||
class TypesEqual(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that two derived type vars must have the same runtime
|
||||
type.
|
||||
"""
|
||||
def __init__(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv1, self.tv2)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
return self.tv1 == self.tv2 or self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
return self.tv1.singleton_type() == self.tv2.singleton_type()
|
||||
|
||||
|
||||
class InTypeset(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that a type var must belong to some typeset.
|
||||
"""
|
||||
def __init__(self, tv, ts):
|
||||
# type: (TypeVar, TypeSet) -> None
|
||||
assert not tv.is_derived and tv.name.startswith("typeof_")
|
||||
self.tv = tv
|
||||
self.ts = ts
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv, self.ts)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
tv_ts = self.tv.get_typeset().copy()
|
||||
|
||||
# Trivially True
|
||||
if (tv_ts.issubset(self.ts)):
|
||||
return True
|
||||
|
||||
# Trivially false
|
||||
tv_ts &= self.ts
|
||||
if (tv_ts.size() == 0):
|
||||
return True
|
||||
|
||||
return self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
return self.tv.get_typeset().issubset(self.ts)
|
||||
|
||||
|
||||
class WiderOrEq(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that a type var tv1 must be wider than or equal to
|
||||
type var tv2 at runtime. This requires that:
|
||||
1) They have the same number of lanes
|
||||
2) In a lane tv1 has at least as many bits as tv2.
|
||||
"""
|
||||
def __init__(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
self.tv1 = tv1
|
||||
self.tv2 = tv2
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv1, self.tv2)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
# Trivially true
|
||||
if (self.tv1 == self.tv2):
|
||||
return True
|
||||
|
||||
ts1 = self.tv1.get_typeset()
|
||||
ts2 = self.tv2.get_typeset()
|
||||
|
||||
def set_wider_or_equal(s1, s2):
|
||||
# type: (Set[int], Set[int]) -> bool
|
||||
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
|
||||
|
||||
# Trivially True
|
||||
if set_wider_or_equal(ts1.ints, ts2.ints) and\
|
||||
set_wider_or_equal(ts1.floats, ts2.floats) and\
|
||||
set_wider_or_equal(ts1.bools, ts2.bools):
|
||||
return True
|
||||
|
||||
def set_narrower(s1, s2):
|
||||
# type: (Set[int], Set[int]) -> bool
|
||||
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
|
||||
|
||||
# Trivially False
|
||||
if set_narrower(ts1.ints, ts2.ints) and\
|
||||
set_narrower(ts1.floats, ts2.floats) and\
|
||||
set_narrower(ts1.bools, ts2.bools):
|
||||
return True
|
||||
|
||||
# Trivially False
|
||||
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
|
||||
return True
|
||||
|
||||
return self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
typ1 = self.tv1.singleton_type()
|
||||
typ2 = self.tv2.singleton_type()
|
||||
|
||||
return typ1.wider_or_equal(typ2)
|
||||
|
||||
|
||||
class SameWidth(TypeConstraint):
|
||||
"""
|
||||
Constraint specifying that two types have the same width. E.g. i32x2 has
|
||||
the same width as i64x1, i16x4, f32x2, f64, b1x64 etc.
|
||||
"""
|
||||
def __init__(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
self.tv1 = tv1
|
||||
self.tv2 = tv2
|
||||
|
||||
def _args(self):
|
||||
# type: () -> Tuple[Any,...]
|
||||
""" See TypeConstraint._args() """
|
||||
return (self.tv1, self.tv2)
|
||||
|
||||
def is_trivial(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.is_trivial() """
|
||||
# Trivially true
|
||||
if (self.tv1 == self.tv2):
|
||||
return True
|
||||
|
||||
ts1 = self.tv1.get_typeset()
|
||||
ts2 = self.tv2.get_typeset()
|
||||
|
||||
# Trivially False
|
||||
if len(ts1.widths().intersection(ts2.widths())) == 0:
|
||||
return True
|
||||
|
||||
return self.is_concrete()
|
||||
|
||||
def eval(self):
|
||||
# type: () -> bool
|
||||
""" See TypeConstraint.eval() """
|
||||
assert self.is_concrete()
|
||||
typ1 = self.tv1.singleton_type()
|
||||
typ2 = self.tv2.singleton_type()
|
||||
|
||||
return (typ1.width() == typ2.width())
|
||||
|
||||
|
||||
class TypeEnv(object):
|
||||
"""
|
||||
Class encapsulating the necessary book keeping for type inference.
|
||||
:attribute type_map: dict holding the equivalence relations between tvs
|
||||
:attribute constraints: a list of accumulated constraints - tuples
|
||||
(tv1, tv2)) where tv1 and tv2 are equal
|
||||
:attribute ranks: dictionary recording the (optional) ranks for tvs.
|
||||
'rank' is a partial ordering on TVs based on their
|
||||
origin. See comments in rank() and register().
|
||||
:attribute vars: a set containing all known Vars
|
||||
:attribute idx: counter used to get fresh ids
|
||||
"""
|
||||
|
||||
RANK_SINGLETON = 5
|
||||
RANK_INPUT = 4
|
||||
RANK_INTERMEDIATE = 3
|
||||
RANK_OUTPUT = 2
|
||||
RANK_TEMP = 1
|
||||
RANK_INTERNAL = 0
|
||||
|
||||
def __init__(self, arg=None):
|
||||
# type: (Optional[Tuple[TypeMap, List[TypeConstraint]]]) -> None
|
||||
self.ranks = {} # type: Dict[TypeVar, int]
|
||||
self.vars = set() # type: Set[Var]
|
||||
|
||||
if arg is None:
|
||||
self.type_map = {} # type: TypeMap
|
||||
self.constraints = [] # type: List[TypeConstraint]
|
||||
else:
|
||||
self.type_map, self.constraints = arg
|
||||
|
||||
self.idx = 0
|
||||
|
||||
def __getitem__(self, arg):
|
||||
# type: (Union[TypeVar, Var]) -> TypeVar
|
||||
"""
|
||||
Lookup the canonical representative for a Var/TypeVar.
|
||||
"""
|
||||
if (isinstance(arg, Var)):
|
||||
assert arg in self.vars
|
||||
tv = arg.get_typevar()
|
||||
else:
|
||||
assert (isinstance(arg, TypeVar))
|
||||
tv = arg
|
||||
|
||||
while tv in self.type_map:
|
||||
tv = self.type_map[tv]
|
||||
|
||||
if tv.is_derived:
|
||||
tv = TypeVar.derived(self[tv.base], tv.derived_func)
|
||||
return tv
|
||||
|
||||
def equivalent(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
"""
|
||||
Record a that the free tv1 is part of the same equivalence class as
|
||||
tv2. The canonical representative of the merged class is tv2's
|
||||
canonical representative.
|
||||
"""
|
||||
assert not tv1.is_derived
|
||||
assert self[tv1] == tv1
|
||||
|
||||
# Make sure we don't create cycles
|
||||
if tv2.is_derived:
|
||||
assert self[tv2.base] != tv1
|
||||
|
||||
self.type_map[tv1] = tv2
|
||||
|
||||
def add_constraint(self, constr):
|
||||
# type: (TypeConstraint) -> None
|
||||
"""
|
||||
Add a new constraint
|
||||
"""
|
||||
if (constr in self.constraints):
|
||||
return
|
||||
|
||||
# InTypeset constraints can be expressed by constraining the typeset of
|
||||
# a variable. No need to add them to self.constraints
|
||||
if (isinstance(constr, InTypeset)):
|
||||
self[constr.tv].constrain_types_by_ts(constr.ts)
|
||||
return
|
||||
|
||||
self.constraints.append(constr)
|
||||
|
||||
def get_uid(self):
|
||||
# type: () -> str
|
||||
r = str(self.idx)
|
||||
self.idx += 1
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return self.dot()
|
||||
|
||||
def rank(self, tv):
|
||||
# type: (TypeVar) -> int
|
||||
"""
|
||||
Get the rank of tv in the partial order. TVs directly associated with a
|
||||
Var get their rank from the Var (see register()). Internally generated
|
||||
non-derived TVs implicitly get the lowest rank (0). Derived variables
|
||||
get their rank from their free typevar. Singletons have the highest
|
||||
rank. TVs associated with vars in a source pattern have a higher rank
|
||||
than TVs associated with temporary vars.
|
||||
"""
|
||||
default_rank = TypeEnv.RANK_INTERNAL if tv.singleton_type() is None \
|
||||
else TypeEnv.RANK_SINGLETON
|
||||
|
||||
if tv.is_derived:
|
||||
tv = tv.free_typevar()
|
||||
|
||||
return self.ranks.get(tv, default_rank)
|
||||
|
||||
def register(self, v):
|
||||
# type: (Var) -> None
|
||||
"""
|
||||
Register a new Var v. This computes a rank for the associated TypeVar
|
||||
for v, which is used to impose a partial order on type variables.
|
||||
"""
|
||||
self.vars.add(v)
|
||||
|
||||
if v.is_input():
|
||||
r = TypeEnv.RANK_INPUT
|
||||
elif v.is_intermediate():
|
||||
r = TypeEnv.RANK_INTERMEDIATE
|
||||
elif v.is_output():
|
||||
r = TypeEnv.RANK_OUTPUT
|
||||
else:
|
||||
assert(v.is_temp())
|
||||
r = TypeEnv.RANK_TEMP
|
||||
|
||||
self.ranks[v.get_typevar()] = r
|
||||
|
||||
def free_typevars(self):
|
||||
# type: () -> List[TypeVar]
|
||||
"""
|
||||
Get the free typevars in the current type env.
|
||||
"""
|
||||
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
|
||||
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
|
||||
# Filter out None here due to singleton type vars
|
||||
return sorted(filter(lambda x: x is not None, tvs),
|
||||
key=lambda x: x.name)
|
||||
|
||||
def normalize(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Normalize by:
|
||||
- collapsing any roots that don't correspond to a concrete TV AND
|
||||
have a single TV derived from them or equivalent to them
|
||||
|
||||
E.g. if we have a root of the tree that looks like:
|
||||
|
||||
typeof_a typeof_b
|
||||
\\ /
|
||||
typeof_x
|
||||
|
|
||||
half_width(1)
|
||||
|
|
||||
1
|
||||
|
||||
we want to collapse the linear path between 1 and typeof_x. The
|
||||
resulting graph is:
|
||||
|
||||
typeof_a typeof_b
|
||||
\\ /
|
||||
typeof_x
|
||||
"""
|
||||
source_tvs = set([v.get_typevar() for v in self.vars])
|
||||
children = {} # type: Dict[TypeVar, Set[TypeVar]]
|
||||
for v in self.type_map.values():
|
||||
if not v.is_derived:
|
||||
continue
|
||||
|
||||
t = v.free_typevar()
|
||||
s = children.get(t, set())
|
||||
s.add(v)
|
||||
children[t] = s
|
||||
|
||||
for (a, b) in self.type_map.items():
|
||||
s = children.get(b, set())
|
||||
s.add(a)
|
||||
children[b] = s
|
||||
|
||||
for r in self.free_typevars():
|
||||
while (r not in source_tvs and r in children and
|
||||
len(children[r]) == 1):
|
||||
child = list(children[r])[0]
|
||||
if child in self.type_map:
|
||||
assert self.type_map[child] == r
|
||||
del self.type_map[child]
|
||||
|
||||
r = child
|
||||
|
||||
def extract(self):
|
||||
# type: () -> TypeEnv
|
||||
"""
|
||||
Extract a clean type environment from self, that only mentions
|
||||
TVs associated with real variables
|
||||
"""
|
||||
vars_tvs = set([v.get_typevar() for v in self.vars])
|
||||
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
|
||||
|
||||
new_constraints = [] # type: List[TypeConstraint]
|
||||
for constr in self.constraints:
|
||||
constr = constr.translate(self)
|
||||
|
||||
if constr.is_trivial() or constr in new_constraints:
|
||||
continue
|
||||
|
||||
# Sanity: translated constraints should refer to only real vars
|
||||
for arg in constr._args():
|
||||
if (not isinstance(arg, TypeVar)):
|
||||
continue
|
||||
|
||||
arg_free_tv = arg.free_typevar()
|
||||
assert arg_free_tv is None or arg_free_tv in vars_tvs
|
||||
|
||||
new_constraints.append(constr)
|
||||
|
||||
# Sanity: translated typemap should refer to only real vars
|
||||
for (k, v) in new_type_map.items():
|
||||
assert k in vars_tvs
|
||||
assert v.free_typevar() is None or v.free_typevar() in vars_tvs
|
||||
|
||||
t = TypeEnv()
|
||||
t.type_map = new_type_map
|
||||
t.constraints = new_constraints
|
||||
# ranks and vars contain only TVs associated with real vars
|
||||
t.ranks = copy(self.ranks)
|
||||
t.vars = copy(self.vars)
|
||||
return t
|
||||
|
||||
def concrete_typings(self):
|
||||
# type: () -> Iterable[VarTyping]
|
||||
"""
|
||||
Return an iterable over all possible concrete typings permitted by this
|
||||
TypeEnv.
|
||||
"""
|
||||
free_tvs = self.free_typevars()
|
||||
free_tv_iters = [tv.get_typeset().concrete_types() for tv in free_tvs]
|
||||
for concrete_types in product(*free_tv_iters):
|
||||
# Build type substitutions for all free vars
|
||||
m = {tv: TypeVar.singleton(typ)
|
||||
for (tv, typ) in zip(free_tvs, concrete_types)}
|
||||
|
||||
concrete_var_map = {v: subst(self[v.get_typevar()], m)
|
||||
for v in self.vars}
|
||||
|
||||
# Check if constraints are satisfied for this typing
|
||||
failed = None
|
||||
for constr in self.constraints:
|
||||
concrete_constr = constr.translate(m)
|
||||
if not concrete_constr.eval():
|
||||
failed = concrete_constr
|
||||
break
|
||||
|
||||
if (failed is not None):
|
||||
continue
|
||||
|
||||
yield concrete_var_map
|
||||
|
||||
def permits(self, concrete_typing):
|
||||
# type: (VarTyping) -> bool
|
||||
"""
|
||||
Return true iff this TypeEnv permits the (possibly partial) concrete
|
||||
variable type mapping concrete_typing.
|
||||
"""
|
||||
# Each variable has a concrete type, that is a subset of its inferred
|
||||
# typeset.
|
||||
for (v, typ) in concrete_typing.items():
|
||||
assert typ.singleton_type() is not None
|
||||
if not typ.get_typeset().issubset(self[v].get_typeset()):
|
||||
return False
|
||||
|
||||
m = {self[v]: typ for (v, typ) in concrete_typing.items()}
|
||||
|
||||
# Constraints involving vars in concrete_typing are satisfied
|
||||
for constr in self.constraints:
|
||||
try:
|
||||
# If the constraint includes only vars in concrete_typing, we
|
||||
# can translate it using m. Otherwise we encounter a KeyError
|
||||
# and ignore it
|
||||
constr = constr.translate(m)
|
||||
if not constr.eval():
|
||||
return False
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
def dot(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
Return a representation of self as a graph in dot format.
|
||||
Nodes correspond to TypeVariables.
|
||||
Dotted edges correspond to equivalences between TVS
|
||||
Solid edges correspond to derivation relations between TVs.
|
||||
Dashed edges correspond to equivalence constraints.
|
||||
"""
|
||||
def label(s):
|
||||
# type: (TypeVar) -> str
|
||||
return "\"" + str(s) + "\""
|
||||
|
||||
# Add all registered TVs (as some of them may be singleton nodes not
|
||||
# appearing in the graph
|
||||
nodes = set() # type: Set[TypeVar]
|
||||
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
|
||||
|
||||
def add_nodes(*args):
|
||||
# type: (*TypeVar) -> None
|
||||
for tv in args:
|
||||
nodes.add(tv)
|
||||
while (tv.is_derived):
|
||||
nodes.add(tv.base)
|
||||
edges.add((tv, tv.base, "solid", "forward",
|
||||
tv.derived_func))
|
||||
tv = tv.base
|
||||
|
||||
for v in self.vars:
|
||||
add_nodes(v.get_typevar())
|
||||
|
||||
for (tv1, tv2) in self.type_map.items():
|
||||
# Add all intermediate TVs appearing in edges
|
||||
add_nodes(tv1, tv2)
|
||||
edges.add((tv1, tv2, "dotted", "forward", None))
|
||||
|
||||
for constr in self.constraints:
|
||||
if isinstance(constr, TypesEqual):
|
||||
add_nodes(constr.tv1, constr.tv2)
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
|
||||
elif isinstance(constr, WiderOrEq):
|
||||
add_nodes(constr.tv1, constr.tv2)
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
|
||||
elif isinstance(constr, SameWidth):
|
||||
add_nodes(constr.tv1, constr.tv2)
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "none",
|
||||
"same_width"))
|
||||
else:
|
||||
assert False, "Can't display constraint {}".format(constr)
|
||||
|
||||
root_nodes = set([x for x in nodes
|
||||
if x not in self.type_map and not x.is_derived])
|
||||
|
||||
r = "digraph {\n"
|
||||
for n in nodes:
|
||||
r += label(n)
|
||||
if n in root_nodes:
|
||||
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
|
||||
r += ";\n"
|
||||
|
||||
for (n1, n2, style, direction, elabel) in edges:
|
||||
e = label(n1) + "->" + label(n2)
|
||||
e += "[style={},dir={}".format(style, direction)
|
||||
|
||||
if elabel is not None:
|
||||
e += ",label=\"{}\"".format(elabel)
|
||||
e += "];\n"
|
||||
|
||||
r += e
|
||||
r += "}"
|
||||
|
||||
return r
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
TypingError = str
|
||||
TypingOrError = Union[TypeEnv, TypingError]
|
||||
|
||||
|
||||
def get_error(typing_or_err):
|
||||
# type: (TypingOrError) -> Optional[TypingError]
|
||||
"""
|
||||
Helper function to appease mypy when checking the result of typing.
|
||||
"""
|
||||
if isinstance(typing_or_err, str):
|
||||
if (TYPE_CHECKING):
|
||||
return cast(TypingError, typing_or_err)
|
||||
else:
|
||||
return typing_or_err
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_type_env(typing_or_err):
|
||||
# type: (TypingOrError) -> TypeEnv
|
||||
"""
|
||||
Helper function to appease mypy when checking the result of typing.
|
||||
"""
|
||||
assert isinstance(typing_or_err, TypeEnv), \
|
||||
"Unexpected error: {}".format(typing_or_err)
|
||||
|
||||
if (TYPE_CHECKING):
|
||||
return cast(TypeEnv, typing_or_err)
|
||||
else:
|
||||
return typing_or_err
|
||||
|
||||
|
||||
def subst(tv, tv_map):
|
||||
# type: (TypeVar, TypeMap) -> TypeVar
|
||||
"""
|
||||
Perform substition on the input tv using the TypeMap tv_map.
|
||||
"""
|
||||
if tv in tv_map:
|
||||
return tv_map[tv]
|
||||
|
||||
if tv.is_derived:
|
||||
return TypeVar.derived(subst(tv.base, tv_map), tv.derived_func)
|
||||
|
||||
return tv
|
||||
|
||||
|
||||
def normalize_tv(tv):
|
||||
# type: (TypeVar) -> TypeVar
|
||||
"""
|
||||
Normalize a (potentially derived) TV using the following rules:
|
||||
- vector and width derived functions commute
|
||||
{HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) ->
|
||||
{HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base))
|
||||
|
||||
- half/double pairs collapse
|
||||
{HALF,DOUBLE}WIDTH({DOUBLE,HALF}WIDTH(base)) -> base
|
||||
{HALF,DOUBLE}VECTOR({DOUBLE,HALF}VECTOR(base)) -> base
|
||||
"""
|
||||
vector_derives = [TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR]
|
||||
width_derives = [TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
|
||||
|
||||
if not tv.is_derived:
|
||||
return tv
|
||||
|
||||
df = tv.derived_func
|
||||
|
||||
if (tv.base.is_derived):
|
||||
base_df = tv.base.derived_func
|
||||
|
||||
# Reordering: {HALFWIDTH, DOUBLEWIDTH} commute with {HALFVECTOR,
|
||||
# DOUBLEVECTOR}. Arbitrarily pick WIDTH < VECTOR
|
||||
if df in vector_derives and base_df in width_derives:
|
||||
return normalize_tv(
|
||||
TypeVar.derived(
|
||||
TypeVar.derived(tv.base.base, df), base_df))
|
||||
|
||||
# Cancelling: HALFWIDTH, DOUBLEWIDTH and HALFVECTOR, DOUBLEVECTOR
|
||||
# cancel each other. Note: This doesn't hide any over/underflows,
|
||||
# since we 1) assert the safety of each TV in the chain upon its
|
||||
# creation, and 2) the base typeset is only allowed to shrink.
|
||||
|
||||
if (df, base_df) in \
|
||||
[(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR),
|
||||
(TypeVar.DOUBLEVECTOR, TypeVar.HALFVECTOR),
|
||||
(TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH),
|
||||
(TypeVar.DOUBLEWIDTH, TypeVar.HALFWIDTH)]:
|
||||
return normalize_tv(tv.base.base)
|
||||
|
||||
return TypeVar.derived(normalize_tv(tv.base), df)
|
||||
|
||||
|
||||
def constrain_fixpoint(tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
"""
|
||||
Given typevars tv1 and tv2 (which could be derived from one another)
|
||||
constrain their typesets to be the same. When one is derived from the
|
||||
other, repeat the constrain process until fixpoint.
|
||||
"""
|
||||
# Constrain tv2's typeset as long as tv1's typeset is changing.
|
||||
while True:
|
||||
old_tv1_ts = tv1.get_typeset().copy()
|
||||
tv2.constrain_types(tv1)
|
||||
if tv1.get_typeset() == old_tv1_ts:
|
||||
break
|
||||
|
||||
old_tv2_ts = tv2.get_typeset().copy()
|
||||
tv1.constrain_types(tv2)
|
||||
assert old_tv2_ts == tv2.get_typeset()
|
||||
|
||||
|
||||
def unify(tv1, tv2, typ):
|
||||
# type: (TypeVar, TypeVar, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Unify tv1 and tv2 in the current type environment typ, and return an
|
||||
updated type environment or error.
|
||||
"""
|
||||
tv1 = normalize_tv(typ[tv1])
|
||||
tv2 = normalize_tv(typ[tv2])
|
||||
|
||||
# Already unified
|
||||
if tv1 == tv2:
|
||||
return typ
|
||||
|
||||
if typ.rank(tv2) < typ.rank(tv1):
|
||||
return unify(tv2, tv1, typ)
|
||||
|
||||
constrain_fixpoint(tv1, tv2)
|
||||
|
||||
if (tv1.get_typeset().size() == 0 or tv2.get_typeset().size() == 0):
|
||||
return "Error: empty type created when unifying {} and {}"\
|
||||
.format(tv1, tv2)
|
||||
|
||||
# Free -> Derived(Free)
|
||||
if not tv1.is_derived:
|
||||
typ.equivalent(tv1, tv2)
|
||||
return typ
|
||||
|
||||
if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)):
|
||||
inv_f = TypeVar.inverse_func(tv1.derived_func)
|
||||
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
|
||||
|
||||
typ.add_constraint(TypesEqual(tv1, tv2))
|
||||
return typ
|
||||
|
||||
|
||||
def move_first(l, i):
|
||||
# type: (List[T], int) -> List[T]
|
||||
return [l[i]] + l[:i] + l[i+1:]
|
||||
|
||||
|
||||
def ti_def(definition, typ):
|
||||
# type: (Def, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Perform type inference on one Def in the current type environment typ and
|
||||
return an updated type environment or error.
|
||||
|
||||
At a high level this works by creating fresh copies of each formal type var
|
||||
in the Def's instruction's signature, and unifying the formal tv with the
|
||||
corresponding actual tv.
|
||||
"""
|
||||
expr = definition.expr
|
||||
inst = expr.inst
|
||||
|
||||
# Create a dict m mapping each free typevar in the signature of definition
|
||||
# to a fresh copy of itself.
|
||||
free_formal_tvs = inst.all_typevars()
|
||||
m = {tv: tv.get_fresh_copy(str(typ.get_uid())) for tv in free_formal_tvs}
|
||||
|
||||
# Update m with any explicitly bound type vars
|
||||
for (idx, bound_typ) in enumerate(expr.typevars):
|
||||
m[free_formal_tvs[idx]] = TypeVar.singleton(bound_typ)
|
||||
|
||||
# Get fresh copies for each typevar in the signature (both free and
|
||||
# derived)
|
||||
fresh_formal_tvs = \
|
||||
[subst(inst.outs[i].typevar, m) for i in inst.value_results] +\
|
||||
[subst(inst.ins[i].typevar, m) for i in inst.value_opnums]
|
||||
|
||||
# Get the list of actual Vars
|
||||
actual_vars = [] # type: List[Expr]
|
||||
actual_vars += [definition.defs[i] for i in inst.value_results]
|
||||
actual_vars += [expr.args[i] for i in inst.value_opnums]
|
||||
|
||||
# Get the list of the actual TypeVars
|
||||
actual_tvs = []
|
||||
for v in actual_vars:
|
||||
assert(isinstance(v, Var))
|
||||
# Register with TypeEnv that this typevar corresponds ot variable v,
|
||||
# and thus has a given rank
|
||||
typ.register(v)
|
||||
actual_tvs.append(v.get_typevar())
|
||||
|
||||
# Make sure we unify the control typevar first.
|
||||
if inst.is_polymorphic:
|
||||
idx = fresh_formal_tvs.index(m[inst.ctrl_typevar])
|
||||
fresh_formal_tvs = move_first(fresh_formal_tvs, idx)
|
||||
actual_tvs = move_first(actual_tvs, idx)
|
||||
|
||||
# Unify each actual typevar with the corresponding fresh formal tv
|
||||
for (actual_tv, formal_tv) in zip(actual_tvs, fresh_formal_tvs):
|
||||
typ_or_err = unify(actual_tv, formal_tv, typ)
|
||||
err = get_error(typ_or_err)
|
||||
if (err):
|
||||
return "fail ti on {} <: {}: ".format(actual_tv, formal_tv) + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
# Add any instruction specific constraints
|
||||
for constr in inst.constraints:
|
||||
typ.add_constraint(constr.translate(m))
|
||||
|
||||
return typ
|
||||
|
||||
|
||||
def ti_rtl(rtl, typ):
|
||||
# type: (Rtl, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Perform type inference on an Rtl in a starting type env typ. Return an
|
||||
updated type environment or error.
|
||||
"""
|
||||
for (i, d) in enumerate(rtl.rtl):
|
||||
assert (isinstance(d, Def))
|
||||
typ_or_err = ti_def(d, typ)
|
||||
err = get_error(typ_or_err) # type: Optional[TypingError]
|
||||
if (err):
|
||||
return "On line {}: ".format(i) + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
return typ
|
||||
|
||||
|
||||
def ti_xform(xform, typ):
|
||||
# type: (XForm, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
Perform type inference on an Rtl in a starting type env typ. Return an
|
||||
updated type environment or error.
|
||||
"""
|
||||
typ_or_err = ti_rtl(xform.src, typ)
|
||||
err = get_error(typ_or_err) # type: Optional[TypingError]
|
||||
if (err):
|
||||
return "In src pattern: " + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
typ_or_err = ti_rtl(xform.dst, typ)
|
||||
err = get_error(typ_or_err)
|
||||
if (err):
|
||||
return "In dst pattern: " + err
|
||||
|
||||
typ = get_type_env(typ_or_err)
|
||||
|
||||
return get_type_env(typ_or_err)
|
||||
Reference in New Issue
Block a user