From 9e3f4e9195130e66c040e78e55021367fadfde6d Mon Sep 17 00:00:00 2001 From: d1m0 Date: Mon, 31 Jul 2017 16:02:27 -0700 Subject: [PATCH] Cleanup for PR #123 (#129) * Fix bextend semantics; Change smtlib.py to use z3 python bindings for query building instead of raw strings * Forgot the mypy stubs for z3 --- lib/cretonne/meta/base/semantics.py | 4 +- lib/cretonne/meta/semantics/primitives.py | 4 + lib/cretonne/meta/semantics/smtlib.py | 114 ++++++++-------- lib/cretonne/meta/stubs/z3/__init__.pyi | 151 ++++++++++++++++++++++ lib/cretonne/meta/stubs/z3/z3core.pyi | 3 + lib/cretonne/meta/stubs/z3/z3types.pyi | 12 ++ 6 files changed, 234 insertions(+), 54 deletions(-) create mode 100644 lib/cretonne/meta/stubs/z3/__init__.pyi create mode 100644 lib/cretonne/meta/stubs/z3/z3core.pyi create mode 100644 lib/cretonne/meta/stubs/z3/z3types.pyi diff --git a/lib/cretonne/meta/base/semantics.py b/lib/cretonne/meta/base/semantics.py index 582fdc0889..edf4c5f82e 100644 --- a/lib/cretonne/meta/base/semantics.py +++ b/lib/cretonne/meta/base/semantics.py @@ -1,6 +1,6 @@ from __future__ import absolute_import from semantics.primitives import prim_to_bv, prim_from_bv, bvsplit, bvconcat,\ - bvadd, bvult, bvzeroext + bvadd, bvult, bvzeroext, bvsignext from .instructions import vsplit, vconcat, iadd, iadd_cout, icmp, bextend, \ isplit, iconcat, iadd_cin, iadd_carry from .immediates import intcc @@ -116,7 +116,7 @@ bextend.set_semantics( a << bextend(x), (Rtl( bvx << prim_to_bv(x), - bvy << bvzeroext(bvx), + bvy << bvsignext(bvx), a << prim_from_bv(bvy) ), [InTypeset(x.get_typevar(), ScalarTS)]), Rtl( diff --git a/lib/cretonne/meta/semantics/primitives.py b/lib/cretonne/meta/semantics/primitives.py index 62d936bc31..0a727c1cf9 100644 --- a/lib/cretonne/meta/semantics/primitives.py +++ b/lib/cretonne/meta/semantics/primitives.py @@ -82,4 +82,8 @@ bvzeroext = Instruction( 'bvzeroext', r"""Unsigned bitvector extension""", ins=x, outs=x1, constraints=WiderOrEq(ToBV, BV)) +bvsignext = Instruction( + 'bvsignext', r"""Signed bitvector extension""", + ins=x, outs=x1, constraints=WiderOrEq(ToBV, BV)) + GROUP.close() diff --git a/lib/cretonne/meta/semantics/smtlib.py b/lib/cretonne/meta/semantics/smtlib.py index f84176dc3c..c1b2526832 100644 --- a/lib/cretonne/meta/semantics/smtlib.py +++ b/lib/cretonne/meta/semantics/smtlib.py @@ -3,64 +3,74 @@ Tools to emit SMTLIB bitvector queries encoding concrete RTLs containing only primitive instructions. """ from .primitives import GROUP as PRIMITIVES, prim_from_bv, prim_to_bv, bvadd,\ - bvult, bvzeroext, bvsplit, bvconcat + bvult, bvzeroext, bvsplit, bvconcat, bvsignext from cdsl.ast import Var from cdsl.types import BVType from .elaborate import elaborate +from z3 import BitVec, ZeroExt, SignExt, And, Extract, Concat, Not, Solver,\ + unsat, BoolRef, BitVecVal, If +from z3.z3core import Z3_mk_eq try: - from typing import TYPE_CHECKING, Tuple # noqa + from typing import TYPE_CHECKING, Tuple, Dict, List # noqa from cdsl.xform import Rtl, XForm # noqa from cdsl.ast import VarMap # noqa from cdsl.ti import VarTyping # noqa + if TYPE_CHECKING: + from z3 import ExprRef, BitVecRef # noqa + Z3VarMap = Dict[Var, BitVecRef] except ImportError: TYPE_CHECKING = False -def bvtype_to_sort(typ): - # type: (BVType) -> str - """Return the BitVec sort corresponding to a BVType""" - return "(_ BitVec {})".format(typ.bits) +# Use this for constructing a == b instead of == since MyPy doesn't +# accept overloading of __eq__ that doesn't return bool +def mk_eq(e1, e2): + # type: (ExprRef, ExprRef) -> ExprRef + """Return a z3 expression equivalent to e1 == e2""" + return BoolRef(Z3_mk_eq(e1.ctx_ref(), e1.as_ast(), e2.as_ast()), e1.ctx) def to_smt(r): - # type: (Rtl) -> Tuple[str, VarMap] + # type: (Rtl) -> Tuple[List[ExprRef], Z3VarMap] """ - Encode a concrete primitive Rtl r sa SMTLIB 2.0 query. + Encode a concrete primitive Rtl r sa z3 query. Returns a tuple (query, var_m) where: - - query is the resulting query. - - var_m is a map from Vars v with non-BVType to their Vars v' with - BVType s.t. v' holds the flattend bitvector value of v. + - query is a list of z3 expressions + - var_m is a map from Vars v with non-BVType to their correspodning z3 + bitvector variable. """ assert r.is_concrete() # Should contain only primitives primitives = set(PRIMITIVES.instructions) assert set(d.expr.inst for d in r.rtl).issubset(primitives) - q = "" - m = {} # type: VarMap + q = [] # type: List[ExprRef] + m = {} # type: Z3VarMap # Build declarations for any bitvector Vars + var_to_bv = {} # type: Z3VarMap for v in r.vars(): typ = v.get_typevar().singleton_type() if not isinstance(typ, BVType): continue - q += "(declare-fun {} () {})\n".format(v.name, bvtype_to_sort(typ)) + var_to_bv[v] = BitVec(v.name, typ.bits) # Encode each instruction as a equality assertion for d in r.rtl: inst = d.expr.inst + exp = None # type: ExprRef # For prim_to_bv/prim_from_bv just update var_m. No assertion needed if inst == prim_to_bv: assert isinstance(d.expr.args[0], Var) - m[d.expr.args[0]] = d.defs[0] + m[d.expr.args[0]] = var_to_bv[d.defs[0]] continue if inst == prim_from_bv: assert isinstance(d.expr.args[0], Var) - m[d.defs[0]] = d.expr.args[0] + m[d.defs[0]] = var_to_bv[d.expr.args[0]] continue if inst in [bvadd, bvult]: # Binary instructions @@ -70,12 +80,15 @@ def to_smt(r): df = d.defs[0] assert isinstance(lhs, Var) and isinstance(rhs, Var) - if inst in [bvadd]: # Normal binary - output type same as args - exp = "(= {} ({} {} {}))".format(df, inst.name, lhs, rhs) + if inst == bvadd: # Normal binary - output type same as args + exp = (var_to_bv[lhs] + var_to_bv[rhs]) else: + assert inst == bvult + exp = (var_to_bv[lhs] < var_to_bv[rhs]) # Comparison binary - need to convert bool to BitVec 1 - exp = "(= {} (ite ({} {} {}) #b1 #b0))"\ - .format(df, inst.name, lhs, rhs) + exp = If(exp, BitVecVal(1, 1), BitVecVal(0, 1)) + + exp = mk_eq(var_to_bv[df], exp) elif inst == bvzeroext: arg = d.expr.args[0] df = d.defs[0] @@ -83,8 +96,15 @@ def to_smt(r): fromW = arg.get_typevar().singleton_type().width() toW = df.get_typevar().singleton_type().width() - exp = "(= {} ((_ zero_extend {}) {}))"\ - .format(df, toW-fromW, arg) + exp = mk_eq(var_to_bv[df], ZeroExt(toW-fromW, var_to_bv[arg])) + elif inst == bvsignext: + arg = d.expr.args[0] + df = d.defs[0] + assert isinstance(arg, Var) + fromW = arg.get_typevar().singleton_type().width() + toW = df.get_typevar().singleton_type().width() + + exp = mk_eq(var_to_bv[df], SignExt(toW-fromW, var_to_bv[arg])) elif inst == bvsplit: arg = d.expr.args[0] assert isinstance(arg, Var) @@ -95,12 +115,10 @@ def to_smt(r): lo = d.defs[0] hi = d.defs[1] - exp = "(and " - exp += "(= {} ((_ extract {} {}) {})) "\ - .format(lo, width//2-1, 0, arg) - exp += "(= {} ((_ extract {} {}) {}))"\ - .format(hi, width-1, width//2, arg) - exp += ")" + exp = And(mk_eq(var_to_bv[lo], + Extract(width//2-1, 0, var_to_bv[arg])), + mk_eq(var_to_bv[hi], + Extract(width-1, width//2, var_to_bv[arg]))) elif inst == bvconcat: assert isinstance(d.expr.args[0], Var) and \ isinstance(d.expr.args[1], Var) @@ -109,18 +127,17 @@ def to_smt(r): df = d.defs[0] # Z3 Concat expects hi bits first, then lo bits - exp = "(= {} (concat {} {}))"\ - .format(df, hi, lo) + exp = mk_eq(var_to_bv[df], Concat(var_to_bv[hi], var_to_bv[lo])) else: assert False, "Unknown primitive instruction {}".format(inst) - q += "(assert {})\n".format(exp) + q.append(exp) return (q, m) def equivalent(r1, r2, inp_m, out_m): - # type: (Rtl, Rtl, VarMap, VarMap) -> str + # type: (Rtl, Rtl, VarMap, VarMap) -> List[ExprRef] """ Given: - concrete source Rtl r1 @@ -156,36 +173,25 @@ def equivalent(r1, r2, inp_m, out_m): (q2, m2) = to_smt(r2) # Build an expression for the equality of real Cretone inputs of r1 and r2 - args_eq_exp = "(and \n" + args_eq_exp = [] # type: List[ExprRef] for v in r1.free_vars(): - args_eq_exp += "(= {} {})\n".format(m1[v], m2[inp_m[v]]) - args_eq_exp += ")" + args_eq_exp.append(mk_eq(m1[v], m2[inp_m[v]])) # Build an expression for the equality of real Cretone outputs of r1 and r2 - results_eq_exp = "(and \n" + results_eq_exp = [] # type: List[ExprRef] for (v1, v2) in out_m.items(): - results_eq_exp += "(= {} {})\n".format(m1[v1], m2[v2]) - results_eq_exp += ")" + results_eq_exp.append(mk_eq(m1[v1], m2[v2])) # Put the whole query toghether - q = '; Rtl 1 declarations and assertions\n' + q1 - q += '; Rtl 2 declarations and assertions\n' + q2 - - q += '; Assert that the inputs of Rtl1 and Rtl2 are equal\n' + \ - '(assert {})\n'.format(args_eq_exp) - - q += '; Assert that the outputs of Rtl1 and Rtl2 are not equal\n' + \ - '(assert (not {}))\n'.format(results_eq_exp) - - return q + return q1 + q2 + args_eq_exp + [Not(And(*results_eq_exp))] def xform_correct(x, typing): - # type: (XForm, VarTyping) -> str + # type: (XForm, VarTyping) -> bool """ - Given an XForm x and a concrete variable typing for x build the smtlib - query asserting that x is correct for the given typing. + Given an XForm x and a concrete variable typing for x check whether x is + semantically preserving for the concrete typing. """ assert x.ti.permits(typing) @@ -208,4 +214,8 @@ def xform_correct(x, typing): # Get the primitive semantic Rtls for src and dst prim_src = elaborate(src) prim_dst = elaborate(dst) - return equivalent(prim_src, prim_dst, inp_m, out_m) + asserts = equivalent(prim_src, prim_dst, inp_m, out_m) + + s = Solver() + s.add(*asserts) + return s.check() == unsat diff --git a/lib/cretonne/meta/stubs/z3/__init__.pyi b/lib/cretonne/meta/stubs/z3/__init__.pyi new file mode 100644 index 0000000000..2fd6c8341f --- /dev/null +++ b/lib/cretonne/meta/stubs/z3/__init__.pyi @@ -0,0 +1,151 @@ +from typing import overload, Tuple, Any, List, Iterable, Union, TypeVar +from .z3types import Ast, ContextObj + +TExprRef = TypeVar("TExprRef", bound="ExprRef") + +class Context: + ... + +class Z3PPObject: + ... + +class AstRef(Z3PPObject): + @overload + def __init__(self, ast: Ast, ctx: Context) -> None: + self.ast: Ast = ... + self.ctx: Context= ... + + @overload + def __init__(self, ast: Ast) -> None: + self.ast: Ast = ... + self.ctx: Context= ... + def ctx_ref(self) -> ContextObj: ... + def as_ast(self) -> Ast: ... + def children(self) -> List[AstRef]: ... + +class SortRef(AstRef): + ... + +class FuncDeclRef(AstRef): + def arity(self) -> int: ... + def name(self) -> str: ... + +class ExprRef(AstRef): + def eq(self, other: ExprRef) -> ExprRef: ... + def sort(self) -> SortRef: ... + def decl(self) -> FuncDeclRef: ... + +class BoolSortRef(SortRef): + ... + +class BoolRef(ExprRef): + ... + + +def is_true(a: BoolRef) -> bool: ... +def is_false(a: BoolRef) -> bool: ... +def is_int_value(a: AstRef) -> bool: ... +def substitute(a: AstRef, *m: Tuple[AstRef, AstRef]) -> AstRef: ... + + +class ArithSortRef(SortRef): + ... + +class ArithRef(ExprRef): + def __neg__(self) -> ExprRef: ... + def __le__(self, other: ArithRef) -> ArithRef: ... + def __lt__(self, other: ArithRef) -> ArithRef: ... + def __ge__(self, other: ArithRef) -> ArithRef: ... + def __gt__(self, other: ArithRef) -> ArithRef: ... + def __add__(self, other: ArithRef) -> ArithRef: ... + def __sub__(self, other: ArithRef) -> ArithRef: ... + def __mul__(self, other: ArithRef) -> ArithRef: ... + def __div__(self, other: ArithRef) -> ArithRef: ... + def __mod__(self, other: ArithRef) -> ArithRef: ... + +class IntNumRef(ArithRef): + def as_long(self) -> int: ... + +class BitVecRef(ExprRef): + def __neg__(self) -> ExprRef: ... + def __le__(self, other: BitVecRef) -> ExprRef: ... + def __lt__(self, other: BitVecRef) -> ExprRef: ... + def __ge__(self, other: BitVecRef) -> ExprRef: ... + def __gt__(self, other: BitVecRef) -> ExprRef: ... + def __add__(self, other: BitVecRef) -> BitVecRef: ... + def __sub__(self, other: BitVecRef) -> BitVecRef: ... + def __mul__(self, other: BitVecRef) -> BitVecRef: ... + def __div__(self, other: BitVecRef) -> BitVecRef: ... + def __mod__(self, other: BitVecRef) -> BitVecRef: ... + +class BitVecNumRef(BitVecRef): + def as_long(self) -> int: ... + +class CheckSatResult: ... + +class ModelRef(Z3PPObject): + def __getitem__(self, k: FuncDeclRef) -> IntNumRef: ... + def decls(self) -> Iterable[FuncDeclRef]: ... + +class Solver(Z3PPObject): + @overload + def __init__(self) -> None: + self.ctx: Context = ... + @overload + def __init__(self, ctx:Context) -> None: + self.ctx: Context = ... + + def add(self, e:ExprRef) -> None: ... + def to_smt2(self) -> str: ... + def check(self) -> CheckSatResult: ... + def push(self) -> None: ... + def pop(self) -> None: ... + def model(self) -> ModelRef: ... + +sat: CheckSatResult = ... +unsat: CheckSatResult = ... + +@overload +def Int(name: str) -> ArithRef: ... +@overload +def Int(name: str, ctx: Context) -> ArithRef: ... + +@overload +def Bool(name: str) -> BoolRef: ... +@overload +def Bool(name: str, ctx: Context) -> BoolRef: ... + +def BitVec(name: str, width: int) -> BitVecRef: ... + +@overload +def parse_smt2_string(s: str) -> ExprRef: ... +@overload +def parse_smt2_string(s: str, ctx: Context) -> ExprRef: ... + +# Can't give more precise types here since func signature is +# a vararg list of ExprRef optionally followed by a Context +def Or(*args: Union[ExprRef, Context]) -> ExprRef: ... +def And(*args: Union[ExprRef, Context]) -> ExprRef: ... +@overload +def Not(p: ExprRef) -> ExprRef: ... +@overload +def Not(p: ExprRef, ctx: Context) -> ExprRef: ... +def Implies(a: ExprRef, b: ExprRef, ctx:Context) -> ExprRef: ... +def If(a: ExprRef, b:TExprRef, c:TExprRef) -> TExprRef: ... + +def ZeroExt(width: int, expr: BitVecRef) -> BitVecRef: ... +def SignExt(width: int, expr: BitVecRef) -> BitVecRef: ... +def Extract(hi: int, lo: int, expr: BitVecRef) -> BitVecRef: ... +def Concat(expr1: BitVecRef, expr2: BitVecRef) -> BitVecRef: ... + +def Function(name: str, *sig: Tuple[SortRef,...]) -> FuncDeclRef: ... + +def IntVal(val: int, ctx: Context) -> IntNumRef: ... +@overload +def BoolVal(val: bool, ctx: Context) -> BoolRef: ... +@overload +def BoolVal(val: bool) -> BoolRef: ... +@overload +def BitVecVal(val: int, bits: int, ctx: Context) -> BitVecNumRef: ... +@overload +def BitVecVal(val: int, bits: int) -> BitVecNumRef: ... diff --git a/lib/cretonne/meta/stubs/z3/z3core.pyi b/lib/cretonne/meta/stubs/z3/z3core.pyi new file mode 100644 index 0000000000..36f1f88792 --- /dev/null +++ b/lib/cretonne/meta/stubs/z3/z3core.pyi @@ -0,0 +1,3 @@ +from .z3types import Ast, ContextObj +def Z3_mk_eq(ctx: ContextObj, a: Ast, b: Ast) -> Ast: ... +def Z3_mk_div(ctx: ContextObj, a: Ast, b: Ast) -> Ast: ... diff --git a/lib/cretonne/meta/stubs/z3/z3types.pyi b/lib/cretonne/meta/stubs/z3/z3types.pyi new file mode 100644 index 0000000000..fa8fc446d1 --- /dev/null +++ b/lib/cretonne/meta/stubs/z3/z3types.pyi @@ -0,0 +1,12 @@ +from typing import Any + +class Z3Exception(Exception): + def __init__(self, a: Any) -> None: + self.value = a + ... + +class ContextObj: + ... + +class Ast: + ...