Save a reference from a Var to its src and dst defs.

When a Var is used in an XForm, it can be defined in the src or dst or
both patterns, and it is classified accordingly. When a Var is defined,
it is also useful to be able to find the `Def` that defined it.

Add src_def and dst_def reference members to Var, and initialize them in
the private Var copies that XForm creates for itself.

These two members also replace the defctx bitmask.
This commit is contained in:
Jakob Stoklund Olesen
2016-11-03 11:24:04 -07:00
parent f2c7a1d57b
commit 026c899042
2 changed files with 54 additions and 27 deletions

View File

@@ -74,7 +74,7 @@ class Var(Expr):
""" """
A free variable. A free variable.
When variables are used in `XForms` with source ans destination patterns, When variables are used in `XForms` with source and destination patterns,
they are classified as follows: they are classified as follows:
Input values Input values
@@ -96,14 +96,10 @@ class Var(Expr):
def __init__(self, name): def __init__(self, name):
# type: (str) -> None # type: (str) -> None
self.name = name self.name = name
# Bitmask of contexts where this variable is defined. # The `Def` defining this variable in a source pattern.
# See XForm._rewrite_defs(). self.src_def = None # type: Def
self.defctx = 0 # The `Def` defining this variable in a destination pattern.
self.dst_def = None # type: Def
# Context bits for `defctx` indicating which pattern has defines of this
# var.
SRCCTX = 1
DSTCTX = 2
def __str__(self): def __str__(self):
# type: () -> str # type: () -> str
@@ -112,29 +108,60 @@ class Var(Expr):
def __repr__(self): def __repr__(self):
# type: () -> str # type: () -> str
s = self.name s = self.name
if self.defctx: if self.src_def:
s += ", d={:02b}".format(self.defctx) s += ", src"
if self.dst_def:
s += ", dst"
return "Var({})".format(s) return "Var({})".format(s)
# Context bits for `set_def` indicating which pattern has defines of this
# var.
SRCCTX = 1
DSTCTX = 2
def set_def(self, context, d):
# type: (int, Def) -> None
"""
Set the `Def` that defines this variable in the given context.
The `context` must be one of `SRCCTX` or `DSTCTX`
"""
if context == self.SRCCTX:
self.src_def = d
else:
self.dst_def = d
def get_def(self, context):
# type: (int) -> Def
"""
Get the def of this variable in context.
The `context` must be one of `SRCCTX` or `DSTCTX`
"""
if context == self.SRCCTX:
return self.src_def
else:
return self.dst_def
def is_input(self): def is_input(self):
# type: () -> bool # type: () -> bool
"""Is this an input value to the source pattern?""" """Is this an input value to the src pattern?"""
return self.defctx == 0 return not self.src_def and not self.dst_def
def is_output(self): def is_output(self):
"""Is this an output value, defined in both src and dest patterns?""" """Is this an output value, defined in both src and dst patterns?"""
# type: () -> bool # type: () -> bool
return self.defctx == self.SRCCTX | self.DSTCTX return self.src_def and self.dst_def
def is_intermediate(self): def is_intermediate(self):
"""Is this an intermediate value, defined only in the src pattern?""" """Is this an intermediate value, defined only in the src pattern?"""
# type: () -> bool # type: () -> bool
return self.defctx == self.SRCCTX return self.src_def and not self.dst_def
def is_temp(self): def is_temp(self):
"""Is this a temp value, defined only in the dest pattern?""" """Is this a temp value, defined only in the dst pattern?"""
# type: () -> bool # type: () -> bool
return self.defctx == self.DSTCTX return not self.src_def and self.dst_def
class Apply(Expr): class Apply(Expr):

View File

@@ -63,7 +63,7 @@ class XForm(object):
... Rtl(c << iconst(v), ... Rtl(c << iconst(v),
... a << iadd(x, c)), ... a << iadd(x, c)),
... Rtl(a << iadd_imm(x, v))) ... Rtl(a << iadd_imm(x, v)))
XForm(inputs=[Var(v), Var(x)], defs=[Var(c, d=01), Var(a, d=11)], XForm(inputs=[Var(v), Var(x)], defs=[Var(c, src), Var(a, src, dst)],
c << iconst(v) c << iconst(v)
a << iadd(x, c) a << iadd(x, c)
=> =>
@@ -112,7 +112,7 @@ class XForm(object):
for line in rtl.rtl: for line in rtl.rtl:
if isinstance(line, Def): if isinstance(line, Def):
line.defs = tuple( line.defs = tuple(
self._rewrite_defs(line.defs, symtab, context)) self._rewrite_defs(line, symtab, context))
expr = line.expr expr = line.expr
else: else:
expr = line expr = line
@@ -132,23 +132,23 @@ class XForm(object):
expr.args = tuple( expr.args = tuple(
self._rewrite_uses(expr, stack, symtab, context)) self._rewrite_uses(expr, stack, symtab, context))
def _rewrite_defs(self, defs, symtab, context): def _rewrite_defs(self, line, symtab, context):
# type: (Sequence[Var], Dict[str, Var], int) -> Iterable[Var] # type: (Def, Dict[str, Var], int) -> Iterable[Var]
""" """
Given a tuple of symbols defined in a Def, rewrite them to local Given a tuple of symbols defined in a Def, rewrite them to local
symbols. Yield the new locals. symbols. Yield the new locals.
""" """
for sym in defs: for sym in line.defs:
name = str(sym) name = str(sym)
if name in symtab: if name in symtab:
var = symtab[name] var = symtab[name]
if var.defctx & context: if var.get_def(context):
raise AssertionError("'{}' multiply defined".format(name)) raise AssertionError("'{}' multiply defined".format(name))
else: else:
var = Var(name) var = Var(name)
symtab[name] = var symtab[name] = var
self.defs.append(var) self.defs.append(var)
var.defctx |= context var.set_def(context, line)
yield var yield var
def _rewrite_uses(self, expr, stack, symtab, context): def _rewrite_uses(self, expr, stack, symtab, context):
@@ -173,8 +173,8 @@ class XForm(object):
name = str(arg) name = str(arg)
if name in symtab: if name in symtab:
var = symtab[name] var = symtab[name]
# The variable must be used consistenty as a def or input. # The variable must be used consistently as a def or input.
if var.defctx and (var.defctx & context) == 0: if not var.is_input() and not var.get_def(context):
raise AssertionError( raise AssertionError(
"'{}' used as both input and def" "'{}' used as both input and def"
.format(name)) .format(name))