Save a reference from a Var to its src and dst defs.

When a Var is used in an XForm, it can be defined in the src or dst or
both patterns, and it is classified accordingly. When a Var is defined,
it is also useful to be able to find the `Def` that defined it.

Add src_def and dst_def reference members to Var, and initialize them in
the private Var copies that XForm creates for itself.

These two members also replace the defctx bitmask.
This commit is contained in:
Jakob Stoklund Olesen
2016-11-03 11:24:04 -07:00
parent f2c7a1d57b
commit 026c899042
2 changed files with 54 additions and 27 deletions

View File

@@ -74,7 +74,7 @@ class Var(Expr):
"""
A free variable.
When variables are used in `XForms` with source ans destination patterns,
When variables are used in `XForms` with source and destination patterns,
they are classified as follows:
Input values
@@ -96,14 +96,10 @@ class Var(Expr):
def __init__(self, name):
# type: (str) -> None
self.name = name
# Bitmask of contexts where this variable is defined.
# See XForm._rewrite_defs().
self.defctx = 0
# Context bits for `defctx` indicating which pattern has defines of this
# var.
SRCCTX = 1
DSTCTX = 2
# The `Def` defining this variable in a source pattern.
self.src_def = None # type: Def
# The `Def` defining this variable in a destination pattern.
self.dst_def = None # type: Def
def __str__(self):
# type: () -> str
@@ -112,29 +108,60 @@ class Var(Expr):
def __repr__(self):
# type: () -> str
s = self.name
if self.defctx:
s += ", d={:02b}".format(self.defctx)
if self.src_def:
s += ", src"
if self.dst_def:
s += ", dst"
return "Var({})".format(s)
# Context bits for `set_def` indicating which pattern has defines of this
# var.
SRCCTX = 1
DSTCTX = 2
def set_def(self, context, d):
# type: (int, Def) -> None
"""
Set the `Def` that defines this variable in the given context.
The `context` must be one of `SRCCTX` or `DSTCTX`
"""
if context == self.SRCCTX:
self.src_def = d
else:
self.dst_def = d
def get_def(self, context):
# type: (int) -> Def
"""
Get the def of this variable in context.
The `context` must be one of `SRCCTX` or `DSTCTX`
"""
if context == self.SRCCTX:
return self.src_def
else:
return self.dst_def
def is_input(self):
# type: () -> bool
"""Is this an input value to the source pattern?"""
return self.defctx == 0
"""Is this an input value to the src pattern?"""
return not self.src_def and not self.dst_def
def is_output(self):
"""Is this an output value, defined in both src and dest patterns?"""
"""Is this an output value, defined in both src and dst patterns?"""
# type: () -> bool
return self.defctx == self.SRCCTX | self.DSTCTX
return self.src_def and self.dst_def
def is_intermediate(self):
"""Is this an intermediate value, defined only in the src pattern?"""
# type: () -> bool
return self.defctx == self.SRCCTX
return self.src_def and not self.dst_def
def is_temp(self):
"""Is this a temp value, defined only in the dest pattern?"""
"""Is this a temp value, defined only in the dst pattern?"""
# type: () -> bool
return self.defctx == self.DSTCTX
return not self.src_def and self.dst_def
class Apply(Expr):

View File

@@ -63,7 +63,7 @@ class XForm(object):
... Rtl(c << iconst(v),
... a << iadd(x, c)),
... Rtl(a << iadd_imm(x, v)))
XForm(inputs=[Var(v), Var(x)], defs=[Var(c, d=01), Var(a, d=11)],
XForm(inputs=[Var(v), Var(x)], defs=[Var(c, src), Var(a, src, dst)],
c << iconst(v)
a << iadd(x, c)
=>
@@ -112,7 +112,7 @@ class XForm(object):
for line in rtl.rtl:
if isinstance(line, Def):
line.defs = tuple(
self._rewrite_defs(line.defs, symtab, context))
self._rewrite_defs(line, symtab, context))
expr = line.expr
else:
expr = line
@@ -132,23 +132,23 @@ class XForm(object):
expr.args = tuple(
self._rewrite_uses(expr, stack, symtab, context))
def _rewrite_defs(self, defs, symtab, context):
# type: (Sequence[Var], Dict[str, Var], int) -> Iterable[Var]
def _rewrite_defs(self, line, symtab, context):
# type: (Def, Dict[str, Var], int) -> Iterable[Var]
"""
Given a tuple of symbols defined in a Def, rewrite them to local
symbols. Yield the new locals.
"""
for sym in defs:
for sym in line.defs:
name = str(sym)
if name in symtab:
var = symtab[name]
if var.defctx & context:
if var.get_def(context):
raise AssertionError("'{}' multiply defined".format(name))
else:
var = Var(name)
symtab[name] = var
self.defs.append(var)
var.defctx |= context
var.set_def(context, line)
yield var
def _rewrite_uses(self, expr, stack, symtab, context):
@@ -173,8 +173,8 @@ class XForm(object):
name = str(arg)
if name in symtab:
var = symtab[name]
# The variable must be used consistenty as a def or input.
if var.defctx and (var.defctx & context) == 0:
# The variable must be used consistently as a def or input.
if not var.is_input() and not var.get_def(context):
raise AssertionError(
"'{}' used as both input and def"
.format(name))