From 66da171050abeea351a394035fa3451b84b0c772 Mon Sep 17 00:00:00 2001 From: d1m0 Date: Mon, 14 Aug 2017 20:19:47 -0700 Subject: [PATCH] Fix for #141 (#142) * Add Atom and Literal base classes to CDSL Ast. Change substitution() and copy() on Def/Apply/Rtl to support substituting Var->Union[Var, Literal]. Check in Apply() constructor kinds of passed in Literals respect instruction signature * Change verify_semantics to check all possible instantiations of enumerated immediates (needed to descrive icmp). Add all bitvector comparison primitives and bvite; Change set_semantics to optionally accept XForms; Add semantics for icmp; Fix typing errors in semantics/{smtlib, elaborate, __init__}.py after the change of VarMap->VarAtomMap * Forgot macros.py * Nit obscured by testing with mypy enabled present. * Typo --- lib/cretonne/meta/base/semantics.py | 67 +++++++--- lib/cretonne/meta/cdsl/ast.py | 120 +++++++++++------- lib/cretonne/meta/cdsl/instructions.py | 11 +- lib/cretonne/meta/cdsl/operands.py | 15 ++- lib/cretonne/meta/cdsl/test_xform.py | 43 ++++++- lib/cretonne/meta/cdsl/xform.py | 30 +++-- lib/cretonne/meta/gen_legalizer.py | 8 +- lib/cretonne/meta/semantics/__init__.py | 80 ++++++++---- lib/cretonne/meta/semantics/elaborate.py | 12 +- lib/cretonne/meta/semantics/macros.py | 45 +++++++ lib/cretonne/meta/semantics/primitives.py | 58 ++++++++- lib/cretonne/meta/semantics/smtlib.py | 50 +++++--- lib/cretonne/meta/semantics/test_elaborate.py | 13 +- 13 files changed, 415 insertions(+), 137 deletions(-) create mode 100644 lib/cretonne/meta/semantics/macros.py diff --git a/lib/cretonne/meta/base/semantics.py b/lib/cretonne/meta/base/semantics.py index edf4c5f82e..ec1852133e 100644 --- a/lib/cretonne/meta/base/semantics.py +++ b/lib/cretonne/meta/base/semantics.py @@ -1,20 +1,33 @@ from __future__ import absolute_import from semantics.primitives import prim_to_bv, prim_from_bv, bvsplit, bvconcat,\ - bvadd, bvult, bvzeroext, bvsignext + bvadd, bvzeroext, bvsignext +from semantics.primitives import bveq, bvne, bvsge, bvsgt, bvsle, bvslt,\ + bvuge, bvugt, bvule, bvult +from semantics.macros import bool2bv from .instructions import vsplit, vconcat, iadd, iadd_cout, icmp, bextend, \ isplit, iconcat, iadd_cin, iadd_carry from .immediates import intcc -from cdsl.xform import Rtl +from cdsl.xform import Rtl, XForm from cdsl.ast import Var from cdsl.typevar import TypeSet from cdsl.ti import InTypeset +try: + from typing import TYPE_CHECKING # noqa + if TYPE_CHECKING: + from cdsl.ast import Enumerator # noqa + from cdsl.instructions import Instruction # noqa +except ImportError: + TYPE_CHECKING = False + x = Var('x') y = Var('y') a = Var('a') b = Var('b') c_out = Var('c_out') c_in = Var('c_in') +CC = Var('CC') +bc_out = Var('bc_out') bvc_out = Var('bvc_out') bvc_in = Var('bvc_in') xhi = Var('xhi') @@ -93,7 +106,8 @@ iadd_cout.set_semantics( bvx << prim_to_bv(x), bvy << prim_to_bv(y), bva << bvadd(bvx, bvy), - bvc_out << bvult(bva, bvx), + bc_out << bvult(bva, bvx), + bvc_out << bool2bv(bc_out), a << prim_from_bv(bva), c_out << prim_from_bv(bvc_out) )) @@ -107,7 +121,8 @@ iadd_carry.set_semantics( bvs << bvzeroext(bvc_in), bvt << bvadd(bvx, bvy), bva << bvadd(bvt, bvs), - bvc_out << bvult(bva, bvx), + bc_out << bvult(bva, bvx), + bvc_out << bool2bv(bc_out), a << prim_from_bv(bva), c_out << prim_from_bv(bvc_out) )) @@ -126,23 +141,45 @@ bextend.set_semantics( a << vconcat(alo, ahi) )) + +def create_comp_xform(cc, bvcmp_func): + # type: (Enumerator, Instruction) -> XForm + ba = Var('ba') + return XForm( + Rtl( + a << icmp(cc, x, y) + ), + Rtl( + bvx << prim_to_bv(x), + bvy << prim_to_bv(y), + ba << bvcmp_func(bvx, bvy), + bva << bool2bv(ba), + bva_wide << bvzeroext(bva), + a << prim_from_bv(bva_wide), + ), + constraints=InTypeset(x.get_typevar(), ScalarTS)) + + icmp.set_semantics( - a << icmp(intcc.ult, x, y), - (Rtl( - bvx << prim_to_bv(x), - bvy << prim_to_bv(y), - bva << bvult(bvx, bvy), - bva_wide << bvzeroext(bva), - a << prim_from_bv(bva_wide), - ), [InTypeset(x.get_typevar(), ScalarTS)]), + a << icmp(CC, x, y), Rtl( (xlo, xhi) << vsplit(x), (ylo, yhi) << vsplit(y), - alo << icmp(intcc.ult, xlo, ylo), - ahi << icmp(intcc.ult, xhi, yhi), + alo << icmp(CC, xlo, ylo), + ahi << icmp(CC, xhi, yhi), b << vconcat(alo, ahi), a << bextend(b) - )) + ), + create_comp_xform(intcc.eq, bveq), + create_comp_xform(intcc.ne, bvne), + create_comp_xform(intcc.sge, bvsge), + create_comp_xform(intcc.sgt, bvsgt), + create_comp_xform(intcc.sle, bvsle), + create_comp_xform(intcc.slt, bvslt), + create_comp_xform(intcc.uge, bvuge), + create_comp_xform(intcc.ugt, bvugt), + create_comp_xform(intcc.ule, bvule), + create_comp_xform(intcc.ult, bvult)) # # Legalization helper instructions. diff --git a/lib/cretonne/meta/cdsl/ast.py b/lib/cretonne/meta/cdsl/ast.py index aa53e7b9b5..251fee3641 100644 --- a/lib/cretonne/meta/cdsl/ast.py +++ b/lib/cretonne/meta/cdsl/ast.py @@ -11,23 +11,23 @@ from .predicates import IsEqual, And, TypePredicate try: from typing import Union, Tuple, Sequence, TYPE_CHECKING, Dict, List # noqa - from typing import Optional, Set # noqa + from typing import Optional, Set, Any # noqa if TYPE_CHECKING: from .operands import ImmediateKind # noqa from .predicates import PredNode # noqa - VarMap = Dict["Var", "Var"] + VarAtomMap = Dict["Var", "Atom"] except ImportError: pass def replace_var(arg, m): - # type: (Expr, VarMap) -> Expr + # type: (Expr, VarAtomMap) -> Expr """ Given a var v return either m[v] or a new variable v' (and remember m[v]=v'). Otherwise return the argument unchanged """ if isinstance(arg, Var): - new_arg = m.get(arg, Var(arg.name)) # type: Var + new_arg = m.get(arg, Var(arg.name)) # type: Atom m[arg] = new_arg return new_arg return arg @@ -76,7 +76,7 @@ class Def(object): ', '.join(map(str, self.defs)), self.expr) def copy(self, m): - # type: (VarMap) -> Def + # type: (VarAtomMap) -> Def """ Return a copy of this Def with vars replaced with fresh variables, in accordance with the map m. Update m as neccessary. @@ -106,7 +106,7 @@ class Def(object): return self.definitions().union(self.uses()) def substitution(self, other, s): - # type: (Def, VarMap) -> Optional[VarMap] + # type: (Def, VarAtomMap) -> Optional[VarAtomMap] """ If the Defs self and other agree structurally, return a variable substitution to transform self to other. Otherwise return None. Two @@ -133,7 +133,13 @@ class Expr(object): """ -class Var(Expr): +class Atom(Expr): + """ + An Atom in the DSL is either a literal or a Var + """ + + +class Var(Atom): """ A free variable. @@ -304,6 +310,16 @@ class Apply(Expr): self.args = args assert len(self.inst.ins) == len(args) + # Check that the kinds of Literals arguments match the expected Operand + for op_idx in self.inst.imm_opnums: + arg = self.args[op_idx] + op = self.inst.ins[op_idx] + + if isinstance(arg, Literal): + assert arg.kind == op.kind, \ + "Passing literal {} to field of wrong kind {}."\ + .format(arg, op.kind) + def __rlshift__(self, other): # type: (Union[Var, Tuple[Var, ...]]) -> Def """ @@ -377,7 +393,7 @@ class Apply(Expr): return pred def copy(self, m): - # type: (VarMap) -> Apply + # type: (VarAtomMap) -> Apply """ Return a copy of this Expr with vars replaced with fresh variables, in accordance with the map m. Update m as neccessary. @@ -396,15 +412,12 @@ class Apply(Expr): return res def substitution(self, other, s): - # type: (Apply, VarMap) -> Optional[VarMap] + # type: (Apply, VarAtomMap) -> Optional[VarAtomMap] """ - If the application self and other agree structurally, return a variable - substitution to transform self to other. Otherwise return None. Two - applications agree structurally if: - 1) They are over the same instruction - 2) Every Var v in self, maps to a single Var w in other. I.e for - each use of v in self, w is used in the corresponding place in - other. + If there is a substituion from Var->Atom that converts self to other, + return it, otherwise return None. Note that this is strictly weaker + than unification (see TestXForm.test_subst_enum_bad_var_const for + example). """ if self.inst != other.inst: return None @@ -413,37 +426,62 @@ class Apply(Expr): assert (len(self.args) == len(other.args)) for (self_a, other_a) in zip(self.args, other.args): - if (isinstance(self_a, Var)): - if not isinstance(other_a, Var): - return None + assert isinstance(self_a, Atom) and isinstance(other_a, Atom) + if (isinstance(self_a, Var)): if (self_a not in s): s[self_a] = other_a else: if (s[self_a] != other_a): return None - elif isinstance(self_a, ConstantInt): - if not isinstance(other_a, ConstantInt): - return None - assert self_a.kind == other_a.kind - if (self_a.value != other_a.value): - return None + elif isinstance(other_a, Var): + assert isinstance(self_a, Literal) + if (other_a not in s): + s[other_a] = self_a + else: + if s[other_a] != self_a: + return None else: - assert isinstance(self_a, Enumerator) - - if not isinstance(other_a, Enumerator): - # Currently don't support substitutions Var->Enumerator - return None - + assert (isinstance(self_a, Literal) and + isinstance(other_a, Literal)) # Guaranteed by self.inst == other.inst assert self_a.kind == other_a.kind - if (self_a.value != other_a.value): return None + return s -class ConstantInt(Expr): +class Literal(Atom): + """ + Base Class for all literal expressions in the DSL. + """ + def __init__(self, kind, value): + # type: (ImmediateKind, Any) -> None + self.kind = kind + self.value = value + + def __eq__(self, other): + # type: (Any) -> bool + if not isinstance(other, Literal): + return False + + if self.kind != other.kind: + return False + + # Can't just compare value here, as comparison Any <> Any returns Any + return repr(self) == repr(other) + + def __ne__(self, other): + # type: (Any) -> bool + return not self.__eq__(other) + + def __repr__(self): + # type: () -> str + return '{}.{}'.format(self.kind, self.value) + + +class ConstantInt(Literal): """ A value of an integer immediate operand. @@ -454,8 +492,7 @@ class ConstantInt(Expr): def __init__(self, kind, value): # type: (ImmediateKind, int) -> None - self.kind = kind - self.value = value + super(ConstantInt, self).__init__(kind, value) def __str__(self): # type: () -> str @@ -464,12 +501,8 @@ class ConstantInt(Expr): """ return str(self.value) - def __repr__(self): - # type: () -> str - return '{}({})'.format(self.kind, self.value) - -class Enumerator(Expr): +class Enumerator(Literal): """ A value of an enumerated immediate operand. @@ -486,8 +519,7 @@ class Enumerator(Expr): def __init__(self, kind, value): # type: (ImmediateKind, str) -> None - self.kind = kind - self.value = value + super(Enumerator, self).__init__(kind, value) def __str__(self): # type: () -> str @@ -495,7 +527,3 @@ class Enumerator(Expr): Get the Rust expression form of this enumerator. """ return self.kind.rust_enumerator(self.value) - - def __repr__(self): - # type: () -> str - return '{}.{}'.format(self.kind, self.value) diff --git a/lib/cretonne/meta/cdsl/instructions.py b/lib/cretonne/meta/cdsl/instructions.py index 30c27c6649..d295f59b41 100644 --- a/lib/cretonne/meta/cdsl/instructions.py +++ b/lib/cretonne/meta/cdsl/instructions.py @@ -9,7 +9,7 @@ try: from typing import Union, Sequence, List, Tuple, Any, TYPE_CHECKING # noqa from typing import Dict # noqa if TYPE_CHECKING: - from .ast import Expr, Apply, Var, Def # noqa + from .ast import Expr, Apply, Var, Def, VarAtomMap # noqa from .typevar import TypeVar # noqa from .ti import TypeConstraint # noqa from .xform import XForm, Rtl @@ -18,7 +18,7 @@ try: ConstrList = Union[Sequence[TypeConstraint], TypeConstraint] MaybeBoundInst = Union['Instruction', 'BoundInstruction'] InstructionSemantics = Sequence[XForm] - RtlCase = Union[Rtl, Tuple[Rtl, Sequence[TypeConstraint]]] + SemDefCase = Union[Rtl, Tuple[Rtl, Sequence[TypeConstraint]], XForm] except ImportError: pass @@ -349,7 +349,7 @@ class Instruction(object): return Apply(self, args) def set_semantics(self, src, *dsts): - # type: (Union[Def, Apply], *RtlCase) -> None + # type: (Union[Def, Apply], *SemDefCase) -> None """Set our semantics.""" from semantics import verify_semantics from .xform import XForm, Rtl @@ -358,6 +358,11 @@ class Instruction(object): for dst in dsts: if isinstance(dst, Rtl): sem.append(XForm(Rtl(src).copy({}), dst)) + elif isinstance(dst, XForm): + sem.append(XForm( + dst.src.copy({}), + dst.dst.copy({}), + dst.constraints)) else: assert isinstance(dst, tuple) sem.append(XForm(Rtl(src).copy({}), dst[0], diff --git a/lib/cretonne/meta/cdsl/operands.py b/lib/cretonne/meta/cdsl/operands.py index abf409a8c4..2ceb94b0fa 100644 --- a/lib/cretonne/meta/cdsl/operands.py +++ b/lib/cretonne/meta/cdsl/operands.py @@ -5,10 +5,10 @@ from .types import ValueType from .typevar import TypeVar try: - from typing import Union, Dict, TYPE_CHECKING # noqa + from typing import Union, Dict, TYPE_CHECKING, Iterable # noqa OperandSpec = Union['OperandKind', ValueType, TypeVar] if TYPE_CHECKING: - from .ast import Enumerator, ConstantInt # noqa + from .ast import Enumerator, ConstantInt, Literal # noqa except ImportError: pass @@ -128,6 +128,17 @@ class ImmediateKind(OperandKind): """ return '{}::{}'.format(self.rust_type, self.values[value]) + def is_enumerable(self): + # type: () -> bool + return self.values is not None + + def possible_values(self): + # type: () -> Iterable[Literal] + from cdsl.ast import Enumerator # noqa + assert self.is_enumerable() + for v in self.values.keys(): + yield Enumerator(self, v) + # Instances of entity reference operand types are provided in the # `cretonne.entities` module. diff --git a/lib/cretonne/meta/cdsl/test_xform.py b/lib/cretonne/meta/cdsl/test_xform.py index 952d8c90cb..424a7c824d 100644 --- a/lib/cretonne/meta/cdsl/test_xform.py +++ b/lib/cretonne/meta/cdsl/test_xform.py @@ -80,15 +80,52 @@ class TestXForm(TestCase): dst = Rtl(b << icmp(intcc.eq, z, u)) assert src.substitution(dst, {}) == {a: b, x: z, y: u} - def test_subst_enum_bad(self): + def test_subst_enum_var_const(self): src = Rtl(a << icmp(CC1, x, y)) dst = Rtl(b << icmp(intcc.eq, z, u)) - assert src.substitution(dst, {}) is None + assert src.substitution(dst, {}) == {CC1: intcc.eq, x: z, y: u, a: b},\ + "{} != {}".format(src.substitution(dst, {}), + {CC1: intcc.eq, x: z, y: u, a: b}) src = Rtl(a << icmp(intcc.eq, x, y)) dst = Rtl(b << icmp(CC1, z, u)) - assert src.substitution(dst, {}) is None + assert src.substitution(dst, {}) == {CC1: intcc.eq, x: z, y: u, a: b} + def test_subst_enum_bad(self): src = Rtl(a << icmp(intcc.eq, x, y)) dst = Rtl(b << icmp(intcc.sge, z, u)) assert src.substitution(dst, {}) is None + + def test_subst_enum_bad_var_const(self): + a1 = Var('a1') + x1 = Var('x1') + y1 = Var('y1') + + b1 = Var('b1') + z1 = Var('z1') + u1 = Var('u1') + + # Var mapping to 2 different constants + src = Rtl(a << icmp(CC1, x, y), + a1 << icmp(CC1, x1, y1)) + dst = Rtl(b << icmp(intcc.eq, z, u), + b1 << icmp(intcc.sge, z1, u1)) + + assert src.substitution(dst, {}) is None + + # 2 different constants mapping to the same var + src = Rtl(a << icmp(intcc.eq, x, y), + a1 << icmp(intcc.sge, x1, y1)) + dst = Rtl(b << icmp(CC1, z, u), + b1 << icmp(CC1, z1, u1)) + + assert src.substitution(dst, {}) is None + + # Var mapping to var and constant - note that full unification would + # have allowed this. + src = Rtl(a << icmp(CC1, x, y), + a1 << icmp(CC1, x1, y1)) + dst = Rtl(b << icmp(CC2, z, u), + b1 << icmp(intcc.sge, z1, u1)) + + assert src.substitution(dst, {}) is None diff --git a/lib/cretonne/meta/cdsl/xform.py b/lib/cretonne/meta/cdsl/xform.py index 991e429b18..261a70a4af 100644 --- a/lib/cretonne/meta/cdsl/xform.py +++ b/lib/cretonne/meta/cdsl/xform.py @@ -3,16 +3,16 @@ Instruction transformations. """ from __future__ import absolute_import from .ast import Def, Var, Apply -from .ti import ti_xform, TypeEnv, get_type_env +from .ti import ti_xform, TypeEnv, get_type_env, TypeConstraint from functools import reduce try: from typing import Union, Iterator, Sequence, Iterable, List, Dict # noqa from typing import Optional, Set # noqa - from .ast import Expr, VarMap # noqa + from .ast import Expr, VarAtomMap # noqa from .isa import TargetISA # noqa - from .ti import TypeConstraint # noqa from .typevar import TypeVar # noqa + from .instructions import ConstrList # noqa DefApply = Union[Def, Apply] except ImportError: pass @@ -47,7 +47,7 @@ class Rtl(object): self.rtl = tuple(map(canonicalize_defapply, args)) def copy(self, m): - # type: (VarMap) -> Rtl + # type: (VarAtomMap) -> Rtl """ Return a copy of this rtl with all Vars substituted with copies or according to m. Update m as neccessary. @@ -85,7 +85,7 @@ class Rtl(object): return reduce(flow_f, reversed(self.rtl), set([])) def substitution(self, other, s): - # type: (Rtl, VarMap) -> Optional[VarMap] + # type: (Rtl, VarAtomMap) -> Optional[VarAtomMap] """ If the Rtl self agrees structurally with the Rtl other, return a substitution to transform self to other. Two Rtls agree structurally if @@ -132,6 +132,10 @@ class Rtl(object): assert typing[v].singleton_type() is not None v.set_typevar(typing[v]) + def __str__(self): + # type: () -> str + return "\n".join(map(str, self.rtl)) + class XForm(object): """ @@ -162,7 +166,7 @@ class XForm(object): """ def __init__(self, src, dst, constraints=None): - # type: (Rtl, Rtl, Optional[Sequence[TypeConstraint]]) -> None + # type: (Rtl, Rtl, Optional[ConstrList]) -> None self.src = src self.dst = dst # Variables that are inputs to the source pattern. @@ -203,10 +207,18 @@ class XForm(object): return tv return symtab[tv.name[len("typeof_"):]].get_typevar() + self.constraints = [] # type: List[TypeConstraint] if constraints is not None: - for c in constraints: + if isinstance(constraints, TypeConstraint): + constr_list = [constraints] # type: Sequence[TypeConstraint] + else: + constr_list = constraints + + for c in constr_list: type_m = {tv: interp_tv(tv) for tv in c.tvs()} - self.ti.add_constraint(c.translate(type_m)) + inner_c = c.translate(type_m) + self.constraints.append(inner_c) + self.ti.add_constraint(inner_c) # Sanity: The set of inferred free typevars should be a subset of the # TVs corresponding to Vars appearing in src @@ -333,7 +345,7 @@ class XForm(object): defs are renamed with '.suffix' appended to their old name. """ assert r.is_concrete() - s = self.src.substitution(r, {}) # type: VarMap + s = self.src.substitution(r, {}) # type: VarAtomMap assert s is not None if (suffix is not None): diff --git a/lib/cretonne/meta/gen_legalizer.py b/lib/cretonne/meta/gen_legalizer.py index 8a76d15e98..7189695a71 100644 --- a/lib/cretonne/meta/gen_legalizer.py +++ b/lib/cretonne/meta/gen_legalizer.py @@ -21,7 +21,7 @@ from cdsl.typevar import TypeVar try: from typing import Sequence, List, Dict, Set, DefaultDict # noqa from cdsl.isa import TargetISA # noqa - from cdsl.ast import Def # noqa + from cdsl.ast import Def, VarAtomMap # noqa from cdsl.xform import XForm, XFormGroup # noqa from cdsl.typevar import TypeSet # noqa from cdsl.ti import TypeConstraint # noqa @@ -45,7 +45,7 @@ def get_runtime_typechecks(xform): # 1) Perform ti only on the source RTL. Accumulate any free tvs that have a # different inferred type in src, compared to the type inferred for both # src and dst. - symtab = {} # type: Dict[Var, Var] + symtab = {} # type: VarAtomMap src_copy = xform.src.copy(symtab) src_typenv = get_type_env(ti_rtl(src_copy, TypeEnv())) @@ -62,7 +62,9 @@ def get_runtime_typechecks(xform): assert v.get_typevar().singleton_type() is not None continue - src_ts = src_typenv[symtab[v]].get_typeset() + inner_v = symtab[v] + assert isinstance(inner_v, Var) + src_ts = src_typenv[inner_v].get_typeset() xform_ts = xform.ti[v].get_typeset() assert xform_ts.issubset(src_ts) diff --git a/lib/cretonne/meta/semantics/__init__.py b/lib/cretonne/meta/semantics/__init__.py index 1c1fee9b9f..94e32b652c 100644 --- a/lib/cretonne/meta/semantics/__init__.py +++ b/lib/cretonne/meta/semantics/__init__.py @@ -1,9 +1,11 @@ """Definitions for the semantics segment of the Cretonne language.""" from cdsl.ti import TypeEnv, ti_rtl, get_type_env +from cdsl.operands import ImmediateKind +from cdsl.ast import Var try: from typing import List, Dict, Tuple # noqa - from cdsl.ast import Var # noqa + from cdsl.ast import VarAtomMap # noqa from cdsl.xform import XForm, Rtl # noqa from cdsl.ti import VarTyping # noqa from cdsl.instructions import Instruction, InstructionSemantics # noqa @@ -16,34 +18,60 @@ def verify_semantics(inst, src, xforms): """ Verify that the semantics transforms in xforms correctly describe the instruction described by the src Rtl. This involves checking that: - 1) For all XForms x \in xforms, there is a Var substitution form src to - x.src - 2) For any possible concrete typing of src there is exactly 1 XForm x - in xforms that applies. + 0) src is a single instance of inst + 1) For all x\in xforms x.src is a single instance of inst + 2) For any concrete values V of Literals in inst: + For all concrete typing T of inst: + Exists single x \in xforms that applies to src conretazied to V + and T """ - # 0) The source rtl is always a single instruction - assert len(src.rtl) == 1 + # 0) The source rtl is always a single instance of inst + assert len(src.rtl) == 1 and src.rtl[0].expr.inst == inst - # 1) For all XForms x, x.src is structurally equivalent to src + # 1) For all XForms x, x.src is a single instance of inst for x in xforms: - assert src.substitution(x.src, {}) is not None,\ - "XForm {} doesn't describe instruction {}.".format(x, src) + assert len(x.src.rtl) == 1 and x.src.rtl[0].expr.inst == inst - # 2) Any possible typing for the instruction should be covered by - # exactly ONE semantic XForm - src = src.copy({}) - typenv = get_type_env(ti_rtl(src, TypeEnv())) - typenv.normalize() - typenv = typenv.extract() + variants = [src] # type: List[Rtl] - for t in typenv.concrete_typings(): - matching_xforms = [] # type: List[XForm] - for x in xforms: - # Translate t using x.symtab - t = {x.symtab[str(v)]: tv for (v, tv) in t.items()} - if (x.ti.permits(t)): - matching_xforms.append(x) + # 2) For all enumerated immediates, compute all the possible + # versions of src with the concrete value filled in. + for i in inst.imm_opnums: + op = inst.ins[i] + if not (isinstance(op.kind, ImmediateKind) and + op.kind.is_enumerable()): + continue - assert len(matching_xforms) == 1,\ - ("Possible typing {} of {} not matched by exactly one case " + - ": {}").format(t, inst, matching_xforms) + new_variants = [] # type: List[Rtl] + for rtl_var in variants: + s = {v: v for v in rtl_var.vars()} # type: VarAtomMap + arg = rtl_var.rtl[0].expr.args[i] + assert isinstance(arg, Var) + for val in op.kind.possible_values(): + s[arg] = val + new_variants.append(rtl_var.copy(s)) + variants = new_variants + + # For any possible version of the src with concrete enumerated immediates + for src in variants: + # 2) Any possible typing should be covered by exactly ONE semantic + # XForm + src = src.copy({}) + typenv = get_type_env(ti_rtl(src, TypeEnv())) + typenv.normalize() + typenv = typenv.extract() + + for t in typenv.concrete_typings(): + matching_xforms = [] # type: List[XForm] + for x in xforms: + if src.substitution(x.src, {}) is None: + continue + + # Translate t using x.symtab + t = {x.symtab[str(v)]: tv for (v, tv) in t.items()} + if (x.ti.permits(t)): + matching_xforms.append(x) + + assert len(matching_xforms) == 1,\ + ("Possible typing {} of {} not matched by exactly one case " + + ": {}").format(t, src.rtl[0], matching_xforms) diff --git a/lib/cretonne/meta/semantics/elaborate.py b/lib/cretonne/meta/semantics/elaborate.py index fc0ca98cc4..8d2ecd7fd6 100644 --- a/lib/cretonne/meta/semantics/elaborate.py +++ b/lib/cretonne/meta/semantics/elaborate.py @@ -10,7 +10,7 @@ from cdsl.ast import Var try: from typing import TYPE_CHECKING, Dict, Union, List, Set, Tuple # noqa from cdsl.xform import XForm # noqa - from cdsl.ast import Def, VarMap # noqa + from cdsl.ast import Def, VarAtomMap # noqa from cdsl.ti import VarTyping # noqa except ImportError: TYPE_CHECKING = False @@ -34,7 +34,13 @@ def find_matching_xform(d): if (subst is None): continue - if x.ti.permits({subst[v]: tv for (v, tv) in typing.items()}): + inner_typing = {} # type: VarTyping + for (v, tv) in typing.items(): + inner_v = subst[v] + assert isinstance(inner_v, Var) + inner_typing[inner_v] = tv + + if x.ti.permits(inner_typing): res.append(x) assert len(res) == 1, "Couldn't find semantic transform for {}".format(d) @@ -60,7 +66,7 @@ def cleanup_semantics(r, outputs): ... """ new_defs = [] # type: List[Def] - subst_m = {v: v for v in r.vars()} # type: VarMap + subst_m = {v: v for v in r.vars()} # type: VarAtomMap definition = {} # type: Dict[Var, Def] prim_to_bv_map = {} # type: Dict[Var, Def] diff --git a/lib/cretonne/meta/semantics/macros.py b/lib/cretonne/meta/semantics/macros.py new file mode 100644 index 0000000000..566bf92eae --- /dev/null +++ b/lib/cretonne/meta/semantics/macros.py @@ -0,0 +1,45 @@ +""" +Useful semantics "macro" instructions built on top of +the primitives. +""" +from __future__ import absolute_import +from cdsl.operands import Operand +from cdsl.typevar import TypeVar +from cdsl.instructions import Instruction, InstructionGroup +from base.types import b1 +from base.immediates import imm64 +from cdsl.ast import Var +from cdsl.xform import Rtl +from semantics.primitives import bv_from_imm64, bvite +import base.formats # noqa + +GROUP = InstructionGroup("primitive_macros", "Semantic macros instruction set") +AnyBV = TypeVar('AnyBV', bitvecs=True, doc="") +x = Var('x') +y = Var('y') +imm = Var('imm') +a = Var('a') + +# +# Bool-to-bv1 +# +BV1 = TypeVar("BV1", bitvecs=(1, 1), doc="") +bv1_op = Operand('bv1_op', BV1, doc="") +cond_op = Operand("cond", b1, doc="") +bool2bv = Instruction( + 'bool2bv', r"""Convert a b1 value to a 1-bit BV""", + ins=cond_op, outs=bv1_op) + +v1 = Var('v1') +v2 = Var('v2') +bvone = Var('bvone') +bvzero = Var('bvzero') +bool2bv.set_semantics( + v1 << bool2bv(v2), + Rtl( + bvone << bv_from_imm64(imm64(1)), + bvzero << bv_from_imm64(imm64(0)), + v1 << bvite(v2, bvone, bvzero) + )) + +GROUP.close() diff --git a/lib/cretonne/meta/semantics/primitives.py b/lib/cretonne/meta/semantics/primitives.py index 0a727c1cf9..656db41538 100644 --- a/lib/cretonne/meta/semantics/primitives.py +++ b/lib/cretonne/meta/semantics/primitives.py @@ -10,6 +10,8 @@ from cdsl.operands import Operand from cdsl.typevar import TypeVar from cdsl.instructions import Instruction, InstructionGroup from cdsl.ti import WiderOrEq +from base.types import b1 +from base.immediates import imm64 import base.formats # noqa GROUP = InstructionGroup("primitive", "Primitive instruction set") @@ -22,26 +24,40 @@ Real = TypeVar('Real', 'Any real type.', ints=True, floats=True, x = Operand('x', BV, doc="A semantic value X") y = Operand('x', BV, doc="A semantic value Y (same width as X)") a = Operand('a', BV, doc="A semantic value A (same width as X)") +cond = Operand('b', TypeVar.singleton(b1), doc='A b1 value') real = Operand('real', Real, doc="A real cretonne value") fromReal = Operand('fromReal', Real.to_bitvec(), doc="A real cretonne value converted to a BV") +# +# BV Conversion/Materialization +# prim_to_bv = Instruction( 'prim_to_bv', r""" Convert an SSA Value to a flat bitvector """, ins=(real), outs=(fromReal)) -# Note that when converting from BV->real values, we use a constraint and not a -# derived function. This reflects that fact that to_bitvec() is not a -# bijection. prim_from_bv = Instruction( 'prim_from_bv', r""" Convert a flat bitvector to a real SSA Value. """, ins=(fromReal), outs=(real)) +N = Operand('N', imm64) +bv_from_imm64 = Instruction( + 'bv_from_imm64', r"""Materialize an imm64 as a bitvector.""", + ins=(N), outs=a) + +# +# Generics +# +bvite = Instruction( + 'bvite', r"""Bitvector ternary operator""", + ins=(cond, x, y), outs=a) + + xh = Operand('xh', BV.half_width(), doc="A semantic value representing the upper half of X") xl = Operand('xl', BV.half_width(), @@ -67,12 +83,40 @@ bvadd = Instruction( of the operands. """, ins=(x, y), outs=a) - +# # Bitvector comparisons -cmp_res = Operand('cmp_res', BV1, doc="Single bit boolean") +# + +bveq = Instruction( + 'bveq', r"""Unsigned bitvector equality""", + ins=(x, y), outs=cond) +bvne = Instruction( + 'bveq', r"""Unsigned bitvector inequality""", + ins=(x, y), outs=cond) +bvsge = Instruction( + 'bvsge', r"""Signed bitvector greater or equal""", + ins=(x, y), outs=cond) +bvsgt = Instruction( + 'bvsgt', r"""Signed bitvector greater than""", + ins=(x, y), outs=cond) +bvsle = Instruction( + 'bvsle', r"""Signed bitvector less than or equal""", + ins=(x, y), outs=cond) +bvslt = Instruction( + 'bvslt', r"""Signed bitvector less than""", + ins=(x, y), outs=cond) +bvuge = Instruction( + 'bvuge', r"""Unsigned bitvector greater or equal""", + ins=(x, y), outs=cond) +bvugt = Instruction( + 'bvugt', r"""Unsigned bitvector greater than""", + ins=(x, y), outs=cond) +bvule = Instruction( + 'bvule', r"""Unsigned bitvector less than or equal""", + ins=(x, y), outs=cond) bvult = Instruction( - 'bvult', r"""Unsigned bitvector comparison""", - ins=(x, y), outs=cmp_res) + 'bvult', r"""Unsigned bitvector less than""", + ins=(x, y), outs=cond) # Extensions ToBV = TypeVar('ToBV', 'A bitvector type.', bitvecs=True) diff --git a/lib/cretonne/meta/semantics/smtlib.py b/lib/cretonne/meta/semantics/smtlib.py index c1b2526832..3a2c819153 100644 --- a/lib/cretonne/meta/semantics/smtlib.py +++ b/lib/cretonne/meta/semantics/smtlib.py @@ -14,7 +14,7 @@ from z3.z3core import Z3_mk_eq try: from typing import TYPE_CHECKING, Tuple, Dict, List # noqa from cdsl.xform import Rtl, XForm # noqa - from cdsl.ast import VarMap # noqa + from cdsl.ast import VarAtomMap, Atom # noqa from cdsl.ti import VarTyping # noqa if TYPE_CHECKING: from z3 import ExprRef, BitVecRef # noqa @@ -137,13 +137,13 @@ def to_smt(r): def equivalent(r1, r2, inp_m, out_m): - # type: (Rtl, Rtl, VarMap, VarMap) -> List[ExprRef] + # type: (Rtl, Rtl, VarAtomMap, VarAtomMap) -> List[ExprRef] """ Given: - concrete source Rtl r1 - concrete dest Rtl r2 - - VarMap inp_m mapping r1's non-bitvector inputs to r2 - - VarMap out_m mapping r1's non-bitvector outputs to r2 + - VarAtomMap inp_m mapping r1's non-bitvector inputs to r2 + - VarAtomMap out_m mapping r1's non-bitvector outputs to r2 Build a query checking whether r1 and r2 are semantically equivalent. If the returned query is unsatisfiable, then r1 and r2 are equivalent. @@ -156,17 +156,31 @@ def equivalent(r1, r2, inp_m, out_m): assert set(r2.free_vars()) == set(inp_m.values()) # Note that the same rule is not expected to hold for out_m due to - # temporaries/intermediates. + # temporaries/intermediates. out_m specified which values are enough for + # equivalence. # Rename the vars in r1 and r2 with unique suffixes to avoid conflicts - src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()} - dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()} + src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()} # type: VarAtomMap # noqa + dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()} # type: VarAtomMap # noqa r1 = r1.copy(src_m) r2 = r2.copy(dst_m) + def _translate(m, k_m, v_m): + # type: (VarAtomMap, VarAtomMap, VarAtomMap) -> VarAtomMap + """Obtain a new map from m, by mapping m's keys with k_m and m's values + with v_m""" + res = {} # type: VarAtomMap + for (k, v) in m1.items(): + new_k = k_m[k] + new_v = v_m[v] + assert isinstance(new_k, Var) + res[new_k] = new_v + + return res + # Convert inp_m, out_m in terms of variables with the .a/.b suffixes - inp_m = {src_m[k]: dst_m[v] for (k, v) in inp_m.items()} - out_m = {src_m[k]: dst_m[v] for (k, v) in out_m.items()} + inp_m = _translate(inp_m, src_m, dst_m) + out_m = _translate(out_m, src_m, dst_m) # Encode r1 and r2 as SMT queries (q1, m1) = to_smt(r1) @@ -175,12 +189,14 @@ def equivalent(r1, r2, inp_m, out_m): # Build an expression for the equality of real Cretone inputs of r1 and r2 args_eq_exp = [] # type: List[ExprRef] - for v in r1.free_vars(): - args_eq_exp.append(mk_eq(m1[v], m2[inp_m[v]])) + for (v1, v2) in inp_m.items(): + assert isinstance(v2, Var) + args_eq_exp.append(mk_eq(m1[v1], m2[v2])) # Build an expression for the equality of real Cretone outputs of r1 and r2 results_eq_exp = [] # type: List[ExprRef] for (v1, v2) in out_m.items(): + assert isinstance(v2, Var) results_eq_exp.append(mk_eq(m1[v1], m2[v2])) # Put the whole query toghether @@ -196,20 +212,22 @@ def xform_correct(x, typing): assert x.ti.permits(typing) # Create copies of the x.src and x.dst with their concrete types - src_m = {v: Var(v.name, typing[v]) for v in x.src.vars()} + src_m = {v: Var(v.name, typing[v]) for v in x.src.vars()} # type: VarAtomMap # noqa src = x.src.copy(src_m) dst = x.apply(src) dst_m = x.dst.substitution(dst, {}) # Build maps for the inputs/outputs for src->dst - inp_m = {} - out_m = {} + inp_m = {} # type: VarAtomMap + out_m = {} # type: VarAtomMap for v in x.src.vars(): + src_v = src_m[v] + assert isinstance(src_v, Var) if v.is_input(): - inp_m[src_m[v]] = dst_m[v] + inp_m[src_v] = dst_m[v] elif v.is_output(): - out_m[src_m[v]] = dst_m[v] + out_m[src_v] = dst_m[v] # Get the primitive semantic Rtls for src and dst prim_src = elaborate(src) diff --git a/lib/cretonne/meta/semantics/test_elaborate.py b/lib/cretonne/meta/semantics/test_elaborate.py index cb798295b9..9ca938bc1f 100644 --- a/lib/cretonne/meta/semantics/test_elaborate.py +++ b/lib/cretonne/meta/semantics/test_elaborate.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint from base.instructions import b1, icmp, ireduce, iadd_cout -from base.immediates import intcc +from base.immediates import intcc, imm64 from base.types import i64, i8, b32, i32, i16, f32 from cdsl.typevar import TypeVar from cdsl.ast import Var @@ -9,7 +9,7 @@ from cdsl.xform import Rtl from unittest import TestCase from .elaborate import elaborate from .primitives import prim_to_bv, bvsplit, prim_from_bv, bvconcat, bvadd, \ - bvult + bvult, bv_from_imm64, bvite import base.semantics # noqa @@ -366,9 +366,12 @@ class TestElaborate(TestCase): a = Var('a') c_out = Var('c_out') bvc_out = Var('bvc_out') + bc_out = Var('bc_out') bvx = Var('bvx') bvy = Var('bvy') bva = Var('bva') + bvone = Var('bvone') + bvzero = Var('bvzero') r = Rtl( (a, c_out) << iadd_cout.i32(x, y), ) @@ -378,10 +381,12 @@ class TestElaborate(TestCase): bvx << prim_to_bv.i32(x), bvy << prim_to_bv.i32(y), bva << bvadd.bv32(bvx, bvy), - bvc_out << bvult.bv32(bva, bvx), + bc_out << bvult.bv32(bva, bvx), + bvone << bv_from_imm64(imm64(1)), + bvzero << bv_from_imm64(imm64(0)), + bvc_out << bvite(bc_out, bvone, bvzero), a << prim_from_bv.i32(bva), c_out << prim_from_bv.b1(bvc_out) ) exp.cleanup_concrete_rtl() - assert concrete_rtls_eq(sem, exp)