diff --git a/lib/cretonne/meta/cdsl/ast.py b/lib/cretonne/meta/cdsl/ast.py index 86e2805497..e954ab01e9 100644 --- a/lib/cretonne/meta/cdsl/ast.py +++ b/lib/cretonne/meta/cdsl/ast.py @@ -402,16 +402,30 @@ class Apply(Expr): if self.inst != other.inst: return None - # TODO: Should we check imm/cond codes here as well? - for i in self.inst.value_opnums: - self_a = self.args[i] - other_a = other.args[i] + # Guaranteed by self.inst == other.inst + assert (len(self.args) == len(other.args)) - assert isinstance(self_a, Var) and isinstance(other_a, Var) - if (self_a not in s): - s[self_a] = other_a + for (self_a, other_a) in zip(self.args, other.args): + if (isinstance(self_a, Var)): + if not isinstance(other_a, Var): + return None + + if (self_a not in s): + s[self_a] = other_a + else: + if (s[self_a] != other_a): + return None else: - if (s[self_a] != other_a): + assert isinstance(self_a, Enumerator) + + if not isinstance(other_a, Enumerator): + # Currently don't support substitutions Var->Enumerator + return None + + # Guaranteed by self.inst == other.inst + assert self_a.kind == other_a.kind + + if (self_a.value != other_a.value): return None return s diff --git a/lib/cretonne/meta/cdsl/test_xform.py b/lib/cretonne/meta/cdsl/test_xform.py index 1609bb6c5f..952d8c90cb 100644 --- a/lib/cretonne/meta/cdsl/test_xform.py +++ b/lib/cretonne/meta/cdsl/test_xform.py @@ -1,7 +1,8 @@ from __future__ import absolute_import from unittest import TestCase from doctest import DocTestSuite -from base.instructions import iadd, iadd_imm, iconst +from base.instructions import iadd, iadd_imm, iconst, icmp +from base.immediates import intcc from . import xform from .ast import Var from .xform import Rtl, XForm @@ -14,9 +15,15 @@ def load_tests(loader, tests, ignore): x = Var('x') y = Var('y') +z = Var('z') +u = Var('u') a = Var('a') +b = Var('b') c = Var('c') +CC1 = Var('CC1') +CC2 = Var('CC2') + class TestXForm(TestCase): def test_macro_pattern(self): @@ -57,3 +64,31 @@ class TestXForm(TestCase): dst = Rtl(a << iadd(x, y)) with self.assertRaisesRegexp(AssertionError, "'a' multiply defined"): XForm(src, dst) + + def test_subst_imm(self): + src = Rtl(a << iconst(x)) + dst = Rtl(c << iconst(y)) + assert src.substitution(dst, {}) == {a: c, x: y} + + def test_subst_enum_var(self): + src = Rtl(a << icmp(CC1, x, y)) + dst = Rtl(b << icmp(CC2, z, u)) + assert src.substitution(dst, {}) == {a: b, CC1: CC2, x: z, y: u} + + def test_subst_enum_const(self): + src = Rtl(a << icmp(intcc.eq, x, y)) + dst = Rtl(b << icmp(intcc.eq, z, u)) + assert src.substitution(dst, {}) == {a: b, x: z, y: u} + + def test_subst_enum_bad(self): + src = Rtl(a << icmp(CC1, x, y)) + dst = Rtl(b << icmp(intcc.eq, z, u)) + assert src.substitution(dst, {}) is None + + src = Rtl(a << icmp(intcc.eq, x, y)) + dst = Rtl(b << icmp(CC1, z, u)) + assert src.substitution(dst, {}) is None + + src = Rtl(a << icmp(intcc.eq, x, y)) + dst = Rtl(b << icmp(intcc.sge, z, u)) + assert src.substitution(dst, {}) is None