Cleanup for PR #123 (#129)

* Fix bextend semantics; Change smtlib.py to use z3 python bindings for query building instead of raw strings

* Forgot the mypy stubs for z3
This commit is contained in:
d1m0
2017-07-31 16:02:27 -07:00
committed by Jakob Stoklund Olesen
parent b74723cb68
commit 9e3f4e9195
6 changed files with 234 additions and 54 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import absolute_import
from semantics.primitives import prim_to_bv, prim_from_bv, bvsplit, bvconcat,\
bvadd, bvult, bvzeroext
bvadd, bvult, bvzeroext, bvsignext
from .instructions import vsplit, vconcat, iadd, iadd_cout, icmp, bextend, \
isplit, iconcat, iadd_cin, iadd_carry
from .immediates import intcc
@@ -116,7 +116,7 @@ bextend.set_semantics(
a << bextend(x),
(Rtl(
bvx << prim_to_bv(x),
bvy << bvzeroext(bvx),
bvy << bvsignext(bvx),
a << prim_from_bv(bvy)
), [InTypeset(x.get_typevar(), ScalarTS)]),
Rtl(

View File

@@ -82,4 +82,8 @@ bvzeroext = Instruction(
'bvzeroext', r"""Unsigned bitvector extension""",
ins=x, outs=x1, constraints=WiderOrEq(ToBV, BV))
bvsignext = Instruction(
'bvsignext', r"""Signed bitvector extension""",
ins=x, outs=x1, constraints=WiderOrEq(ToBV, BV))
GROUP.close()

View File

@@ -3,64 +3,74 @@ Tools to emit SMTLIB bitvector queries encoding concrete RTLs containing only
primitive instructions.
"""
from .primitives import GROUP as PRIMITIVES, prim_from_bv, prim_to_bv, bvadd,\
bvult, bvzeroext, bvsplit, bvconcat
bvult, bvzeroext, bvsplit, bvconcat, bvsignext
from cdsl.ast import Var
from cdsl.types import BVType
from .elaborate import elaborate
from z3 import BitVec, ZeroExt, SignExt, And, Extract, Concat, Not, Solver,\
unsat, BoolRef, BitVecVal, If
from z3.z3core import Z3_mk_eq
try:
from typing import TYPE_CHECKING, Tuple # noqa
from typing import TYPE_CHECKING, Tuple, Dict, List # noqa
from cdsl.xform import Rtl, XForm # noqa
from cdsl.ast import VarMap # noqa
from cdsl.ti import VarTyping # noqa
if TYPE_CHECKING:
from z3 import ExprRef, BitVecRef # noqa
Z3VarMap = Dict[Var, BitVecRef]
except ImportError:
TYPE_CHECKING = False
def bvtype_to_sort(typ):
# type: (BVType) -> str
"""Return the BitVec sort corresponding to a BVType"""
return "(_ BitVec {})".format(typ.bits)
# Use this for constructing a == b instead of == since MyPy doesn't
# accept overloading of __eq__ that doesn't return bool
def mk_eq(e1, e2):
# type: (ExprRef, ExprRef) -> ExprRef
"""Return a z3 expression equivalent to e1 == e2"""
return BoolRef(Z3_mk_eq(e1.ctx_ref(), e1.as_ast(), e2.as_ast()), e1.ctx)
def to_smt(r):
# type: (Rtl) -> Tuple[str, VarMap]
# type: (Rtl) -> Tuple[List[ExprRef], Z3VarMap]
"""
Encode a concrete primitive Rtl r sa SMTLIB 2.0 query.
Encode a concrete primitive Rtl r sa z3 query.
Returns a tuple (query, var_m) where:
- query is the resulting query.
- var_m is a map from Vars v with non-BVType to their Vars v' with
BVType s.t. v' holds the flattend bitvector value of v.
- query is a list of z3 expressions
- var_m is a map from Vars v with non-BVType to their correspodning z3
bitvector variable.
"""
assert r.is_concrete()
# Should contain only primitives
primitives = set(PRIMITIVES.instructions)
assert set(d.expr.inst for d in r.rtl).issubset(primitives)
q = ""
m = {} # type: VarMap
q = [] # type: List[ExprRef]
m = {} # type: Z3VarMap
# Build declarations for any bitvector Vars
var_to_bv = {} # type: Z3VarMap
for v in r.vars():
typ = v.get_typevar().singleton_type()
if not isinstance(typ, BVType):
continue
q += "(declare-fun {} () {})\n".format(v.name, bvtype_to_sort(typ))
var_to_bv[v] = BitVec(v.name, typ.bits)
# Encode each instruction as a equality assertion
for d in r.rtl:
inst = d.expr.inst
exp = None # type: ExprRef
# For prim_to_bv/prim_from_bv just update var_m. No assertion needed
if inst == prim_to_bv:
assert isinstance(d.expr.args[0], Var)
m[d.expr.args[0]] = d.defs[0]
m[d.expr.args[0]] = var_to_bv[d.defs[0]]
continue
if inst == prim_from_bv:
assert isinstance(d.expr.args[0], Var)
m[d.defs[0]] = d.expr.args[0]
m[d.defs[0]] = var_to_bv[d.expr.args[0]]
continue
if inst in [bvadd, bvult]: # Binary instructions
@@ -70,12 +80,15 @@ def to_smt(r):
df = d.defs[0]
assert isinstance(lhs, Var) and isinstance(rhs, Var)
if inst in [bvadd]: # Normal binary - output type same as args
exp = "(= {} ({} {} {}))".format(df, inst.name, lhs, rhs)
if inst == bvadd: # Normal binary - output type same as args
exp = (var_to_bv[lhs] + var_to_bv[rhs])
else:
assert inst == bvult
exp = (var_to_bv[lhs] < var_to_bv[rhs])
# Comparison binary - need to convert bool to BitVec 1
exp = "(= {} (ite ({} {} {}) #b1 #b0))"\
.format(df, inst.name, lhs, rhs)
exp = If(exp, BitVecVal(1, 1), BitVecVal(0, 1))
exp = mk_eq(var_to_bv[df], exp)
elif inst == bvzeroext:
arg = d.expr.args[0]
df = d.defs[0]
@@ -83,8 +96,15 @@ def to_smt(r):
fromW = arg.get_typevar().singleton_type().width()
toW = df.get_typevar().singleton_type().width()
exp = "(= {} ((_ zero_extend {}) {}))"\
.format(df, toW-fromW, arg)
exp = mk_eq(var_to_bv[df], ZeroExt(toW-fromW, var_to_bv[arg]))
elif inst == bvsignext:
arg = d.expr.args[0]
df = d.defs[0]
assert isinstance(arg, Var)
fromW = arg.get_typevar().singleton_type().width()
toW = df.get_typevar().singleton_type().width()
exp = mk_eq(var_to_bv[df], SignExt(toW-fromW, var_to_bv[arg]))
elif inst == bvsplit:
arg = d.expr.args[0]
assert isinstance(arg, Var)
@@ -95,12 +115,10 @@ def to_smt(r):
lo = d.defs[0]
hi = d.defs[1]
exp = "(and "
exp += "(= {} ((_ extract {} {}) {})) "\
.format(lo, width//2-1, 0, arg)
exp += "(= {} ((_ extract {} {}) {}))"\
.format(hi, width-1, width//2, arg)
exp += ")"
exp = And(mk_eq(var_to_bv[lo],
Extract(width//2-1, 0, var_to_bv[arg])),
mk_eq(var_to_bv[hi],
Extract(width-1, width//2, var_to_bv[arg])))
elif inst == bvconcat:
assert isinstance(d.expr.args[0], Var) and \
isinstance(d.expr.args[1], Var)
@@ -109,18 +127,17 @@ def to_smt(r):
df = d.defs[0]
# Z3 Concat expects hi bits first, then lo bits
exp = "(= {} (concat {} {}))"\
.format(df, hi, lo)
exp = mk_eq(var_to_bv[df], Concat(var_to_bv[hi], var_to_bv[lo]))
else:
assert False, "Unknown primitive instruction {}".format(inst)
q += "(assert {})\n".format(exp)
q.append(exp)
return (q, m)
def equivalent(r1, r2, inp_m, out_m):
# type: (Rtl, Rtl, VarMap, VarMap) -> str
# type: (Rtl, Rtl, VarMap, VarMap) -> List[ExprRef]
"""
Given:
- concrete source Rtl r1
@@ -156,36 +173,25 @@ def equivalent(r1, r2, inp_m, out_m):
(q2, m2) = to_smt(r2)
# Build an expression for the equality of real Cretone inputs of r1 and r2
args_eq_exp = "(and \n"
args_eq_exp = [] # type: List[ExprRef]
for v in r1.free_vars():
args_eq_exp += "(= {} {})\n".format(m1[v], m2[inp_m[v]])
args_eq_exp += ")"
args_eq_exp.append(mk_eq(m1[v], m2[inp_m[v]]))
# Build an expression for the equality of real Cretone outputs of r1 and r2
results_eq_exp = "(and \n"
results_eq_exp = [] # type: List[ExprRef]
for (v1, v2) in out_m.items():
results_eq_exp += "(= {} {})\n".format(m1[v1], m2[v2])
results_eq_exp += ")"
results_eq_exp.append(mk_eq(m1[v1], m2[v2]))
# Put the whole query toghether
q = '; Rtl 1 declarations and assertions\n' + q1
q += '; Rtl 2 declarations and assertions\n' + q2
q += '; Assert that the inputs of Rtl1 and Rtl2 are equal\n' + \
'(assert {})\n'.format(args_eq_exp)
q += '; Assert that the outputs of Rtl1 and Rtl2 are not equal\n' + \
'(assert (not {}))\n'.format(results_eq_exp)
return q
return q1 + q2 + args_eq_exp + [Not(And(*results_eq_exp))]
def xform_correct(x, typing):
# type: (XForm, VarTyping) -> str
# type: (XForm, VarTyping) -> bool
"""
Given an XForm x and a concrete variable typing for x build the smtlib
query asserting that x is correct for the given typing.
Given an XForm x and a concrete variable typing for x check whether x is
semantically preserving for the concrete typing.
"""
assert x.ti.permits(typing)
@@ -208,4 +214,8 @@ def xform_correct(x, typing):
# Get the primitive semantic Rtls for src and dst
prim_src = elaborate(src)
prim_dst = elaborate(dst)
return equivalent(prim_src, prim_dst, inp_m, out_m)
asserts = equivalent(prim_src, prim_dst, inp_m, out_m)
s = Solver()
s.add(*asserts)
return s.check() == unsat

View File

@@ -0,0 +1,151 @@
from typing import overload, Tuple, Any, List, Iterable, Union, TypeVar
from .z3types import Ast, ContextObj
TExprRef = TypeVar("TExprRef", bound="ExprRef")
class Context:
...
class Z3PPObject:
...
class AstRef(Z3PPObject):
@overload
def __init__(self, ast: Ast, ctx: Context) -> None:
self.ast: Ast = ...
self.ctx: Context= ...
@overload
def __init__(self, ast: Ast) -> None:
self.ast: Ast = ...
self.ctx: Context= ...
def ctx_ref(self) -> ContextObj: ...
def as_ast(self) -> Ast: ...
def children(self) -> List[AstRef]: ...
class SortRef(AstRef):
...
class FuncDeclRef(AstRef):
def arity(self) -> int: ...
def name(self) -> str: ...
class ExprRef(AstRef):
def eq(self, other: ExprRef) -> ExprRef: ...
def sort(self) -> SortRef: ...
def decl(self) -> FuncDeclRef: ...
class BoolSortRef(SortRef):
...
class BoolRef(ExprRef):
...
def is_true(a: BoolRef) -> bool: ...
def is_false(a: BoolRef) -> bool: ...
def is_int_value(a: AstRef) -> bool: ...
def substitute(a: AstRef, *m: Tuple[AstRef, AstRef]) -> AstRef: ...
class ArithSortRef(SortRef):
...
class ArithRef(ExprRef):
def __neg__(self) -> ExprRef: ...
def __le__(self, other: ArithRef) -> ArithRef: ...
def __lt__(self, other: ArithRef) -> ArithRef: ...
def __ge__(self, other: ArithRef) -> ArithRef: ...
def __gt__(self, other: ArithRef) -> ArithRef: ...
def __add__(self, other: ArithRef) -> ArithRef: ...
def __sub__(self, other: ArithRef) -> ArithRef: ...
def __mul__(self, other: ArithRef) -> ArithRef: ...
def __div__(self, other: ArithRef) -> ArithRef: ...
def __mod__(self, other: ArithRef) -> ArithRef: ...
class IntNumRef(ArithRef):
def as_long(self) -> int: ...
class BitVecRef(ExprRef):
def __neg__(self) -> ExprRef: ...
def __le__(self, other: BitVecRef) -> ExprRef: ...
def __lt__(self, other: BitVecRef) -> ExprRef: ...
def __ge__(self, other: BitVecRef) -> ExprRef: ...
def __gt__(self, other: BitVecRef) -> ExprRef: ...
def __add__(self, other: BitVecRef) -> BitVecRef: ...
def __sub__(self, other: BitVecRef) -> BitVecRef: ...
def __mul__(self, other: BitVecRef) -> BitVecRef: ...
def __div__(self, other: BitVecRef) -> BitVecRef: ...
def __mod__(self, other: BitVecRef) -> BitVecRef: ...
class BitVecNumRef(BitVecRef):
def as_long(self) -> int: ...
class CheckSatResult: ...
class ModelRef(Z3PPObject):
def __getitem__(self, k: FuncDeclRef) -> IntNumRef: ...
def decls(self) -> Iterable[FuncDeclRef]: ...
class Solver(Z3PPObject):
@overload
def __init__(self) -> None:
self.ctx: Context = ...
@overload
def __init__(self, ctx:Context) -> None:
self.ctx: Context = ...
def add(self, e:ExprRef) -> None: ...
def to_smt2(self) -> str: ...
def check(self) -> CheckSatResult: ...
def push(self) -> None: ...
def pop(self) -> None: ...
def model(self) -> ModelRef: ...
sat: CheckSatResult = ...
unsat: CheckSatResult = ...
@overload
def Int(name: str) -> ArithRef: ...
@overload
def Int(name: str, ctx: Context) -> ArithRef: ...
@overload
def Bool(name: str) -> BoolRef: ...
@overload
def Bool(name: str, ctx: Context) -> BoolRef: ...
def BitVec(name: str, width: int) -> BitVecRef: ...
@overload
def parse_smt2_string(s: str) -> ExprRef: ...
@overload
def parse_smt2_string(s: str, ctx: Context) -> ExprRef: ...
# Can't give more precise types here since func signature is
# a vararg list of ExprRef optionally followed by a Context
def Or(*args: Union[ExprRef, Context]) -> ExprRef: ...
def And(*args: Union[ExprRef, Context]) -> ExprRef: ...
@overload
def Not(p: ExprRef) -> ExprRef: ...
@overload
def Not(p: ExprRef, ctx: Context) -> ExprRef: ...
def Implies(a: ExprRef, b: ExprRef, ctx:Context) -> ExprRef: ...
def If(a: ExprRef, b:TExprRef, c:TExprRef) -> TExprRef: ...
def ZeroExt(width: int, expr: BitVecRef) -> BitVecRef: ...
def SignExt(width: int, expr: BitVecRef) -> BitVecRef: ...
def Extract(hi: int, lo: int, expr: BitVecRef) -> BitVecRef: ...
def Concat(expr1: BitVecRef, expr2: BitVecRef) -> BitVecRef: ...
def Function(name: str, *sig: Tuple[SortRef,...]) -> FuncDeclRef: ...
def IntVal(val: int, ctx: Context) -> IntNumRef: ...
@overload
def BoolVal(val: bool, ctx: Context) -> BoolRef: ...
@overload
def BoolVal(val: bool) -> BoolRef: ...
@overload
def BitVecVal(val: int, bits: int, ctx: Context) -> BitVecNumRef: ...
@overload
def BitVecVal(val: int, bits: int) -> BitVecNumRef: ...

View File

@@ -0,0 +1,3 @@
from .z3types import Ast, ContextObj
def Z3_mk_eq(ctx: ContextObj, a: Ast, b: Ast) -> Ast: ...
def Z3_mk_div(ctx: ContextObj, a: Ast, b: Ast) -> Ast: ...

View File

@@ -0,0 +1,12 @@
from typing import Any
class Z3Exception(Exception):
def __init__(self, a: Any) -> None:
self.value = a
...
class ContextObj:
...
class Ast:
...