Cleanup, typechecking and documentation nits
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
9767654dd7
commit
a324d60ccc
@@ -10,8 +10,9 @@ from .elaborate import elaborate
|
||||
|
||||
try:
|
||||
from typing import TYPE_CHECKING, Tuple # noqa
|
||||
from cdsl.xform import Rtl # noqa
|
||||
from cdsl.xform import Rtl, XForm # noqa
|
||||
from cdsl.ast import VarMap # noqa
|
||||
from cdsl.ti import VarTyping # noqa
|
||||
except ImportError:
|
||||
TYPE_CHECKING = False
|
||||
|
||||
@@ -34,10 +35,12 @@ def to_smt(r):
|
||||
assert r.is_concrete()
|
||||
# Should contain only primitives
|
||||
primitives = set(PRIMITIVES.instructions)
|
||||
assert all(d.expr.inst in primitives for d in r.rtl)
|
||||
assert set(d.expr.inst for d in r.rtl).issubset(primitives)
|
||||
|
||||
q = ""
|
||||
m = {} # type: VarMap
|
||||
|
||||
# Build declarations for any bitvector Vars
|
||||
for v in r.vars():
|
||||
typ = v.get_typevar().singleton_type()
|
||||
if not isinstance(typ, BVType):
|
||||
@@ -45,9 +48,11 @@ def to_smt(r):
|
||||
|
||||
q += "(declare-fun {} () {})\n".format(v.name, bvtype_to_sort(typ))
|
||||
|
||||
# Encode each instruction as a equality assertion
|
||||
for d in r.rtl:
|
||||
inst = d.expr.inst
|
||||
|
||||
# For prim_to_bv/prim_from_bv just update var_m. No assertion needed
|
||||
if inst == prim_to_bv:
|
||||
assert isinstance(d.expr.args[0], Var)
|
||||
m[d.expr.args[0]] = d.defs[0]
|
||||
@@ -82,13 +87,13 @@ def to_smt(r):
|
||||
.format(df, toW-fromW, arg)
|
||||
elif inst == bvsplit:
|
||||
arg = d.expr.args[0]
|
||||
assert isinstance(arg, Var)
|
||||
arg_typ = arg.get_typevar().singleton_type()
|
||||
width = arg_typ.width()
|
||||
assert (width % 2 == 0)
|
||||
|
||||
lo = d.defs[0]
|
||||
hi = d.defs[1]
|
||||
assert isinstance(arg, Var)
|
||||
|
||||
exp = "(and "
|
||||
exp += "(= {} ((_ extract {} {}) {})) "\
|
||||
@@ -97,9 +102,10 @@ def to_smt(r):
|
||||
.format(hi, width-1, width//2, arg)
|
||||
exp += ")"
|
||||
elif inst == bvconcat:
|
||||
assert isinstance(d.expr.args[0], Var) and \
|
||||
isinstance(d.expr.args[1], Var)
|
||||
lo = d.expr.args[0]
|
||||
hi = d.expr.args[1]
|
||||
assert isinstance(lo, Var) and isinstance(hi, Var)
|
||||
df = d.defs[0]
|
||||
|
||||
# Z3 Concat expects hi bits first, then lo bits
|
||||
@@ -127,6 +133,14 @@ def equivalent(r1, r2, inp_m, out_m):
|
||||
Otherwise, the satisfying example for the query gives us values
|
||||
for which the two Rtls disagree.
|
||||
"""
|
||||
# Sanity - inp_m is a bijection from the set of inputs of r1 to the set of
|
||||
# inputs of r2
|
||||
assert set(r1.free_vars()) == set(inp_m.keys())
|
||||
assert set(r2.free_vars()) == set(inp_m.values())
|
||||
|
||||
# Note that the same rule is not expected to hold for out_m due to
|
||||
# temporaries/intermediates.
|
||||
|
||||
# Rename the vars in r1 and r2 with unique suffixes to avoid conflicts
|
||||
src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()}
|
||||
dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()}
|
||||
@@ -145,7 +159,6 @@ def equivalent(r1, r2, inp_m, out_m):
|
||||
args_eq_exp = "(and \n"
|
||||
|
||||
for v in r1.free_vars():
|
||||
assert v in inp_m
|
||||
args_eq_exp += "(= {} {})\n".format(m1[v], m2[inp_m[v]])
|
||||
args_eq_exp += ")"
|
||||
|
||||
@@ -171,12 +184,12 @@ def equivalent(r1, r2, inp_m, out_m):
|
||||
def xform_correct(x, typing):
|
||||
# type: (XForm, VarTyping) -> str
|
||||
"""
|
||||
Given an XForm x and a concrete variable typing for x typing, build the
|
||||
smtlib query asserting that x is correct for the given typing.
|
||||
Given an XForm x and a concrete variable typing for x build the smtlib
|
||||
query asserting that x is correct for the given typing.
|
||||
"""
|
||||
assert x.ti.permits(typing)
|
||||
|
||||
# Create copies of the x.src and x.dst with the concrete types in typing.
|
||||
# Create copies of the x.src and x.dst with their concrete types
|
||||
src_m = {v: Var(v.name, typing[v]) for v in x.src.vars()}
|
||||
src = x.src.copy(src_m)
|
||||
dst = x.apply(src)
|
||||
@@ -191,8 +204,6 @@ def xform_correct(x, typing):
|
||||
inp_m[src_m[v]] = dst_m[v]
|
||||
elif v.is_output():
|
||||
out_m[src_m[v]] = dst_m[v]
|
||||
else:
|
||||
assert False, "Haven't decided what to do with intermediates yet"
|
||||
|
||||
# Get the primitive semantic Rtls for src and dst
|
||||
prim_src = elaborate(src)
|
||||
|
||||
Reference in New Issue
Block a user