Add better type inference and encapsulate it in its own file (#110)

* Add more rigorous type inference and encapsulate the type inferece code in its own file (ti.py).

Add constraints accumulation during type inference, to represent constraints that cannot be expressed
using bijective derivation functions between typevars.

Add testing for new type inference code.

* Additional annotations to appease mypy
This commit is contained in:
d1m0
2017-07-05 09:16:44 -07:00
committed by Jakob Stoklund Olesen
parent f867ddbf0c
commit a5c96ef6bf
6 changed files with 1123 additions and 281 deletions

View File

@@ -3,6 +3,7 @@ Instruction transformations.
"""
from __future__ import absolute_import
from .ast import Def, Var, Apply
from .ti import ti_xform, TypeEnv, get_type_env
try:
from typing import Union, Iterator, Sequence, Iterable, List, Dict # noqa
@@ -83,6 +84,8 @@ class XForm(object):
self._rewrite_rtl(src, symtab, Var.SRCCTX)
num_src_inputs = len(self.inputs)
self._rewrite_rtl(dst, symtab, Var.DSTCTX)
# Needed for testing type inference on XForms
self.symtab = symtab
# Check for inconsistently used inputs.
for i in self.inputs:
@@ -96,9 +99,25 @@ class XForm(object):
"extra inputs in dst RTL: {}".format(
self.inputs[num_src_inputs:]))
self._infer_types(self.src)
self._infer_types(self.dst)
self._collect_typevars()
# Perform type inference and cleanup
raw_ti = get_type_env(ti_xform(self, TypeEnv()))
raw_ti.normalize()
self.ti = raw_ti.extract()
# Sanity: The set of inferred free typevars should be a subset of the
# TVs corresponding to Vars appearing in src
self.free_typevars = self.ti.free_typevars()
src_vars = set(self.inputs).union(
[x for x in self.defs if not x.is_temp()])
src_tvs = set([v.get_typevar() for v in src_vars])
if (not self.free_typevars.issubset(src_tvs)):
raise AssertionError(
"Some free vars don't appear in src - {}"
.format(self.free_typevars.difference(src_tvs)))
# Update the type vars for each Var to their inferred values
for v in self.inputs + self.defs:
v.set_typevar(self.ti[v.get_typevar()])
def __repr__(self):
# type: () -> str
@@ -202,63 +221,6 @@ class XForm(object):
raise AssertionError(
'{} not defined in dest pattern'.format(d))
def _infer_types(self, rtl):
# type: (Rtl) -> None
"""Assign type variables to all value variables used in `rtl`."""
for d in rtl.rtl:
inst = d.expr.inst
# Get the Var corresponding to the controlling type variable.
ctrl_var = None # type: Var
if inst.is_polymorphic:
if inst.use_typevar_operand:
# Should this be an assertion instead?
# Should all value operands be required to be Vars?
arg = d.expr.args[inst.format.typevar_operand]
if isinstance(arg, Var):
ctrl_var = arg
else:
ctrl_var = d.defs[inst.value_results[0]]
# Reconcile arguments with the requirements of `inst`.
for opnum in inst.value_opnums:
inst_tv = inst.ins[opnum].typevar
v = d.expr.args[opnum]
if isinstance(v, Var):
v.constrain_typevar(inst_tv, inst.ctrl_typevar, ctrl_var)
# Reconcile results with the requirements of `inst`.
for resnum in inst.value_results:
inst_tv = inst.outs[resnum].typevar
v = d.defs[resnum]
v.constrain_typevar(inst_tv, inst.ctrl_typevar, ctrl_var)
def _collect_typevars(self):
# type: () -> None
"""
Collect a list of variables whose type can be used to infer the types
of all expressions.
This should be called after `_infer_types()` above has computed type
variables for all the used vars.
"""
fvars = list(v for v in self.inputs if v.has_free_typevar())
fvars += list(v for v in self.defs if v.has_free_typevar())
self.free_typevars = fvars
# When substituting a pattern, we know the types of all variables that
# appear on the source side: inut, output, and intermediate values.
# However, temporary values which appear only on the destination side
# must have their type computed somehow.
#
# Some variables have a fixed type which appears as a type variable
# with a singleton_type field set. That's allowed for temps too.
for v in fvars:
if v.is_temp() and not v.typevar.singleton_type():
raise AssertionError(
"Cannot determine type of temp '{}' in xform:\n{}"
.format(v, self))
class XFormGroup(object):
"""