From 1bbe6440801721bbef558fb266e8a59fe4003df7 Mon Sep 17 00:00:00 2001 From: Dimo Date: Thu, 27 Jul 2017 17:16:26 -0700 Subject: [PATCH] Add semantics for several more iadd with carry; Add xform_correct() and doc cleanup --- lib/cretonne/meta/base/semantics.py | 103 +++++++++++++++++++++--- lib/cretonne/meta/semantics/smtlib.py | 110 +++++++++++++++++++------- 2 files changed, 171 insertions(+), 42 deletions(-) diff --git a/lib/cretonne/meta/base/semantics.py b/lib/cretonne/meta/base/semantics.py index cdf071f7d9..582fdc0889 100644 --- a/lib/cretonne/meta/base/semantics.py +++ b/lib/cretonne/meta/base/semantics.py @@ -1,7 +1,8 @@ from __future__ import absolute_import from semantics.primitives import prim_to_bv, prim_from_bv, bvsplit, bvconcat,\ bvadd, bvult, bvzeroext -from .instructions import vsplit, vconcat, iadd, iadd_cout, icmp, bextend +from .instructions import vsplit, vconcat, iadd, iadd_cout, icmp, bextend, \ + isplit, iconcat, iadd_cin, iadd_carry from .immediates import intcc from cdsl.xform import Rtl from cdsl.ast import Var @@ -13,18 +14,24 @@ y = Var('y') a = Var('a') b = Var('b') c_out = Var('c_out') +c_in = Var('c_in') bvc_out = Var('bvc_out') +bvc_in = Var('bvc_in') xhi = Var('xhi') yhi = Var('yhi') ahi = Var('ahi') +bhi = Var('bhi') xlo = Var('xlo') ylo = Var('ylo') alo = Var('alo') +blo = Var('blo') lo = Var('lo') hi = Var('hi') bvx = Var('bvx') bvy = Var('bvy') bva = Var('bva') +bvt = Var('bvt') +bvs = Var('bvs') bva_wide = Var('bva_wide') bvlo = Var('bvlo') bvhi = Var('bvhi') @@ -51,16 +58,34 @@ vconcat.set_semantics( iadd.set_semantics( a << iadd(x, y), - (Rtl(bvx << prim_to_bv(x), - bvy << prim_to_bv(y), - bva << bvadd(bvx, bvy), - a << prim_from_bv(bva)), - [InTypeset(x.get_typevar(), ScalarTS)]), - Rtl((xlo, xhi) << vsplit(x), + (Rtl( + bvx << prim_to_bv(x), + bvy << prim_to_bv(y), + bva << bvadd(bvx, bvy), + a << prim_from_bv(bva) + ), [InTypeset(x.get_typevar(), ScalarTS)]), + Rtl( + (xlo, xhi) << vsplit(x), (ylo, yhi) << vsplit(y), alo << iadd(xlo, ylo), ahi << iadd(xhi, yhi), - a << vconcat(alo, ahi))) + a << vconcat(alo, ahi) + )) + +# +# Integer arithmetic with carry and/or borrow. +# +iadd_cin.set_semantics( + a << iadd_cin(x, y, c_in), + Rtl( + bvx << prim_to_bv(x), + bvy << prim_to_bv(y), + bvc_in << prim_to_bv(c_in), + bvs << bvzeroext(bvc_in), + bvt << bvadd(bvx, bvy), + bva << bvadd(bvt, bvs), + a << prim_from_bv(bva) + )) iadd_cout.set_semantics( (a, c_out) << iadd_cout(x, y), @@ -73,6 +98,20 @@ iadd_cout.set_semantics( c_out << prim_from_bv(bvc_out) )) +iadd_carry.set_semantics( + (a, c_out) << iadd_carry(x, y, c_in), + Rtl( + bvx << prim_to_bv(x), + bvy << prim_to_bv(y), + bvc_in << prim_to_bv(c_in), + bvs << bvzeroext(bvc_in), + bvt << bvadd(bvx, bvy), + bva << bvadd(bvt, bvs), + bvc_out << bvult(bva, bvx), + a << prim_from_bv(bva), + c_out << prim_from_bv(bvc_out) + )) + bextend.set_semantics( a << bextend(x), (Rtl( @@ -80,10 +119,12 @@ bextend.set_semantics( bvy << bvzeroext(bvx), a << prim_from_bv(bvy) ), [InTypeset(x.get_typevar(), ScalarTS)]), - Rtl((xlo, xhi) << vsplit(x), + Rtl( + (xlo, xhi) << vsplit(x), alo << bextend(xlo), ahi << bextend(xhi), - a << vconcat(alo, ahi))) + a << vconcat(alo, ahi) + )) icmp.set_semantics( a << icmp(intcc.ult, x, y), @@ -94,9 +135,47 @@ icmp.set_semantics( bva_wide << bvzeroext(bva), a << prim_from_bv(bva_wide), ), [InTypeset(x.get_typevar(), ScalarTS)]), - Rtl((xlo, xhi) << vsplit(x), + Rtl( + (xlo, xhi) << vsplit(x), (ylo, yhi) << vsplit(y), alo << icmp(intcc.ult, xlo, ylo), ahi << icmp(intcc.ult, xhi, yhi), b << vconcat(alo, ahi), - a << bextend(b))) + a << bextend(b) + )) + +# +# Legalization helper instructions. +# + +isplit.set_semantics( + (xlo, xhi) << isplit(x), + (Rtl( + bvx << prim_to_bv(x), + (bvlo, bvhi) << bvsplit(bvx), + xlo << prim_from_bv(bvlo), + xhi << prim_from_bv(bvhi) + ), [InTypeset(x.get_typevar(), ScalarTS)]), + Rtl( + (a, b) << vsplit(x), + (alo, ahi) << isplit(a), + (blo, bhi) << isplit(b), + xlo << vconcat(alo, blo), + xhi << vconcat(bhi, bhi) + )) + +iconcat.set_semantics( + x << iconcat(xlo, xhi), + (Rtl( + bvlo << prim_to_bv(xlo), + bvhi << prim_to_bv(xhi), + bvx << bvconcat(bvlo, bvhi), + x << prim_from_bv(bvx) + ), [InTypeset(x.get_typevar(), ScalarTS)]), + Rtl( + (alo, ahi) << vsplit(xlo), + (blo, bhi) << vsplit(xhi), + a << iconcat(alo, blo), + b << iconcat(ahi, bhi), + x << vconcat(a, b), + )) diff --git a/lib/cretonne/meta/semantics/smtlib.py b/lib/cretonne/meta/semantics/smtlib.py index 53dac682e5..2bf94515ab 100644 --- a/lib/cretonne/meta/semantics/smtlib.py +++ b/lib/cretonne/meta/semantics/smtlib.py @@ -3,9 +3,10 @@ Tools to emit SMTLIB bitvector queries encoding concrete RTLs containing only primitive instructions. """ from .primitives import GROUP as PRIMITIVES, prim_from_bv, prim_to_bv, bvadd,\ - bvult, bvzeroext + bvult, bvzeroext, bvsplit, bvconcat from cdsl.ast import Var from cdsl.types import BVType +from .elaborate import elaborate try: from typing import TYPE_CHECKING, Tuple # noqa @@ -78,7 +79,32 @@ def to_smt(r): toW = df.get_typevar().singleton_type().width() exp = "(= {} ((_ zero_extend {}) {}))"\ - .format(df, toW-fromW, arg, df) + .format(df, toW-fromW, arg) + elif inst == bvsplit: + arg = d.expr.args[0] + arg_typ = arg.get_typevar().singleton_type() + width = arg_typ.width() + assert (width % 2 == 0) + + lo = d.defs[0] + hi = d.defs[1] + assert isinstance(arg, Var) + + exp = "(and " + exp += "(= {} ((_ extract {} {}) {})) "\ + .format(lo, width//2-1, 0, arg) + exp += "(= {} ((_ extract {} {}) {}))"\ + .format(hi, width-1, width//2, arg) + exp += ")" + elif inst == bvconcat: + lo = d.expr.args[0] + hi = d.expr.args[1] + assert isinstance(lo, Var) and isinstance(hi, Var) + df = d.defs[0] + + # Z3 Concat expects hi bits first, then lo bits + exp = "(= {} (concat {} {}))"\ + .format(df, hi, lo) else: assert False, "Unknown primitive instruction {}".format(inst) @@ -87,57 +113,49 @@ def to_smt(r): return (q, m) -def equivalent(r1, r2, m): - # type: (Rtl, Rtl, VarMap) -> str +def equivalent(r1, r2, inp_m, out_m): + # type: (Rtl, Rtl, VarMap, VarMap) -> str """ - Given concrete primitive Rtls r1 and r2, and a VarMap m, mapping all - non-primitive vars in r1 onto r2, return a query checking that the - two Rtls are semantically equivalent. + Given: + - concrete source Rtl r1 + - concrete dest Rtl r2 + - VarMap inp_m mapping r1's non-bitvector inputs to r2 + - VarMap out_m mapping r1's non-bitvector outputs to r2 + Build a query checking whether r1 and r2 are semantically equivalent. If the returned query is unsatisfiable, then r1 and r2 are equivalent. Otherwise, the satisfying example for the query gives us values for which the two Rtls disagree. """ - # Rename the vars in r1 and r2 to avoid conflicts + # Rename the vars in r1 and r2 with unique suffixes to avoid conflicts src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()} dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()} - m = {src_m[k]: dst_m[v] for (k, v) in m.items()} - r1 = r1.copy(src_m) r2 = r2.copy(dst_m) - r1_nonprim_vars = set( - [v for v in r1.vars() - if not isinstance(v.get_typevar().singleton_type(), BVType)]) - - r2_nonprim_vars = set( - [v for v in r2.vars() - if not isinstance(v.get_typevar().singleton_type(), BVType)]) - - # Check that the map m maps all non real Cretone Vars from r1 onto r2 - assert r1_nonprim_vars == set(m.keys()) - assert r2_nonprim_vars == set(m.values()) + # Convert inp_m, out_m in terms of variables with the .a/.b suffixes + inp_m = {src_m[k]: dst_m[v] for (k, v) in inp_m.items()} + out_m = {src_m[k]: dst_m[v] for (k, v) in out_m.items()} + # Encode r1 and r2 as SMT queries (q1, m1) = to_smt(r1) (q2, m2) = to_smt(r2) - # Build an expression for the equality of real Cretone inputs + # Build an expression for the equality of real Cretone inputs of r1 and r2 args_eq_exp = "(and \n" for v in r1.free_vars(): - assert v in r1_nonprim_vars - args_eq_exp += "(= {} {})\n".format(m1[v], m2[m[v]]) + assert v in inp_m + args_eq_exp += "(= {} {})\n".format(m1[v], m2[inp_m[v]]) args_eq_exp += ")" - # Build an expression for the equality of real Cretone defs + # Build an expression for the equality of real Cretone outputs of r1 and r2 results_eq_exp = "(and \n" - for v in r1.definitions(): - if (v not in r1_nonprim_vars): - continue - - results_eq_exp += "(= {} {})\n".format(m1[v], m2[m[v]]) + for (v1, v2) in out_m.items(): + results_eq_exp += "(= {} {})\n".format(m1[v1], m2[v2]) results_eq_exp += ")" + # Put the whole query toghether q = '; Rtl 1 declarations and assertions\n' + q1 q += '; Rtl 2 declarations and assertions\n' + q2 @@ -148,3 +166,35 @@ def equivalent(r1, r2, m): '(assert (not {}))\n'.format(results_eq_exp) return q + + +def xform_correct(x, typing): + # type: (XForm, VarTyping) -> str + """ + Given an XForm x and a concrete variable typing for x typing, build the + smtlib query asserting that x is correct for the given typing. + """ + assert x.ti.permits(typing) + + # Create copies of the x.src and x.dst with the concrete types in typing. + src_m = {v: Var(v.name, typing[v]) for v in x.src.vars()} + src = x.src.copy(src_m) + dst = x.apply(src) + dst_m = x.dst.substitution(dst, {}) + + # Build maps for the inputs/outputs for src->dst + inp_m = {} + out_m = {} + + for v in x.src.vars(): + if v.is_input(): + inp_m[src_m[v]] = dst_m[v] + elif v.is_output(): + out_m[src_m[v]] = dst_m[v] + else: + assert False, "Haven't decided what to do with intermediates yet" + + # Get the primitive semantic Rtls for src and dst + prim_src = elaborate(src) + prim_dst = elaborate(dst) + return equivalent(prim_src, prim_dst, inp_m, out_m)