Add semantics for several more iadd with carry; Add xform_correct() and doc cleanup

This commit is contained in:
Dimo
2017-07-27 17:16:26 -07:00
committed by Jakob Stoklund Olesen
parent b5e1e4d454
commit 1bbe644080
2 changed files with 171 additions and 42 deletions

View File

@@ -1,7 +1,8 @@
from __future__ import absolute_import from __future__ import absolute_import
from semantics.primitives import prim_to_bv, prim_from_bv, bvsplit, bvconcat,\ from semantics.primitives import prim_to_bv, prim_from_bv, bvsplit, bvconcat,\
bvadd, bvult, bvzeroext bvadd, bvult, bvzeroext
from .instructions import vsplit, vconcat, iadd, iadd_cout, icmp, bextend from .instructions import vsplit, vconcat, iadd, iadd_cout, icmp, bextend, \
isplit, iconcat, iadd_cin, iadd_carry
from .immediates import intcc from .immediates import intcc
from cdsl.xform import Rtl from cdsl.xform import Rtl
from cdsl.ast import Var from cdsl.ast import Var
@@ -13,18 +14,24 @@ y = Var('y')
a = Var('a') a = Var('a')
b = Var('b') b = Var('b')
c_out = Var('c_out') c_out = Var('c_out')
c_in = Var('c_in')
bvc_out = Var('bvc_out') bvc_out = Var('bvc_out')
bvc_in = Var('bvc_in')
xhi = Var('xhi') xhi = Var('xhi')
yhi = Var('yhi') yhi = Var('yhi')
ahi = Var('ahi') ahi = Var('ahi')
bhi = Var('bhi')
xlo = Var('xlo') xlo = Var('xlo')
ylo = Var('ylo') ylo = Var('ylo')
alo = Var('alo') alo = Var('alo')
blo = Var('blo')
lo = Var('lo') lo = Var('lo')
hi = Var('hi') hi = Var('hi')
bvx = Var('bvx') bvx = Var('bvx')
bvy = Var('bvy') bvy = Var('bvy')
bva = Var('bva') bva = Var('bva')
bvt = Var('bvt')
bvs = Var('bvs')
bva_wide = Var('bva_wide') bva_wide = Var('bva_wide')
bvlo = Var('bvlo') bvlo = Var('bvlo')
bvhi = Var('bvhi') bvhi = Var('bvhi')
@@ -51,16 +58,34 @@ vconcat.set_semantics(
iadd.set_semantics( iadd.set_semantics(
a << iadd(x, y), a << iadd(x, y),
(Rtl(bvx << prim_to_bv(x), (Rtl(
bvy << prim_to_bv(y), bvx << prim_to_bv(x),
bva << bvadd(bvx, bvy), bvy << prim_to_bv(y),
a << prim_from_bv(bva)), bva << bvadd(bvx, bvy),
[InTypeset(x.get_typevar(), ScalarTS)]), a << prim_from_bv(bva)
Rtl((xlo, xhi) << vsplit(x), ), [InTypeset(x.get_typevar(), ScalarTS)]),
Rtl(
(xlo, xhi) << vsplit(x),
(ylo, yhi) << vsplit(y), (ylo, yhi) << vsplit(y),
alo << iadd(xlo, ylo), alo << iadd(xlo, ylo),
ahi << iadd(xhi, yhi), ahi << iadd(xhi, yhi),
a << vconcat(alo, ahi))) a << vconcat(alo, ahi)
))
#
# Integer arithmetic with carry and/or borrow.
#
iadd_cin.set_semantics(
a << iadd_cin(x, y, c_in),
Rtl(
bvx << prim_to_bv(x),
bvy << prim_to_bv(y),
bvc_in << prim_to_bv(c_in),
bvs << bvzeroext(bvc_in),
bvt << bvadd(bvx, bvy),
bva << bvadd(bvt, bvs),
a << prim_from_bv(bva)
))
iadd_cout.set_semantics( iadd_cout.set_semantics(
(a, c_out) << iadd_cout(x, y), (a, c_out) << iadd_cout(x, y),
@@ -73,6 +98,20 @@ iadd_cout.set_semantics(
c_out << prim_from_bv(bvc_out) c_out << prim_from_bv(bvc_out)
)) ))
iadd_carry.set_semantics(
(a, c_out) << iadd_carry(x, y, c_in),
Rtl(
bvx << prim_to_bv(x),
bvy << prim_to_bv(y),
bvc_in << prim_to_bv(c_in),
bvs << bvzeroext(bvc_in),
bvt << bvadd(bvx, bvy),
bva << bvadd(bvt, bvs),
bvc_out << bvult(bva, bvx),
a << prim_from_bv(bva),
c_out << prim_from_bv(bvc_out)
))
bextend.set_semantics( bextend.set_semantics(
a << bextend(x), a << bextend(x),
(Rtl( (Rtl(
@@ -80,10 +119,12 @@ bextend.set_semantics(
bvy << bvzeroext(bvx), bvy << bvzeroext(bvx),
a << prim_from_bv(bvy) a << prim_from_bv(bvy)
), [InTypeset(x.get_typevar(), ScalarTS)]), ), [InTypeset(x.get_typevar(), ScalarTS)]),
Rtl((xlo, xhi) << vsplit(x), Rtl(
(xlo, xhi) << vsplit(x),
alo << bextend(xlo), alo << bextend(xlo),
ahi << bextend(xhi), ahi << bextend(xhi),
a << vconcat(alo, ahi))) a << vconcat(alo, ahi)
))
icmp.set_semantics( icmp.set_semantics(
a << icmp(intcc.ult, x, y), a << icmp(intcc.ult, x, y),
@@ -94,9 +135,47 @@ icmp.set_semantics(
bva_wide << bvzeroext(bva), bva_wide << bvzeroext(bva),
a << prim_from_bv(bva_wide), a << prim_from_bv(bva_wide),
), [InTypeset(x.get_typevar(), ScalarTS)]), ), [InTypeset(x.get_typevar(), ScalarTS)]),
Rtl((xlo, xhi) << vsplit(x), Rtl(
(xlo, xhi) << vsplit(x),
(ylo, yhi) << vsplit(y), (ylo, yhi) << vsplit(y),
alo << icmp(intcc.ult, xlo, ylo), alo << icmp(intcc.ult, xlo, ylo),
ahi << icmp(intcc.ult, xhi, yhi), ahi << icmp(intcc.ult, xhi, yhi),
b << vconcat(alo, ahi), b << vconcat(alo, ahi),
a << bextend(b))) a << bextend(b)
))
#
# Legalization helper instructions.
#
isplit.set_semantics(
(xlo, xhi) << isplit(x),
(Rtl(
bvx << prim_to_bv(x),
(bvlo, bvhi) << bvsplit(bvx),
xlo << prim_from_bv(bvlo),
xhi << prim_from_bv(bvhi)
), [InTypeset(x.get_typevar(), ScalarTS)]),
Rtl(
(a, b) << vsplit(x),
(alo, ahi) << isplit(a),
(blo, bhi) << isplit(b),
xlo << vconcat(alo, blo),
xhi << vconcat(bhi, bhi)
))
iconcat.set_semantics(
x << iconcat(xlo, xhi),
(Rtl(
bvlo << prim_to_bv(xlo),
bvhi << prim_to_bv(xhi),
bvx << bvconcat(bvlo, bvhi),
x << prim_from_bv(bvx)
), [InTypeset(x.get_typevar(), ScalarTS)]),
Rtl(
(alo, ahi) << vsplit(xlo),
(blo, bhi) << vsplit(xhi),
a << iconcat(alo, blo),
b << iconcat(ahi, bhi),
x << vconcat(a, b),
))

View File

@@ -3,9 +3,10 @@ Tools to emit SMTLIB bitvector queries encoding concrete RTLs containing only
primitive instructions. primitive instructions.
""" """
from .primitives import GROUP as PRIMITIVES, prim_from_bv, prim_to_bv, bvadd,\ from .primitives import GROUP as PRIMITIVES, prim_from_bv, prim_to_bv, bvadd,\
bvult, bvzeroext bvult, bvzeroext, bvsplit, bvconcat
from cdsl.ast import Var from cdsl.ast import Var
from cdsl.types import BVType from cdsl.types import BVType
from .elaborate import elaborate
try: try:
from typing import TYPE_CHECKING, Tuple # noqa from typing import TYPE_CHECKING, Tuple # noqa
@@ -78,7 +79,32 @@ def to_smt(r):
toW = df.get_typevar().singleton_type().width() toW = df.get_typevar().singleton_type().width()
exp = "(= {} ((_ zero_extend {}) {}))"\ exp = "(= {} ((_ zero_extend {}) {}))"\
.format(df, toW-fromW, arg, df) .format(df, toW-fromW, arg)
elif inst == bvsplit:
arg = d.expr.args[0]
arg_typ = arg.get_typevar().singleton_type()
width = arg_typ.width()
assert (width % 2 == 0)
lo = d.defs[0]
hi = d.defs[1]
assert isinstance(arg, Var)
exp = "(and "
exp += "(= {} ((_ extract {} {}) {})) "\
.format(lo, width//2-1, 0, arg)
exp += "(= {} ((_ extract {} {}) {}))"\
.format(hi, width-1, width//2, arg)
exp += ")"
elif inst == bvconcat:
lo = d.expr.args[0]
hi = d.expr.args[1]
assert isinstance(lo, Var) and isinstance(hi, Var)
df = d.defs[0]
# Z3 Concat expects hi bits first, then lo bits
exp = "(= {} (concat {} {}))"\
.format(df, hi, lo)
else: else:
assert False, "Unknown primitive instruction {}".format(inst) assert False, "Unknown primitive instruction {}".format(inst)
@@ -87,57 +113,49 @@ def to_smt(r):
return (q, m) return (q, m)
def equivalent(r1, r2, m): def equivalent(r1, r2, inp_m, out_m):
# type: (Rtl, Rtl, VarMap) -> str # type: (Rtl, Rtl, VarMap, VarMap) -> str
""" """
Given concrete primitive Rtls r1 and r2, and a VarMap m, mapping all Given:
non-primitive vars in r1 onto r2, return a query checking that the - concrete source Rtl r1
two Rtls are semantically equivalent. - concrete dest Rtl r2
- VarMap inp_m mapping r1's non-bitvector inputs to r2
- VarMap out_m mapping r1's non-bitvector outputs to r2
Build a query checking whether r1 and r2 are semantically equivalent.
If the returned query is unsatisfiable, then r1 and r2 are equivalent. If the returned query is unsatisfiable, then r1 and r2 are equivalent.
Otherwise, the satisfying example for the query gives us values Otherwise, the satisfying example for the query gives us values
for which the two Rtls disagree. for which the two Rtls disagree.
""" """
# Rename the vars in r1 and r2 to avoid conflicts # Rename the vars in r1 and r2 with unique suffixes to avoid conflicts
src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()} src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()}
dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()} dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()}
m = {src_m[k]: dst_m[v] for (k, v) in m.items()}
r1 = r1.copy(src_m) r1 = r1.copy(src_m)
r2 = r2.copy(dst_m) r2 = r2.copy(dst_m)
r1_nonprim_vars = set( # Convert inp_m, out_m in terms of variables with the .a/.b suffixes
[v for v in r1.vars() inp_m = {src_m[k]: dst_m[v] for (k, v) in inp_m.items()}
if not isinstance(v.get_typevar().singleton_type(), BVType)]) out_m = {src_m[k]: dst_m[v] for (k, v) in out_m.items()}
r2_nonprim_vars = set(
[v for v in r2.vars()
if not isinstance(v.get_typevar().singleton_type(), BVType)])
# Check that the map m maps all non real Cretone Vars from r1 onto r2
assert r1_nonprim_vars == set(m.keys())
assert r2_nonprim_vars == set(m.values())
# Encode r1 and r2 as SMT queries
(q1, m1) = to_smt(r1) (q1, m1) = to_smt(r1)
(q2, m2) = to_smt(r2) (q2, m2) = to_smt(r2)
# Build an expression for the equality of real Cretone inputs # Build an expression for the equality of real Cretone inputs of r1 and r2
args_eq_exp = "(and \n" args_eq_exp = "(and \n"
for v in r1.free_vars(): for v in r1.free_vars():
assert v in r1_nonprim_vars assert v in inp_m
args_eq_exp += "(= {} {})\n".format(m1[v], m2[m[v]]) args_eq_exp += "(= {} {})\n".format(m1[v], m2[inp_m[v]])
args_eq_exp += ")" args_eq_exp += ")"
# Build an expression for the equality of real Cretone defs # Build an expression for the equality of real Cretone outputs of r1 and r2
results_eq_exp = "(and \n" results_eq_exp = "(and \n"
for v in r1.definitions(): for (v1, v2) in out_m.items():
if (v not in r1_nonprim_vars): results_eq_exp += "(= {} {})\n".format(m1[v1], m2[v2])
continue
results_eq_exp += "(= {} {})\n".format(m1[v], m2[m[v]])
results_eq_exp += ")" results_eq_exp += ")"
# Put the whole query toghether
q = '; Rtl 1 declarations and assertions\n' + q1 q = '; Rtl 1 declarations and assertions\n' + q1
q += '; Rtl 2 declarations and assertions\n' + q2 q += '; Rtl 2 declarations and assertions\n' + q2
@@ -148,3 +166,35 @@ def equivalent(r1, r2, m):
'(assert (not {}))\n'.format(results_eq_exp) '(assert (not {}))\n'.format(results_eq_exp)
return q return q
def xform_correct(x, typing):
# type: (XForm, VarTyping) -> str
"""
Given an XForm x and a concrete variable typing for x typing, build the
smtlib query asserting that x is correct for the given typing.
"""
assert x.ti.permits(typing)
# Create copies of the x.src and x.dst with the concrete types in typing.
src_m = {v: Var(v.name, typing[v]) for v in x.src.vars()}
src = x.src.copy(src_m)
dst = x.apply(src)
dst_m = x.dst.substitution(dst, {})
# Build maps for the inputs/outputs for src->dst
inp_m = {}
out_m = {}
for v in x.src.vars():
if v.is_input():
inp_m[src_m[v]] = dst_m[v]
elif v.is_output():
out_m[src_m[v]] = dst_m[v]
else:
assert False, "Haven't decided what to do with intermediates yet"
# Get the primitive semantic Rtls for src and dst
prim_src = elaborate(src)
prim_dst = elaborate(dst)
return equivalent(prim_src, prim_dst, inp_m, out_m)