* Fix bextend semantics; Change smtlib.py to use z3 python bindings for query building instead of raw strings * Forgot the mypy stubs for z3
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
b74723cb68
commit
9e3f4e9195
@@ -1,6 +1,6 @@
|
||||
from __future__ import absolute_import
|
||||
from semantics.primitives import prim_to_bv, prim_from_bv, bvsplit, bvconcat,\
|
||||
bvadd, bvult, bvzeroext
|
||||
bvadd, bvult, bvzeroext, bvsignext
|
||||
from .instructions import vsplit, vconcat, iadd, iadd_cout, icmp, bextend, \
|
||||
isplit, iconcat, iadd_cin, iadd_carry
|
||||
from .immediates import intcc
|
||||
@@ -116,7 +116,7 @@ bextend.set_semantics(
|
||||
a << bextend(x),
|
||||
(Rtl(
|
||||
bvx << prim_to_bv(x),
|
||||
bvy << bvzeroext(bvx),
|
||||
bvy << bvsignext(bvx),
|
||||
a << prim_from_bv(bvy)
|
||||
), [InTypeset(x.get_typevar(), ScalarTS)]),
|
||||
Rtl(
|
||||
|
||||
@@ -82,4 +82,8 @@ bvzeroext = Instruction(
|
||||
'bvzeroext', r"""Unsigned bitvector extension""",
|
||||
ins=x, outs=x1, constraints=WiderOrEq(ToBV, BV))
|
||||
|
||||
bvsignext = Instruction(
|
||||
'bvsignext', r"""Signed bitvector extension""",
|
||||
ins=x, outs=x1, constraints=WiderOrEq(ToBV, BV))
|
||||
|
||||
GROUP.close()
|
||||
|
||||
@@ -3,64 +3,74 @@ Tools to emit SMTLIB bitvector queries encoding concrete RTLs containing only
|
||||
primitive instructions.
|
||||
"""
|
||||
from .primitives import GROUP as PRIMITIVES, prim_from_bv, prim_to_bv, bvadd,\
|
||||
bvult, bvzeroext, bvsplit, bvconcat
|
||||
bvult, bvzeroext, bvsplit, bvconcat, bvsignext
|
||||
from cdsl.ast import Var
|
||||
from cdsl.types import BVType
|
||||
from .elaborate import elaborate
|
||||
from z3 import BitVec, ZeroExt, SignExt, And, Extract, Concat, Not, Solver,\
|
||||
unsat, BoolRef, BitVecVal, If
|
||||
from z3.z3core import Z3_mk_eq
|
||||
|
||||
try:
|
||||
from typing import TYPE_CHECKING, Tuple # noqa
|
||||
from typing import TYPE_CHECKING, Tuple, Dict, List # noqa
|
||||
from cdsl.xform import Rtl, XForm # noqa
|
||||
from cdsl.ast import VarMap # noqa
|
||||
from cdsl.ti import VarTyping # noqa
|
||||
if TYPE_CHECKING:
|
||||
from z3 import ExprRef, BitVecRef # noqa
|
||||
Z3VarMap = Dict[Var, BitVecRef]
|
||||
except ImportError:
|
||||
TYPE_CHECKING = False
|
||||
|
||||
|
||||
def bvtype_to_sort(typ):
|
||||
# type: (BVType) -> str
|
||||
"""Return the BitVec sort corresponding to a BVType"""
|
||||
return "(_ BitVec {})".format(typ.bits)
|
||||
# Use this for constructing a == b instead of == since MyPy doesn't
|
||||
# accept overloading of __eq__ that doesn't return bool
|
||||
def mk_eq(e1, e2):
|
||||
# type: (ExprRef, ExprRef) -> ExprRef
|
||||
"""Return a z3 expression equivalent to e1 == e2"""
|
||||
return BoolRef(Z3_mk_eq(e1.ctx_ref(), e1.as_ast(), e2.as_ast()), e1.ctx)
|
||||
|
||||
|
||||
def to_smt(r):
|
||||
# type: (Rtl) -> Tuple[str, VarMap]
|
||||
# type: (Rtl) -> Tuple[List[ExprRef], Z3VarMap]
|
||||
"""
|
||||
Encode a concrete primitive Rtl r sa SMTLIB 2.0 query.
|
||||
Encode a concrete primitive Rtl r sa z3 query.
|
||||
Returns a tuple (query, var_m) where:
|
||||
- query is the resulting query.
|
||||
- var_m is a map from Vars v with non-BVType to their Vars v' with
|
||||
BVType s.t. v' holds the flattend bitvector value of v.
|
||||
- query is a list of z3 expressions
|
||||
- var_m is a map from Vars v with non-BVType to their correspodning z3
|
||||
bitvector variable.
|
||||
"""
|
||||
assert r.is_concrete()
|
||||
# Should contain only primitives
|
||||
primitives = set(PRIMITIVES.instructions)
|
||||
assert set(d.expr.inst for d in r.rtl).issubset(primitives)
|
||||
|
||||
q = ""
|
||||
m = {} # type: VarMap
|
||||
q = [] # type: List[ExprRef]
|
||||
m = {} # type: Z3VarMap
|
||||
|
||||
# Build declarations for any bitvector Vars
|
||||
var_to_bv = {} # type: Z3VarMap
|
||||
for v in r.vars():
|
||||
typ = v.get_typevar().singleton_type()
|
||||
if not isinstance(typ, BVType):
|
||||
continue
|
||||
|
||||
q += "(declare-fun {} () {})\n".format(v.name, bvtype_to_sort(typ))
|
||||
var_to_bv[v] = BitVec(v.name, typ.bits)
|
||||
|
||||
# Encode each instruction as a equality assertion
|
||||
for d in r.rtl:
|
||||
inst = d.expr.inst
|
||||
|
||||
exp = None # type: ExprRef
|
||||
# For prim_to_bv/prim_from_bv just update var_m. No assertion needed
|
||||
if inst == prim_to_bv:
|
||||
assert isinstance(d.expr.args[0], Var)
|
||||
m[d.expr.args[0]] = d.defs[0]
|
||||
m[d.expr.args[0]] = var_to_bv[d.defs[0]]
|
||||
continue
|
||||
|
||||
if inst == prim_from_bv:
|
||||
assert isinstance(d.expr.args[0], Var)
|
||||
m[d.defs[0]] = d.expr.args[0]
|
||||
m[d.defs[0]] = var_to_bv[d.expr.args[0]]
|
||||
continue
|
||||
|
||||
if inst in [bvadd, bvult]: # Binary instructions
|
||||
@@ -70,12 +80,15 @@ def to_smt(r):
|
||||
df = d.defs[0]
|
||||
assert isinstance(lhs, Var) and isinstance(rhs, Var)
|
||||
|
||||
if inst in [bvadd]: # Normal binary - output type same as args
|
||||
exp = "(= {} ({} {} {}))".format(df, inst.name, lhs, rhs)
|
||||
if inst == bvadd: # Normal binary - output type same as args
|
||||
exp = (var_to_bv[lhs] + var_to_bv[rhs])
|
||||
else:
|
||||
assert inst == bvult
|
||||
exp = (var_to_bv[lhs] < var_to_bv[rhs])
|
||||
# Comparison binary - need to convert bool to BitVec 1
|
||||
exp = "(= {} (ite ({} {} {}) #b1 #b0))"\
|
||||
.format(df, inst.name, lhs, rhs)
|
||||
exp = If(exp, BitVecVal(1, 1), BitVecVal(0, 1))
|
||||
|
||||
exp = mk_eq(var_to_bv[df], exp)
|
||||
elif inst == bvzeroext:
|
||||
arg = d.expr.args[0]
|
||||
df = d.defs[0]
|
||||
@@ -83,8 +96,15 @@ def to_smt(r):
|
||||
fromW = arg.get_typevar().singleton_type().width()
|
||||
toW = df.get_typevar().singleton_type().width()
|
||||
|
||||
exp = "(= {} ((_ zero_extend {}) {}))"\
|
||||
.format(df, toW-fromW, arg)
|
||||
exp = mk_eq(var_to_bv[df], ZeroExt(toW-fromW, var_to_bv[arg]))
|
||||
elif inst == bvsignext:
|
||||
arg = d.expr.args[0]
|
||||
df = d.defs[0]
|
||||
assert isinstance(arg, Var)
|
||||
fromW = arg.get_typevar().singleton_type().width()
|
||||
toW = df.get_typevar().singleton_type().width()
|
||||
|
||||
exp = mk_eq(var_to_bv[df], SignExt(toW-fromW, var_to_bv[arg]))
|
||||
elif inst == bvsplit:
|
||||
arg = d.expr.args[0]
|
||||
assert isinstance(arg, Var)
|
||||
@@ -95,12 +115,10 @@ def to_smt(r):
|
||||
lo = d.defs[0]
|
||||
hi = d.defs[1]
|
||||
|
||||
exp = "(and "
|
||||
exp += "(= {} ((_ extract {} {}) {})) "\
|
||||
.format(lo, width//2-1, 0, arg)
|
||||
exp += "(= {} ((_ extract {} {}) {}))"\
|
||||
.format(hi, width-1, width//2, arg)
|
||||
exp += ")"
|
||||
exp = And(mk_eq(var_to_bv[lo],
|
||||
Extract(width//2-1, 0, var_to_bv[arg])),
|
||||
mk_eq(var_to_bv[hi],
|
||||
Extract(width-1, width//2, var_to_bv[arg])))
|
||||
elif inst == bvconcat:
|
||||
assert isinstance(d.expr.args[0], Var) and \
|
||||
isinstance(d.expr.args[1], Var)
|
||||
@@ -109,18 +127,17 @@ def to_smt(r):
|
||||
df = d.defs[0]
|
||||
|
||||
# Z3 Concat expects hi bits first, then lo bits
|
||||
exp = "(= {} (concat {} {}))"\
|
||||
.format(df, hi, lo)
|
||||
exp = mk_eq(var_to_bv[df], Concat(var_to_bv[hi], var_to_bv[lo]))
|
||||
else:
|
||||
assert False, "Unknown primitive instruction {}".format(inst)
|
||||
|
||||
q += "(assert {})\n".format(exp)
|
||||
q.append(exp)
|
||||
|
||||
return (q, m)
|
||||
|
||||
|
||||
def equivalent(r1, r2, inp_m, out_m):
|
||||
# type: (Rtl, Rtl, VarMap, VarMap) -> str
|
||||
# type: (Rtl, Rtl, VarMap, VarMap) -> List[ExprRef]
|
||||
"""
|
||||
Given:
|
||||
- concrete source Rtl r1
|
||||
@@ -156,36 +173,25 @@ def equivalent(r1, r2, inp_m, out_m):
|
||||
(q2, m2) = to_smt(r2)
|
||||
|
||||
# Build an expression for the equality of real Cretone inputs of r1 and r2
|
||||
args_eq_exp = "(and \n"
|
||||
args_eq_exp = [] # type: List[ExprRef]
|
||||
|
||||
for v in r1.free_vars():
|
||||
args_eq_exp += "(= {} {})\n".format(m1[v], m2[inp_m[v]])
|
||||
args_eq_exp += ")"
|
||||
args_eq_exp.append(mk_eq(m1[v], m2[inp_m[v]]))
|
||||
|
||||
# Build an expression for the equality of real Cretone outputs of r1 and r2
|
||||
results_eq_exp = "(and \n"
|
||||
results_eq_exp = [] # type: List[ExprRef]
|
||||
for (v1, v2) in out_m.items():
|
||||
results_eq_exp += "(= {} {})\n".format(m1[v1], m2[v2])
|
||||
results_eq_exp += ")"
|
||||
results_eq_exp.append(mk_eq(m1[v1], m2[v2]))
|
||||
|
||||
# Put the whole query toghether
|
||||
q = '; Rtl 1 declarations and assertions\n' + q1
|
||||
q += '; Rtl 2 declarations and assertions\n' + q2
|
||||
|
||||
q += '; Assert that the inputs of Rtl1 and Rtl2 are equal\n' + \
|
||||
'(assert {})\n'.format(args_eq_exp)
|
||||
|
||||
q += '; Assert that the outputs of Rtl1 and Rtl2 are not equal\n' + \
|
||||
'(assert (not {}))\n'.format(results_eq_exp)
|
||||
|
||||
return q
|
||||
return q1 + q2 + args_eq_exp + [Not(And(*results_eq_exp))]
|
||||
|
||||
|
||||
def xform_correct(x, typing):
|
||||
# type: (XForm, VarTyping) -> str
|
||||
# type: (XForm, VarTyping) -> bool
|
||||
"""
|
||||
Given an XForm x and a concrete variable typing for x build the smtlib
|
||||
query asserting that x is correct for the given typing.
|
||||
Given an XForm x and a concrete variable typing for x check whether x is
|
||||
semantically preserving for the concrete typing.
|
||||
"""
|
||||
assert x.ti.permits(typing)
|
||||
|
||||
@@ -208,4 +214,8 @@ def xform_correct(x, typing):
|
||||
# Get the primitive semantic Rtls for src and dst
|
||||
prim_src = elaborate(src)
|
||||
prim_dst = elaborate(dst)
|
||||
return equivalent(prim_src, prim_dst, inp_m, out_m)
|
||||
asserts = equivalent(prim_src, prim_dst, inp_m, out_m)
|
||||
|
||||
s = Solver()
|
||||
s.add(*asserts)
|
||||
return s.check() == unsat
|
||||
|
||||
151
lib/cretonne/meta/stubs/z3/__init__.pyi
Normal file
151
lib/cretonne/meta/stubs/z3/__init__.pyi
Normal file
@@ -0,0 +1,151 @@
|
||||
from typing import overload, Tuple, Any, List, Iterable, Union, TypeVar
|
||||
from .z3types import Ast, ContextObj
|
||||
|
||||
TExprRef = TypeVar("TExprRef", bound="ExprRef")
|
||||
|
||||
class Context:
|
||||
...
|
||||
|
||||
class Z3PPObject:
|
||||
...
|
||||
|
||||
class AstRef(Z3PPObject):
|
||||
@overload
|
||||
def __init__(self, ast: Ast, ctx: Context) -> None:
|
||||
self.ast: Ast = ...
|
||||
self.ctx: Context= ...
|
||||
|
||||
@overload
|
||||
def __init__(self, ast: Ast) -> None:
|
||||
self.ast: Ast = ...
|
||||
self.ctx: Context= ...
|
||||
def ctx_ref(self) -> ContextObj: ...
|
||||
def as_ast(self) -> Ast: ...
|
||||
def children(self) -> List[AstRef]: ...
|
||||
|
||||
class SortRef(AstRef):
|
||||
...
|
||||
|
||||
class FuncDeclRef(AstRef):
|
||||
def arity(self) -> int: ...
|
||||
def name(self) -> str: ...
|
||||
|
||||
class ExprRef(AstRef):
|
||||
def eq(self, other: ExprRef) -> ExprRef: ...
|
||||
def sort(self) -> SortRef: ...
|
||||
def decl(self) -> FuncDeclRef: ...
|
||||
|
||||
class BoolSortRef(SortRef):
|
||||
...
|
||||
|
||||
class BoolRef(ExprRef):
|
||||
...
|
||||
|
||||
|
||||
def is_true(a: BoolRef) -> bool: ...
|
||||
def is_false(a: BoolRef) -> bool: ...
|
||||
def is_int_value(a: AstRef) -> bool: ...
|
||||
def substitute(a: AstRef, *m: Tuple[AstRef, AstRef]) -> AstRef: ...
|
||||
|
||||
|
||||
class ArithSortRef(SortRef):
|
||||
...
|
||||
|
||||
class ArithRef(ExprRef):
|
||||
def __neg__(self) -> ExprRef: ...
|
||||
def __le__(self, other: ArithRef) -> ArithRef: ...
|
||||
def __lt__(self, other: ArithRef) -> ArithRef: ...
|
||||
def __ge__(self, other: ArithRef) -> ArithRef: ...
|
||||
def __gt__(self, other: ArithRef) -> ArithRef: ...
|
||||
def __add__(self, other: ArithRef) -> ArithRef: ...
|
||||
def __sub__(self, other: ArithRef) -> ArithRef: ...
|
||||
def __mul__(self, other: ArithRef) -> ArithRef: ...
|
||||
def __div__(self, other: ArithRef) -> ArithRef: ...
|
||||
def __mod__(self, other: ArithRef) -> ArithRef: ...
|
||||
|
||||
class IntNumRef(ArithRef):
|
||||
def as_long(self) -> int: ...
|
||||
|
||||
class BitVecRef(ExprRef):
|
||||
def __neg__(self) -> ExprRef: ...
|
||||
def __le__(self, other: BitVecRef) -> ExprRef: ...
|
||||
def __lt__(self, other: BitVecRef) -> ExprRef: ...
|
||||
def __ge__(self, other: BitVecRef) -> ExprRef: ...
|
||||
def __gt__(self, other: BitVecRef) -> ExprRef: ...
|
||||
def __add__(self, other: BitVecRef) -> BitVecRef: ...
|
||||
def __sub__(self, other: BitVecRef) -> BitVecRef: ...
|
||||
def __mul__(self, other: BitVecRef) -> BitVecRef: ...
|
||||
def __div__(self, other: BitVecRef) -> BitVecRef: ...
|
||||
def __mod__(self, other: BitVecRef) -> BitVecRef: ...
|
||||
|
||||
class BitVecNumRef(BitVecRef):
|
||||
def as_long(self) -> int: ...
|
||||
|
||||
class CheckSatResult: ...
|
||||
|
||||
class ModelRef(Z3PPObject):
|
||||
def __getitem__(self, k: FuncDeclRef) -> IntNumRef: ...
|
||||
def decls(self) -> Iterable[FuncDeclRef]: ...
|
||||
|
||||
class Solver(Z3PPObject):
|
||||
@overload
|
||||
def __init__(self) -> None:
|
||||
self.ctx: Context = ...
|
||||
@overload
|
||||
def __init__(self, ctx:Context) -> None:
|
||||
self.ctx: Context = ...
|
||||
|
||||
def add(self, e:ExprRef) -> None: ...
|
||||
def to_smt2(self) -> str: ...
|
||||
def check(self) -> CheckSatResult: ...
|
||||
def push(self) -> None: ...
|
||||
def pop(self) -> None: ...
|
||||
def model(self) -> ModelRef: ...
|
||||
|
||||
sat: CheckSatResult = ...
|
||||
unsat: CheckSatResult = ...
|
||||
|
||||
@overload
|
||||
def Int(name: str) -> ArithRef: ...
|
||||
@overload
|
||||
def Int(name: str, ctx: Context) -> ArithRef: ...
|
||||
|
||||
@overload
|
||||
def Bool(name: str) -> BoolRef: ...
|
||||
@overload
|
||||
def Bool(name: str, ctx: Context) -> BoolRef: ...
|
||||
|
||||
def BitVec(name: str, width: int) -> BitVecRef: ...
|
||||
|
||||
@overload
|
||||
def parse_smt2_string(s: str) -> ExprRef: ...
|
||||
@overload
|
||||
def parse_smt2_string(s: str, ctx: Context) -> ExprRef: ...
|
||||
|
||||
# Can't give more precise types here since func signature is
|
||||
# a vararg list of ExprRef optionally followed by a Context
|
||||
def Or(*args: Union[ExprRef, Context]) -> ExprRef: ...
|
||||
def And(*args: Union[ExprRef, Context]) -> ExprRef: ...
|
||||
@overload
|
||||
def Not(p: ExprRef) -> ExprRef: ...
|
||||
@overload
|
||||
def Not(p: ExprRef, ctx: Context) -> ExprRef: ...
|
||||
def Implies(a: ExprRef, b: ExprRef, ctx:Context) -> ExprRef: ...
|
||||
def If(a: ExprRef, b:TExprRef, c:TExprRef) -> TExprRef: ...
|
||||
|
||||
def ZeroExt(width: int, expr: BitVecRef) -> BitVecRef: ...
|
||||
def SignExt(width: int, expr: BitVecRef) -> BitVecRef: ...
|
||||
def Extract(hi: int, lo: int, expr: BitVecRef) -> BitVecRef: ...
|
||||
def Concat(expr1: BitVecRef, expr2: BitVecRef) -> BitVecRef: ...
|
||||
|
||||
def Function(name: str, *sig: Tuple[SortRef,...]) -> FuncDeclRef: ...
|
||||
|
||||
def IntVal(val: int, ctx: Context) -> IntNumRef: ...
|
||||
@overload
|
||||
def BoolVal(val: bool, ctx: Context) -> BoolRef: ...
|
||||
@overload
|
||||
def BoolVal(val: bool) -> BoolRef: ...
|
||||
@overload
|
||||
def BitVecVal(val: int, bits: int, ctx: Context) -> BitVecNumRef: ...
|
||||
@overload
|
||||
def BitVecVal(val: int, bits: int) -> BitVecNumRef: ...
|
||||
3
lib/cretonne/meta/stubs/z3/z3core.pyi
Normal file
3
lib/cretonne/meta/stubs/z3/z3core.pyi
Normal file
@@ -0,0 +1,3 @@
|
||||
from .z3types import Ast, ContextObj
|
||||
def Z3_mk_eq(ctx: ContextObj, a: Ast, b: Ast) -> Ast: ...
|
||||
def Z3_mk_div(ctx: ContextObj, a: Ast, b: Ast) -> Ast: ...
|
||||
12
lib/cretonne/meta/stubs/z3/z3types.pyi
Normal file
12
lib/cretonne/meta/stubs/z3/z3types.pyi
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
class Z3Exception(Exception):
|
||||
def __init__(self, a: Any) -> None:
|
||||
self.value = a
|
||||
...
|
||||
|
||||
class ContextObj:
|
||||
...
|
||||
|
||||
class Ast:
|
||||
...
|
||||
Reference in New Issue
Block a user