Save a reference from a Var to its src and dst defs.
When a Var is used in an XForm, it can be defined in the src or dst or both patterns, and it is classified accordingly. When a Var is defined, it is also useful to be able to find the `Def` that defined it. Add src_def and dst_def reference members to Var, and initialize them in the private Var copies that XForm creates for itself. These two members also replace the defctx bitmask.
This commit is contained in:
@@ -74,7 +74,7 @@ class Var(Expr):
|
||||
"""
|
||||
A free variable.
|
||||
|
||||
When variables are used in `XForms` with source ans destination patterns,
|
||||
When variables are used in `XForms` with source and destination patterns,
|
||||
they are classified as follows:
|
||||
|
||||
Input values
|
||||
@@ -96,14 +96,10 @@ class Var(Expr):
|
||||
def __init__(self, name):
|
||||
# type: (str) -> None
|
||||
self.name = name
|
||||
# Bitmask of contexts where this variable is defined.
|
||||
# See XForm._rewrite_defs().
|
||||
self.defctx = 0
|
||||
|
||||
# Context bits for `defctx` indicating which pattern has defines of this
|
||||
# var.
|
||||
SRCCTX = 1
|
||||
DSTCTX = 2
|
||||
# The `Def` defining this variable in a source pattern.
|
||||
self.src_def = None # type: Def
|
||||
# The `Def` defining this variable in a destination pattern.
|
||||
self.dst_def = None # type: Def
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
@@ -112,29 +108,60 @@ class Var(Expr):
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
s = self.name
|
||||
if self.defctx:
|
||||
s += ", d={:02b}".format(self.defctx)
|
||||
if self.src_def:
|
||||
s += ", src"
|
||||
if self.dst_def:
|
||||
s += ", dst"
|
||||
return "Var({})".format(s)
|
||||
|
||||
# Context bits for `set_def` indicating which pattern has defines of this
|
||||
# var.
|
||||
SRCCTX = 1
|
||||
DSTCTX = 2
|
||||
|
||||
def set_def(self, context, d):
|
||||
# type: (int, Def) -> None
|
||||
"""
|
||||
Set the `Def` that defines this variable in the given context.
|
||||
|
||||
The `context` must be one of `SRCCTX` or `DSTCTX`
|
||||
"""
|
||||
if context == self.SRCCTX:
|
||||
self.src_def = d
|
||||
else:
|
||||
self.dst_def = d
|
||||
|
||||
def get_def(self, context):
|
||||
# type: (int) -> Def
|
||||
"""
|
||||
Get the def of this variable in context.
|
||||
|
||||
The `context` must be one of `SRCCTX` or `DSTCTX`
|
||||
"""
|
||||
if context == self.SRCCTX:
|
||||
return self.src_def
|
||||
else:
|
||||
return self.dst_def
|
||||
|
||||
def is_input(self):
|
||||
# type: () -> bool
|
||||
"""Is this an input value to the source pattern?"""
|
||||
return self.defctx == 0
|
||||
"""Is this an input value to the src pattern?"""
|
||||
return not self.src_def and not self.dst_def
|
||||
|
||||
def is_output(self):
|
||||
"""Is this an output value, defined in both src and dest patterns?"""
|
||||
"""Is this an output value, defined in both src and dst patterns?"""
|
||||
# type: () -> bool
|
||||
return self.defctx == self.SRCCTX | self.DSTCTX
|
||||
return self.src_def and self.dst_def
|
||||
|
||||
def is_intermediate(self):
|
||||
"""Is this an intermediate value, defined only in the src pattern?"""
|
||||
# type: () -> bool
|
||||
return self.defctx == self.SRCCTX
|
||||
return self.src_def and not self.dst_def
|
||||
|
||||
def is_temp(self):
|
||||
"""Is this a temp value, defined only in the dest pattern?"""
|
||||
"""Is this a temp value, defined only in the dst pattern?"""
|
||||
# type: () -> bool
|
||||
return self.defctx == self.DSTCTX
|
||||
return not self.src_def and self.dst_def
|
||||
|
||||
|
||||
class Apply(Expr):
|
||||
|
||||
@@ -63,7 +63,7 @@ class XForm(object):
|
||||
... Rtl(c << iconst(v),
|
||||
... a << iadd(x, c)),
|
||||
... Rtl(a << iadd_imm(x, v)))
|
||||
XForm(inputs=[Var(v), Var(x)], defs=[Var(c, d=01), Var(a, d=11)],
|
||||
XForm(inputs=[Var(v), Var(x)], defs=[Var(c, src), Var(a, src, dst)],
|
||||
c << iconst(v)
|
||||
a << iadd(x, c)
|
||||
=>
|
||||
@@ -112,7 +112,7 @@ class XForm(object):
|
||||
for line in rtl.rtl:
|
||||
if isinstance(line, Def):
|
||||
line.defs = tuple(
|
||||
self._rewrite_defs(line.defs, symtab, context))
|
||||
self._rewrite_defs(line, symtab, context))
|
||||
expr = line.expr
|
||||
else:
|
||||
expr = line
|
||||
@@ -132,23 +132,23 @@ class XForm(object):
|
||||
expr.args = tuple(
|
||||
self._rewrite_uses(expr, stack, symtab, context))
|
||||
|
||||
def _rewrite_defs(self, defs, symtab, context):
|
||||
# type: (Sequence[Var], Dict[str, Var], int) -> Iterable[Var]
|
||||
def _rewrite_defs(self, line, symtab, context):
|
||||
# type: (Def, Dict[str, Var], int) -> Iterable[Var]
|
||||
"""
|
||||
Given a tuple of symbols defined in a Def, rewrite them to local
|
||||
symbols. Yield the new locals.
|
||||
"""
|
||||
for sym in defs:
|
||||
for sym in line.defs:
|
||||
name = str(sym)
|
||||
if name in symtab:
|
||||
var = symtab[name]
|
||||
if var.defctx & context:
|
||||
if var.get_def(context):
|
||||
raise AssertionError("'{}' multiply defined".format(name))
|
||||
else:
|
||||
var = Var(name)
|
||||
symtab[name] = var
|
||||
self.defs.append(var)
|
||||
var.defctx |= context
|
||||
var.set_def(context, line)
|
||||
yield var
|
||||
|
||||
def _rewrite_uses(self, expr, stack, symtab, context):
|
||||
@@ -173,8 +173,8 @@ class XForm(object):
|
||||
name = str(arg)
|
||||
if name in symtab:
|
||||
var = symtab[name]
|
||||
# The variable must be used consistenty as a def or input.
|
||||
if var.defctx and (var.defctx & context) == 0:
|
||||
# The variable must be used consistently as a def or input.
|
||||
if not var.is_input() and not var.get_def(context):
|
||||
raise AssertionError(
|
||||
"'{}' used as both input and def"
|
||||
.format(name))
|
||||
|
||||
Reference in New Issue
Block a user