Cleanup, typechecking and documentation nits

This commit is contained in:
Dimo
2017-07-27 18:05:38 -07:00
committed by Jakob Stoklund Olesen
parent 9767654dd7
commit a324d60ccc
3 changed files with 43 additions and 23 deletions

View File

@@ -5,7 +5,6 @@ from __future__ import absolute_import
from .ast import Def, Var, Apply from .ast import Def, Var, Apply
from .ti import ti_xform, TypeEnv, get_type_env from .ti import ti_xform, TypeEnv, get_type_env
from functools import reduce from functools import reduce
from .typevar import TypeVar
try: try:
from typing import Union, Iterator, Sequence, Iterable, List, Dict # noqa from typing import Union, Iterator, Sequence, Iterable, List, Dict # noqa
@@ -13,6 +12,7 @@ try:
from .ast import Expr, VarMap # noqa from .ast import Expr, VarMap # noqa
from .isa import TargetISA # noqa from .isa import TargetISA # noqa
from .ti import TypeConstraint # noqa from .ti import TypeConstraint # noqa
from .typevar import TypeVar # noqa
DefApply = Union[Def, Apply] DefApply = Union[Def, Apply]
except ImportError: except ImportError:
pass pass
@@ -70,12 +70,17 @@ class Rtl(object):
def free_vars(self): def free_vars(self):
# type: () -> Set[Var] # type: () -> Set[Var]
""" Return the set of free Vars used in self""" """Return the set of free Vars corresp. to SSA vals used in self"""
def flow_f(s, d): def flow_f(s, d):
# type: (Set[Var], Def) -> Set[Var] # type: (Set[Var], Def) -> Set[Var]
"""Compute the change in the set of free vars across a Def""" """Compute the change in the set of free vars across a Def"""
s = s.difference(set(d.defs)) s = s.difference(set(d.defs))
return s.union(set(a for a in d.expr.args if isinstance(a, Var))) uses = set(d.expr.args[i] for i in d.expr.inst.value_opnums)
for v in uses:
assert isinstance(v, Var)
s.add(v)
return s
return reduce(flow_f, reversed(self.rtl), set([])) return reduce(flow_f, reversed(self.rtl), set([]))
@@ -107,8 +112,9 @@ class Rtl(object):
# type: (Rtl) -> None # type: (Rtl) -> None
""" """
Given that there is only 1 possible concrete typing T for self, assign Given that there is only 1 possible concrete typing T for self, assign
a singleton TV with the single type t=T[v] for each Var v \in self. a singleton TV with type t=T[v] for each Var v \in self. Its an error
Its an error to call this on an Rtl with more than 1 possible typing. to call this on an Rtl with more than 1 possible typing. This modifies
the Rtl in-place.
""" """
from .ti import ti_rtl, TypeEnv from .ti import ti_rtl, TypeEnv
# 1) Infer the types of all vars in res # 1) Infer the types of all vars in res
@@ -123,10 +129,8 @@ class Rtl(object):
# 3) Assign the only possible type to each variable. # 3) Assign the only possible type to each variable.
for v in typenv.vars: for v in typenv.vars:
if v.get_typevar().singleton_type() is not None: assert typing[v].singleton_type() is not None
continue v.set_typevar(typing[v])
v.set_typevar(TypeVar.singleton(typing[v].singleton_type()))
class XForm(object): class XForm(object):

View File

@@ -44,15 +44,20 @@ def find_matching_xform(d):
def cleanup_semantics(r, outputs): def cleanup_semantics(r, outputs):
# type: (Rtl, Set[Var]) -> Rtl # type: (Rtl, Set[Var]) -> Rtl
""" """
The elaboration process creates a lot of redundant instruction pairs of the The elaboration process creates a lot of redundant prim_to_bv conversions.
shape: Cleanup the following cases:
1) prim_to_bv/prim_from_bv pair:
a.0 << prim_from_bv(bva.0) a.0 << prim_from_bv(bva.0)
... ...
bva.1 << prim_to_bv(a.0) bva.1 << prim_to_bv(a.0) <-- redundant, replace by bva.0
... ...
Contract these to ease manual inspection. 2) prim_to_bv/prim_to-bv pair:
bva.0 << prim_to_bv(a)
...
bva.1 << prim_to_bv(a) <-- redundant, replace by bva.0
...
""" """
new_defs = [] # type: List[Def] new_defs = [] # type: List[Def]
subst_m = {v: v for v in r.vars()} # type: VarMap subst_m = {v: v for v in r.vars()} # type: VarMap

View File

@@ -10,8 +10,9 @@ from .elaborate import elaborate
try: try:
from typing import TYPE_CHECKING, Tuple # noqa from typing import TYPE_CHECKING, Tuple # noqa
from cdsl.xform import Rtl # noqa from cdsl.xform import Rtl, XForm # noqa
from cdsl.ast import VarMap # noqa from cdsl.ast import VarMap # noqa
from cdsl.ti import VarTyping # noqa
except ImportError: except ImportError:
TYPE_CHECKING = False TYPE_CHECKING = False
@@ -34,10 +35,12 @@ def to_smt(r):
assert r.is_concrete() assert r.is_concrete()
# Should contain only primitives # Should contain only primitives
primitives = set(PRIMITIVES.instructions) primitives = set(PRIMITIVES.instructions)
assert all(d.expr.inst in primitives for d in r.rtl) assert set(d.expr.inst for d in r.rtl).issubset(primitives)
q = "" q = ""
m = {} # type: VarMap m = {} # type: VarMap
# Build declarations for any bitvector Vars
for v in r.vars(): for v in r.vars():
typ = v.get_typevar().singleton_type() typ = v.get_typevar().singleton_type()
if not isinstance(typ, BVType): if not isinstance(typ, BVType):
@@ -45,9 +48,11 @@ def to_smt(r):
q += "(declare-fun {} () {})\n".format(v.name, bvtype_to_sort(typ)) q += "(declare-fun {} () {})\n".format(v.name, bvtype_to_sort(typ))
# Encode each instruction as a equality assertion
for d in r.rtl: for d in r.rtl:
inst = d.expr.inst inst = d.expr.inst
# For prim_to_bv/prim_from_bv just update var_m. No assertion needed
if inst == prim_to_bv: if inst == prim_to_bv:
assert isinstance(d.expr.args[0], Var) assert isinstance(d.expr.args[0], Var)
m[d.expr.args[0]] = d.defs[0] m[d.expr.args[0]] = d.defs[0]
@@ -82,13 +87,13 @@ def to_smt(r):
.format(df, toW-fromW, arg) .format(df, toW-fromW, arg)
elif inst == bvsplit: elif inst == bvsplit:
arg = d.expr.args[0] arg = d.expr.args[0]
assert isinstance(arg, Var)
arg_typ = arg.get_typevar().singleton_type() arg_typ = arg.get_typevar().singleton_type()
width = arg_typ.width() width = arg_typ.width()
assert (width % 2 == 0) assert (width % 2 == 0)
lo = d.defs[0] lo = d.defs[0]
hi = d.defs[1] hi = d.defs[1]
assert isinstance(arg, Var)
exp = "(and " exp = "(and "
exp += "(= {} ((_ extract {} {}) {})) "\ exp += "(= {} ((_ extract {} {}) {})) "\
@@ -97,9 +102,10 @@ def to_smt(r):
.format(hi, width-1, width//2, arg) .format(hi, width-1, width//2, arg)
exp += ")" exp += ")"
elif inst == bvconcat: elif inst == bvconcat:
assert isinstance(d.expr.args[0], Var) and \
isinstance(d.expr.args[1], Var)
lo = d.expr.args[0] lo = d.expr.args[0]
hi = d.expr.args[1] hi = d.expr.args[1]
assert isinstance(lo, Var) and isinstance(hi, Var)
df = d.defs[0] df = d.defs[0]
# Z3 Concat expects hi bits first, then lo bits # Z3 Concat expects hi bits first, then lo bits
@@ -127,6 +133,14 @@ def equivalent(r1, r2, inp_m, out_m):
Otherwise, the satisfying example for the query gives us values Otherwise, the satisfying example for the query gives us values
for which the two Rtls disagree. for which the two Rtls disagree.
""" """
# Sanity - inp_m is a bijection from the set of inputs of r1 to the set of
# inputs of r2
assert set(r1.free_vars()) == set(inp_m.keys())
assert set(r2.free_vars()) == set(inp_m.values())
# Note that the same rule is not expected to hold for out_m due to
# temporaries/intermediates.
# Rename the vars in r1 and r2 with unique suffixes to avoid conflicts # Rename the vars in r1 and r2 with unique suffixes to avoid conflicts
src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()} src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()}
dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()} dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()}
@@ -145,7 +159,6 @@ def equivalent(r1, r2, inp_m, out_m):
args_eq_exp = "(and \n" args_eq_exp = "(and \n"
for v in r1.free_vars(): for v in r1.free_vars():
assert v in inp_m
args_eq_exp += "(= {} {})\n".format(m1[v], m2[inp_m[v]]) args_eq_exp += "(= {} {})\n".format(m1[v], m2[inp_m[v]])
args_eq_exp += ")" args_eq_exp += ")"
@@ -171,12 +184,12 @@ def equivalent(r1, r2, inp_m, out_m):
def xform_correct(x, typing): def xform_correct(x, typing):
# type: (XForm, VarTyping) -> str # type: (XForm, VarTyping) -> str
""" """
Given an XForm x and a concrete variable typing for x typing, build the Given an XForm x and a concrete variable typing for x build the smtlib
smtlib query asserting that x is correct for the given typing. query asserting that x is correct for the given typing.
""" """
assert x.ti.permits(typing) assert x.ti.permits(typing)
# Create copies of the x.src and x.dst with the concrete types in typing. # Create copies of the x.src and x.dst with their concrete types
src_m = {v: Var(v.name, typing[v]) for v in x.src.vars()} src_m = {v: Var(v.name, typing[v]) for v in x.src.vars()}
src = x.src.copy(src_m) src = x.src.copy(src_m)
dst = x.apply(src) dst = x.apply(src)
@@ -191,8 +204,6 @@ def xform_correct(x, typing):
inp_m[src_m[v]] = dst_m[v] inp_m[src_m[v]] = dst_m[v]
elif v.is_output(): elif v.is_output():
out_m[src_m[v]] = dst_m[v] out_m[src_m[v]] = dst_m[v]
else:
assert False, "Haven't decided what to do with intermediates yet"
# Get the primitive semantic Rtls for src and dst # Get the primitive semantic Rtls for src and dst
prim_src = elaborate(src) prim_src = elaborate(src)