* Add Atom and Literal base classes to CDSL Ast. Change substitution() and copy() on Def/Apply/Rtl to support substituting Var->Union[Var, Literal]. Check in Apply() constructor kinds of passed in Literals respect instruction signature

* Change verify_semantics to check all possible instantiations of enumerated immediates (needed to descrive icmp). Add all bitvector comparison primitives and bvite; Change set_semantics to optionally accept XForms; Add semantics for icmp; Fix typing errors in semantics/{smtlib, elaborate, __init__}.py after the change of VarMap->VarAtomMap

* Forgot macros.py

* Nit obscured by testing with mypy enabled present.

* Typo
This commit is contained in:
d1m0
2017-08-14 20:19:47 -07:00
committed by Jakob Stoklund Olesen
parent 591f6c1632
commit 66da171050
13 changed files with 415 additions and 137 deletions

View File

@@ -1,20 +1,33 @@
from __future__ import absolute_import from __future__ import absolute_import
from semantics.primitives import prim_to_bv, prim_from_bv, bvsplit, bvconcat,\ from semantics.primitives import prim_to_bv, prim_from_bv, bvsplit, bvconcat,\
bvadd, bvult, bvzeroext, bvsignext bvadd, bvzeroext, bvsignext
from semantics.primitives import bveq, bvne, bvsge, bvsgt, bvsle, bvslt,\
bvuge, bvugt, bvule, bvult
from semantics.macros import bool2bv
from .instructions import vsplit, vconcat, iadd, iadd_cout, icmp, bextend, \ from .instructions import vsplit, vconcat, iadd, iadd_cout, icmp, bextend, \
isplit, iconcat, iadd_cin, iadd_carry isplit, iconcat, iadd_cin, iadd_carry
from .immediates import intcc from .immediates import intcc
from cdsl.xform import Rtl from cdsl.xform import Rtl, XForm
from cdsl.ast import Var from cdsl.ast import Var
from cdsl.typevar import TypeSet from cdsl.typevar import TypeSet
from cdsl.ti import InTypeset from cdsl.ti import InTypeset
try:
from typing import TYPE_CHECKING # noqa
if TYPE_CHECKING:
from cdsl.ast import Enumerator # noqa
from cdsl.instructions import Instruction # noqa
except ImportError:
TYPE_CHECKING = False
x = Var('x') x = Var('x')
y = Var('y') y = Var('y')
a = Var('a') a = Var('a')
b = Var('b') b = Var('b')
c_out = Var('c_out') c_out = Var('c_out')
c_in = Var('c_in') c_in = Var('c_in')
CC = Var('CC')
bc_out = Var('bc_out')
bvc_out = Var('bvc_out') bvc_out = Var('bvc_out')
bvc_in = Var('bvc_in') bvc_in = Var('bvc_in')
xhi = Var('xhi') xhi = Var('xhi')
@@ -93,7 +106,8 @@ iadd_cout.set_semantics(
bvx << prim_to_bv(x), bvx << prim_to_bv(x),
bvy << prim_to_bv(y), bvy << prim_to_bv(y),
bva << bvadd(bvx, bvy), bva << bvadd(bvx, bvy),
bvc_out << bvult(bva, bvx), bc_out << bvult(bva, bvx),
bvc_out << bool2bv(bc_out),
a << prim_from_bv(bva), a << prim_from_bv(bva),
c_out << prim_from_bv(bvc_out) c_out << prim_from_bv(bvc_out)
)) ))
@@ -107,7 +121,8 @@ iadd_carry.set_semantics(
bvs << bvzeroext(bvc_in), bvs << bvzeroext(bvc_in),
bvt << bvadd(bvx, bvy), bvt << bvadd(bvx, bvy),
bva << bvadd(bvt, bvs), bva << bvadd(bvt, bvs),
bvc_out << bvult(bva, bvx), bc_out << bvult(bva, bvx),
bvc_out << bool2bv(bc_out),
a << prim_from_bv(bva), a << prim_from_bv(bva),
c_out << prim_from_bv(bvc_out) c_out << prim_from_bv(bvc_out)
)) ))
@@ -126,23 +141,45 @@ bextend.set_semantics(
a << vconcat(alo, ahi) a << vconcat(alo, ahi)
)) ))
icmp.set_semantics(
a << icmp(intcc.ult, x, y), def create_comp_xform(cc, bvcmp_func):
(Rtl( # type: (Enumerator, Instruction) -> XForm
ba = Var('ba')
return XForm(
Rtl(
a << icmp(cc, x, y)
),
Rtl(
bvx << prim_to_bv(x), bvx << prim_to_bv(x),
bvy << prim_to_bv(y), bvy << prim_to_bv(y),
bva << bvult(bvx, bvy), ba << bvcmp_func(bvx, bvy),
bva << bool2bv(ba),
bva_wide << bvzeroext(bva), bva_wide << bvzeroext(bva),
a << prim_from_bv(bva_wide), a << prim_from_bv(bva_wide),
), [InTypeset(x.get_typevar(), ScalarTS)]), ),
constraints=InTypeset(x.get_typevar(), ScalarTS))
icmp.set_semantics(
a << icmp(CC, x, y),
Rtl( Rtl(
(xlo, xhi) << vsplit(x), (xlo, xhi) << vsplit(x),
(ylo, yhi) << vsplit(y), (ylo, yhi) << vsplit(y),
alo << icmp(intcc.ult, xlo, ylo), alo << icmp(CC, xlo, ylo),
ahi << icmp(intcc.ult, xhi, yhi), ahi << icmp(CC, xhi, yhi),
b << vconcat(alo, ahi), b << vconcat(alo, ahi),
a << bextend(b) a << bextend(b)
)) ),
create_comp_xform(intcc.eq, bveq),
create_comp_xform(intcc.ne, bvne),
create_comp_xform(intcc.sge, bvsge),
create_comp_xform(intcc.sgt, bvsgt),
create_comp_xform(intcc.sle, bvsle),
create_comp_xform(intcc.slt, bvslt),
create_comp_xform(intcc.uge, bvuge),
create_comp_xform(intcc.ugt, bvugt),
create_comp_xform(intcc.ule, bvule),
create_comp_xform(intcc.ult, bvult))
# #
# Legalization helper instructions. # Legalization helper instructions.

View File

@@ -11,23 +11,23 @@ from .predicates import IsEqual, And, TypePredicate
try: try:
from typing import Union, Tuple, Sequence, TYPE_CHECKING, Dict, List # noqa from typing import Union, Tuple, Sequence, TYPE_CHECKING, Dict, List # noqa
from typing import Optional, Set # noqa from typing import Optional, Set, Any # noqa
if TYPE_CHECKING: if TYPE_CHECKING:
from .operands import ImmediateKind # noqa from .operands import ImmediateKind # noqa
from .predicates import PredNode # noqa from .predicates import PredNode # noqa
VarMap = Dict["Var", "Var"] VarAtomMap = Dict["Var", "Atom"]
except ImportError: except ImportError:
pass pass
def replace_var(arg, m): def replace_var(arg, m):
# type: (Expr, VarMap) -> Expr # type: (Expr, VarAtomMap) -> Expr
""" """
Given a var v return either m[v] or a new variable v' (and remember Given a var v return either m[v] or a new variable v' (and remember
m[v]=v'). Otherwise return the argument unchanged m[v]=v'). Otherwise return the argument unchanged
""" """
if isinstance(arg, Var): if isinstance(arg, Var):
new_arg = m.get(arg, Var(arg.name)) # type: Var new_arg = m.get(arg, Var(arg.name)) # type: Atom
m[arg] = new_arg m[arg] = new_arg
return new_arg return new_arg
return arg return arg
@@ -76,7 +76,7 @@ class Def(object):
', '.join(map(str, self.defs)), self.expr) ', '.join(map(str, self.defs)), self.expr)
def copy(self, m): def copy(self, m):
# type: (VarMap) -> Def # type: (VarAtomMap) -> Def
""" """
Return a copy of this Def with vars replaced with fresh variables, Return a copy of this Def with vars replaced with fresh variables,
in accordance with the map m. Update m as neccessary. in accordance with the map m. Update m as neccessary.
@@ -106,7 +106,7 @@ class Def(object):
return self.definitions().union(self.uses()) return self.definitions().union(self.uses())
def substitution(self, other, s): def substitution(self, other, s):
# type: (Def, VarMap) -> Optional[VarMap] # type: (Def, VarAtomMap) -> Optional[VarAtomMap]
""" """
If the Defs self and other agree structurally, return a variable If the Defs self and other agree structurally, return a variable
substitution to transform self to other. Otherwise return None. Two substitution to transform self to other. Otherwise return None. Two
@@ -133,7 +133,13 @@ class Expr(object):
""" """
class Var(Expr): class Atom(Expr):
"""
An Atom in the DSL is either a literal or a Var
"""
class Var(Atom):
""" """
A free variable. A free variable.
@@ -304,6 +310,16 @@ class Apply(Expr):
self.args = args self.args = args
assert len(self.inst.ins) == len(args) assert len(self.inst.ins) == len(args)
# Check that the kinds of Literals arguments match the expected Operand
for op_idx in self.inst.imm_opnums:
arg = self.args[op_idx]
op = self.inst.ins[op_idx]
if isinstance(arg, Literal):
assert arg.kind == op.kind, \
"Passing literal {} to field of wrong kind {}."\
.format(arg, op.kind)
def __rlshift__(self, other): def __rlshift__(self, other):
# type: (Union[Var, Tuple[Var, ...]]) -> Def # type: (Union[Var, Tuple[Var, ...]]) -> Def
""" """
@@ -377,7 +393,7 @@ class Apply(Expr):
return pred return pred
def copy(self, m): def copy(self, m):
# type: (VarMap) -> Apply # type: (VarAtomMap) -> Apply
""" """
Return a copy of this Expr with vars replaced with fresh variables, Return a copy of this Expr with vars replaced with fresh variables,
in accordance with the map m. Update m as neccessary. in accordance with the map m. Update m as neccessary.
@@ -396,15 +412,12 @@ class Apply(Expr):
return res return res
def substitution(self, other, s): def substitution(self, other, s):
# type: (Apply, VarMap) -> Optional[VarMap] # type: (Apply, VarAtomMap) -> Optional[VarAtomMap]
""" """
If the application self and other agree structurally, return a variable If there is a substituion from Var->Atom that converts self to other,
substitution to transform self to other. Otherwise return None. Two return it, otherwise return None. Note that this is strictly weaker
applications agree structurally if: than unification (see TestXForm.test_subst_enum_bad_var_const for
1) They are over the same instruction example).
2) Every Var v in self, maps to a single Var w in other. I.e for
each use of v in self, w is used in the corresponding place in
other.
""" """
if self.inst != other.inst: if self.inst != other.inst:
return None return None
@@ -413,37 +426,62 @@ class Apply(Expr):
assert (len(self.args) == len(other.args)) assert (len(self.args) == len(other.args))
for (self_a, other_a) in zip(self.args, other.args): for (self_a, other_a) in zip(self.args, other.args):
if (isinstance(self_a, Var)): assert isinstance(self_a, Atom) and isinstance(other_a, Atom)
if not isinstance(other_a, Var):
return None
if (isinstance(self_a, Var)):
if (self_a not in s): if (self_a not in s):
s[self_a] = other_a s[self_a] = other_a
else: else:
if (s[self_a] != other_a): if (s[self_a] != other_a):
return None return None
elif isinstance(self_a, ConstantInt): elif isinstance(other_a, Var):
if not isinstance(other_a, ConstantInt): assert isinstance(self_a, Literal)
return None if (other_a not in s):
assert self_a.kind == other_a.kind s[other_a] = self_a
if (self_a.value != other_a.value): else:
if s[other_a] != self_a:
return None return None
else: else:
assert isinstance(self_a, Enumerator) assert (isinstance(self_a, Literal) and
isinstance(other_a, Literal))
if not isinstance(other_a, Enumerator):
# Currently don't support substitutions Var->Enumerator
return None
# Guaranteed by self.inst == other.inst # Guaranteed by self.inst == other.inst
assert self_a.kind == other_a.kind assert self_a.kind == other_a.kind
if (self_a.value != other_a.value): if (self_a.value != other_a.value):
return None return None
return s return s
class ConstantInt(Expr): class Literal(Atom):
"""
Base Class for all literal expressions in the DSL.
"""
def __init__(self, kind, value):
# type: (ImmediateKind, Any) -> None
self.kind = kind
self.value = value
def __eq__(self, other):
# type: (Any) -> bool
if not isinstance(other, Literal):
return False
if self.kind != other.kind:
return False
# Can't just compare value here, as comparison Any <> Any returns Any
return repr(self) == repr(other)
def __ne__(self, other):
# type: (Any) -> bool
return not self.__eq__(other)
def __repr__(self):
# type: () -> str
return '{}.{}'.format(self.kind, self.value)
class ConstantInt(Literal):
""" """
A value of an integer immediate operand. A value of an integer immediate operand.
@@ -454,8 +492,7 @@ class ConstantInt(Expr):
def __init__(self, kind, value): def __init__(self, kind, value):
# type: (ImmediateKind, int) -> None # type: (ImmediateKind, int) -> None
self.kind = kind super(ConstantInt, self).__init__(kind, value)
self.value = value
def __str__(self): def __str__(self):
# type: () -> str # type: () -> str
@@ -464,12 +501,8 @@ class ConstantInt(Expr):
""" """
return str(self.value) return str(self.value)
def __repr__(self):
# type: () -> str
return '{}({})'.format(self.kind, self.value)
class Enumerator(Literal):
class Enumerator(Expr):
""" """
A value of an enumerated immediate operand. A value of an enumerated immediate operand.
@@ -486,8 +519,7 @@ class Enumerator(Expr):
def __init__(self, kind, value): def __init__(self, kind, value):
# type: (ImmediateKind, str) -> None # type: (ImmediateKind, str) -> None
self.kind = kind super(Enumerator, self).__init__(kind, value)
self.value = value
def __str__(self): def __str__(self):
# type: () -> str # type: () -> str
@@ -495,7 +527,3 @@ class Enumerator(Expr):
Get the Rust expression form of this enumerator. Get the Rust expression form of this enumerator.
""" """
return self.kind.rust_enumerator(self.value) return self.kind.rust_enumerator(self.value)
def __repr__(self):
# type: () -> str
return '{}.{}'.format(self.kind, self.value)

View File

@@ -9,7 +9,7 @@ try:
from typing import Union, Sequence, List, Tuple, Any, TYPE_CHECKING # noqa from typing import Union, Sequence, List, Tuple, Any, TYPE_CHECKING # noqa
from typing import Dict # noqa from typing import Dict # noqa
if TYPE_CHECKING: if TYPE_CHECKING:
from .ast import Expr, Apply, Var, Def # noqa from .ast import Expr, Apply, Var, Def, VarAtomMap # noqa
from .typevar import TypeVar # noqa from .typevar import TypeVar # noqa
from .ti import TypeConstraint # noqa from .ti import TypeConstraint # noqa
from .xform import XForm, Rtl from .xform import XForm, Rtl
@@ -18,7 +18,7 @@ try:
ConstrList = Union[Sequence[TypeConstraint], TypeConstraint] ConstrList = Union[Sequence[TypeConstraint], TypeConstraint]
MaybeBoundInst = Union['Instruction', 'BoundInstruction'] MaybeBoundInst = Union['Instruction', 'BoundInstruction']
InstructionSemantics = Sequence[XForm] InstructionSemantics = Sequence[XForm]
RtlCase = Union[Rtl, Tuple[Rtl, Sequence[TypeConstraint]]] SemDefCase = Union[Rtl, Tuple[Rtl, Sequence[TypeConstraint]], XForm]
except ImportError: except ImportError:
pass pass
@@ -349,7 +349,7 @@ class Instruction(object):
return Apply(self, args) return Apply(self, args)
def set_semantics(self, src, *dsts): def set_semantics(self, src, *dsts):
# type: (Union[Def, Apply], *RtlCase) -> None # type: (Union[Def, Apply], *SemDefCase) -> None
"""Set our semantics.""" """Set our semantics."""
from semantics import verify_semantics from semantics import verify_semantics
from .xform import XForm, Rtl from .xform import XForm, Rtl
@@ -358,6 +358,11 @@ class Instruction(object):
for dst in dsts: for dst in dsts:
if isinstance(dst, Rtl): if isinstance(dst, Rtl):
sem.append(XForm(Rtl(src).copy({}), dst)) sem.append(XForm(Rtl(src).copy({}), dst))
elif isinstance(dst, XForm):
sem.append(XForm(
dst.src.copy({}),
dst.dst.copy({}),
dst.constraints))
else: else:
assert isinstance(dst, tuple) assert isinstance(dst, tuple)
sem.append(XForm(Rtl(src).copy({}), dst[0], sem.append(XForm(Rtl(src).copy({}), dst[0],

View File

@@ -5,10 +5,10 @@ from .types import ValueType
from .typevar import TypeVar from .typevar import TypeVar
try: try:
from typing import Union, Dict, TYPE_CHECKING # noqa from typing import Union, Dict, TYPE_CHECKING, Iterable # noqa
OperandSpec = Union['OperandKind', ValueType, TypeVar] OperandSpec = Union['OperandKind', ValueType, TypeVar]
if TYPE_CHECKING: if TYPE_CHECKING:
from .ast import Enumerator, ConstantInt # noqa from .ast import Enumerator, ConstantInt, Literal # noqa
except ImportError: except ImportError:
pass pass
@@ -128,6 +128,17 @@ class ImmediateKind(OperandKind):
""" """
return '{}::{}'.format(self.rust_type, self.values[value]) return '{}::{}'.format(self.rust_type, self.values[value])
def is_enumerable(self):
# type: () -> bool
return self.values is not None
def possible_values(self):
# type: () -> Iterable[Literal]
from cdsl.ast import Enumerator # noqa
assert self.is_enumerable()
for v in self.values.keys():
yield Enumerator(self, v)
# Instances of entity reference operand types are provided in the # Instances of entity reference operand types are provided in the
# `cretonne.entities` module. # `cretonne.entities` module.

View File

@@ -80,15 +80,52 @@ class TestXForm(TestCase):
dst = Rtl(b << icmp(intcc.eq, z, u)) dst = Rtl(b << icmp(intcc.eq, z, u))
assert src.substitution(dst, {}) == {a: b, x: z, y: u} assert src.substitution(dst, {}) == {a: b, x: z, y: u}
def test_subst_enum_bad(self): def test_subst_enum_var_const(self):
src = Rtl(a << icmp(CC1, x, y)) src = Rtl(a << icmp(CC1, x, y))
dst = Rtl(b << icmp(intcc.eq, z, u)) dst = Rtl(b << icmp(intcc.eq, z, u))
assert src.substitution(dst, {}) is None assert src.substitution(dst, {}) == {CC1: intcc.eq, x: z, y: u, a: b},\
"{} != {}".format(src.substitution(dst, {}),
{CC1: intcc.eq, x: z, y: u, a: b})
src = Rtl(a << icmp(intcc.eq, x, y)) src = Rtl(a << icmp(intcc.eq, x, y))
dst = Rtl(b << icmp(CC1, z, u)) dst = Rtl(b << icmp(CC1, z, u))
assert src.substitution(dst, {}) is None assert src.substitution(dst, {}) == {CC1: intcc.eq, x: z, y: u, a: b}
def test_subst_enum_bad(self):
src = Rtl(a << icmp(intcc.eq, x, y)) src = Rtl(a << icmp(intcc.eq, x, y))
dst = Rtl(b << icmp(intcc.sge, z, u)) dst = Rtl(b << icmp(intcc.sge, z, u))
assert src.substitution(dst, {}) is None assert src.substitution(dst, {}) is None
def test_subst_enum_bad_var_const(self):
a1 = Var('a1')
x1 = Var('x1')
y1 = Var('y1')
b1 = Var('b1')
z1 = Var('z1')
u1 = Var('u1')
# Var mapping to 2 different constants
src = Rtl(a << icmp(CC1, x, y),
a1 << icmp(CC1, x1, y1))
dst = Rtl(b << icmp(intcc.eq, z, u),
b1 << icmp(intcc.sge, z1, u1))
assert src.substitution(dst, {}) is None
# 2 different constants mapping to the same var
src = Rtl(a << icmp(intcc.eq, x, y),
a1 << icmp(intcc.sge, x1, y1))
dst = Rtl(b << icmp(CC1, z, u),
b1 << icmp(CC1, z1, u1))
assert src.substitution(dst, {}) is None
# Var mapping to var and constant - note that full unification would
# have allowed this.
src = Rtl(a << icmp(CC1, x, y),
a1 << icmp(CC1, x1, y1))
dst = Rtl(b << icmp(CC2, z, u),
b1 << icmp(intcc.sge, z1, u1))
assert src.substitution(dst, {}) is None

View File

@@ -3,16 +3,16 @@ Instruction transformations.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from .ast import Def, Var, Apply from .ast import Def, Var, Apply
from .ti import ti_xform, TypeEnv, get_type_env from .ti import ti_xform, TypeEnv, get_type_env, TypeConstraint
from functools import reduce from functools import reduce
try: try:
from typing import Union, Iterator, Sequence, Iterable, List, Dict # noqa from typing import Union, Iterator, Sequence, Iterable, List, Dict # noqa
from typing import Optional, Set # noqa from typing import Optional, Set # noqa
from .ast import Expr, VarMap # noqa from .ast import Expr, VarAtomMap # noqa
from .isa import TargetISA # noqa from .isa import TargetISA # noqa
from .ti import TypeConstraint # noqa
from .typevar import TypeVar # noqa from .typevar import TypeVar # noqa
from .instructions import ConstrList # noqa
DefApply = Union[Def, Apply] DefApply = Union[Def, Apply]
except ImportError: except ImportError:
pass pass
@@ -47,7 +47,7 @@ class Rtl(object):
self.rtl = tuple(map(canonicalize_defapply, args)) self.rtl = tuple(map(canonicalize_defapply, args))
def copy(self, m): def copy(self, m):
# type: (VarMap) -> Rtl # type: (VarAtomMap) -> Rtl
""" """
Return a copy of this rtl with all Vars substituted with copies or Return a copy of this rtl with all Vars substituted with copies or
according to m. Update m as neccessary. according to m. Update m as neccessary.
@@ -85,7 +85,7 @@ class Rtl(object):
return reduce(flow_f, reversed(self.rtl), set([])) return reduce(flow_f, reversed(self.rtl), set([]))
def substitution(self, other, s): def substitution(self, other, s):
# type: (Rtl, VarMap) -> Optional[VarMap] # type: (Rtl, VarAtomMap) -> Optional[VarAtomMap]
""" """
If the Rtl self agrees structurally with the Rtl other, return a If the Rtl self agrees structurally with the Rtl other, return a
substitution to transform self to other. Two Rtls agree structurally if substitution to transform self to other. Two Rtls agree structurally if
@@ -132,6 +132,10 @@ class Rtl(object):
assert typing[v].singleton_type() is not None assert typing[v].singleton_type() is not None
v.set_typevar(typing[v]) v.set_typevar(typing[v])
def __str__(self):
# type: () -> str
return "\n".join(map(str, self.rtl))
class XForm(object): class XForm(object):
""" """
@@ -162,7 +166,7 @@ class XForm(object):
""" """
def __init__(self, src, dst, constraints=None): def __init__(self, src, dst, constraints=None):
# type: (Rtl, Rtl, Optional[Sequence[TypeConstraint]]) -> None # type: (Rtl, Rtl, Optional[ConstrList]) -> None
self.src = src self.src = src
self.dst = dst self.dst = dst
# Variables that are inputs to the source pattern. # Variables that are inputs to the source pattern.
@@ -203,10 +207,18 @@ class XForm(object):
return tv return tv
return symtab[tv.name[len("typeof_"):]].get_typevar() return symtab[tv.name[len("typeof_"):]].get_typevar()
self.constraints = [] # type: List[TypeConstraint]
if constraints is not None: if constraints is not None:
for c in constraints: if isinstance(constraints, TypeConstraint):
constr_list = [constraints] # type: Sequence[TypeConstraint]
else:
constr_list = constraints
for c in constr_list:
type_m = {tv: interp_tv(tv) for tv in c.tvs()} type_m = {tv: interp_tv(tv) for tv in c.tvs()}
self.ti.add_constraint(c.translate(type_m)) inner_c = c.translate(type_m)
self.constraints.append(inner_c)
self.ti.add_constraint(inner_c)
# Sanity: The set of inferred free typevars should be a subset of the # Sanity: The set of inferred free typevars should be a subset of the
# TVs corresponding to Vars appearing in src # TVs corresponding to Vars appearing in src
@@ -333,7 +345,7 @@ class XForm(object):
defs are renamed with '.suffix' appended to their old name. defs are renamed with '.suffix' appended to their old name.
""" """
assert r.is_concrete() assert r.is_concrete()
s = self.src.substitution(r, {}) # type: VarMap s = self.src.substitution(r, {}) # type: VarAtomMap
assert s is not None assert s is not None
if (suffix is not None): if (suffix is not None):

View File

@@ -21,7 +21,7 @@ from cdsl.typevar import TypeVar
try: try:
from typing import Sequence, List, Dict, Set, DefaultDict # noqa from typing import Sequence, List, Dict, Set, DefaultDict # noqa
from cdsl.isa import TargetISA # noqa from cdsl.isa import TargetISA # noqa
from cdsl.ast import Def # noqa from cdsl.ast import Def, VarAtomMap # noqa
from cdsl.xform import XForm, XFormGroup # noqa from cdsl.xform import XForm, XFormGroup # noqa
from cdsl.typevar import TypeSet # noqa from cdsl.typevar import TypeSet # noqa
from cdsl.ti import TypeConstraint # noqa from cdsl.ti import TypeConstraint # noqa
@@ -45,7 +45,7 @@ def get_runtime_typechecks(xform):
# 1) Perform ti only on the source RTL. Accumulate any free tvs that have a # 1) Perform ti only on the source RTL. Accumulate any free tvs that have a
# different inferred type in src, compared to the type inferred for both # different inferred type in src, compared to the type inferred for both
# src and dst. # src and dst.
symtab = {} # type: Dict[Var, Var] symtab = {} # type: VarAtomMap
src_copy = xform.src.copy(symtab) src_copy = xform.src.copy(symtab)
src_typenv = get_type_env(ti_rtl(src_copy, TypeEnv())) src_typenv = get_type_env(ti_rtl(src_copy, TypeEnv()))
@@ -62,7 +62,9 @@ def get_runtime_typechecks(xform):
assert v.get_typevar().singleton_type() is not None assert v.get_typevar().singleton_type() is not None
continue continue
src_ts = src_typenv[symtab[v]].get_typeset() inner_v = symtab[v]
assert isinstance(inner_v, Var)
src_ts = src_typenv[inner_v].get_typeset()
xform_ts = xform.ti[v].get_typeset() xform_ts = xform.ti[v].get_typeset()
assert xform_ts.issubset(src_ts) assert xform_ts.issubset(src_ts)

View File

@@ -1,9 +1,11 @@
"""Definitions for the semantics segment of the Cretonne language.""" """Definitions for the semantics segment of the Cretonne language."""
from cdsl.ti import TypeEnv, ti_rtl, get_type_env from cdsl.ti import TypeEnv, ti_rtl, get_type_env
from cdsl.operands import ImmediateKind
from cdsl.ast import Var
try: try:
from typing import List, Dict, Tuple # noqa from typing import List, Dict, Tuple # noqa
from cdsl.ast import Var # noqa from cdsl.ast import VarAtomMap # noqa
from cdsl.xform import XForm, Rtl # noqa from cdsl.xform import XForm, Rtl # noqa
from cdsl.ti import VarTyping # noqa from cdsl.ti import VarTyping # noqa
from cdsl.instructions import Instruction, InstructionSemantics # noqa from cdsl.instructions import Instruction, InstructionSemantics # noqa
@@ -16,21 +18,44 @@ def verify_semantics(inst, src, xforms):
""" """
Verify that the semantics transforms in xforms correctly describe the Verify that the semantics transforms in xforms correctly describe the
instruction described by the src Rtl. This involves checking that: instruction described by the src Rtl. This involves checking that:
1) For all XForms x \in xforms, there is a Var substitution form src to 0) src is a single instance of inst
x.src 1) For all x\in xforms x.src is a single instance of inst
2) For any possible concrete typing of src there is exactly 1 XForm x 2) For any concrete values V of Literals in inst:
in xforms that applies. For all concrete typing T of inst:
Exists single x \in xforms that applies to src conretazied to V
and T
""" """
# 0) The source rtl is always a single instruction # 0) The source rtl is always a single instance of inst
assert len(src.rtl) == 1 assert len(src.rtl) == 1 and src.rtl[0].expr.inst == inst
# 1) For all XForms x, x.src is structurally equivalent to src # 1) For all XForms x, x.src is a single instance of inst
for x in xforms: for x in xforms:
assert src.substitution(x.src, {}) is not None,\ assert len(x.src.rtl) == 1 and x.src.rtl[0].expr.inst == inst
"XForm {} doesn't describe instruction {}.".format(x, src)
# 2) Any possible typing for the instruction should be covered by variants = [src] # type: List[Rtl]
# exactly ONE semantic XForm
# 2) For all enumerated immediates, compute all the possible
# versions of src with the concrete value filled in.
for i in inst.imm_opnums:
op = inst.ins[i]
if not (isinstance(op.kind, ImmediateKind) and
op.kind.is_enumerable()):
continue
new_variants = [] # type: List[Rtl]
for rtl_var in variants:
s = {v: v for v in rtl_var.vars()} # type: VarAtomMap
arg = rtl_var.rtl[0].expr.args[i]
assert isinstance(arg, Var)
for val in op.kind.possible_values():
s[arg] = val
new_variants.append(rtl_var.copy(s))
variants = new_variants
# For any possible version of the src with concrete enumerated immediates
for src in variants:
# 2) Any possible typing should be covered by exactly ONE semantic
# XForm
src = src.copy({}) src = src.copy({})
typenv = get_type_env(ti_rtl(src, TypeEnv())) typenv = get_type_env(ti_rtl(src, TypeEnv()))
typenv.normalize() typenv.normalize()
@@ -39,6 +64,9 @@ def verify_semantics(inst, src, xforms):
for t in typenv.concrete_typings(): for t in typenv.concrete_typings():
matching_xforms = [] # type: List[XForm] matching_xforms = [] # type: List[XForm]
for x in xforms: for x in xforms:
if src.substitution(x.src, {}) is None:
continue
# Translate t using x.symtab # Translate t using x.symtab
t = {x.symtab[str(v)]: tv for (v, tv) in t.items()} t = {x.symtab[str(v)]: tv for (v, tv) in t.items()}
if (x.ti.permits(t)): if (x.ti.permits(t)):
@@ -46,4 +74,4 @@ def verify_semantics(inst, src, xforms):
assert len(matching_xforms) == 1,\ assert len(matching_xforms) == 1,\
("Possible typing {} of {} not matched by exactly one case " + ("Possible typing {} of {} not matched by exactly one case " +
": {}").format(t, inst, matching_xforms) ": {}").format(t, src.rtl[0], matching_xforms)

View File

@@ -10,7 +10,7 @@ from cdsl.ast import Var
try: try:
from typing import TYPE_CHECKING, Dict, Union, List, Set, Tuple # noqa from typing import TYPE_CHECKING, Dict, Union, List, Set, Tuple # noqa
from cdsl.xform import XForm # noqa from cdsl.xform import XForm # noqa
from cdsl.ast import Def, VarMap # noqa from cdsl.ast import Def, VarAtomMap # noqa
from cdsl.ti import VarTyping # noqa from cdsl.ti import VarTyping # noqa
except ImportError: except ImportError:
TYPE_CHECKING = False TYPE_CHECKING = False
@@ -34,7 +34,13 @@ def find_matching_xform(d):
if (subst is None): if (subst is None):
continue continue
if x.ti.permits({subst[v]: tv for (v, tv) in typing.items()}): inner_typing = {} # type: VarTyping
for (v, tv) in typing.items():
inner_v = subst[v]
assert isinstance(inner_v, Var)
inner_typing[inner_v] = tv
if x.ti.permits(inner_typing):
res.append(x) res.append(x)
assert len(res) == 1, "Couldn't find semantic transform for {}".format(d) assert len(res) == 1, "Couldn't find semantic transform for {}".format(d)
@@ -60,7 +66,7 @@ def cleanup_semantics(r, outputs):
... ...
""" """
new_defs = [] # type: List[Def] new_defs = [] # type: List[Def]
subst_m = {v: v for v in r.vars()} # type: VarMap subst_m = {v: v for v in r.vars()} # type: VarAtomMap
definition = {} # type: Dict[Var, Def] definition = {} # type: Dict[Var, Def]
prim_to_bv_map = {} # type: Dict[Var, Def] prim_to_bv_map = {} # type: Dict[Var, Def]

View File

@@ -0,0 +1,45 @@
"""
Useful semantics "macro" instructions built on top of
the primitives.
"""
from __future__ import absolute_import
from cdsl.operands import Operand
from cdsl.typevar import TypeVar
from cdsl.instructions import Instruction, InstructionGroup
from base.types import b1
from base.immediates import imm64
from cdsl.ast import Var
from cdsl.xform import Rtl
from semantics.primitives import bv_from_imm64, bvite
import base.formats # noqa
GROUP = InstructionGroup("primitive_macros", "Semantic macros instruction set")
AnyBV = TypeVar('AnyBV', bitvecs=True, doc="")
x = Var('x')
y = Var('y')
imm = Var('imm')
a = Var('a')
#
# Bool-to-bv1
#
BV1 = TypeVar("BV1", bitvecs=(1, 1), doc="")
bv1_op = Operand('bv1_op', BV1, doc="")
cond_op = Operand("cond", b1, doc="")
bool2bv = Instruction(
'bool2bv', r"""Convert a b1 value to a 1-bit BV""",
ins=cond_op, outs=bv1_op)
v1 = Var('v1')
v2 = Var('v2')
bvone = Var('bvone')
bvzero = Var('bvzero')
bool2bv.set_semantics(
v1 << bool2bv(v2),
Rtl(
bvone << bv_from_imm64(imm64(1)),
bvzero << bv_from_imm64(imm64(0)),
v1 << bvite(v2, bvone, bvzero)
))
GROUP.close()

View File

@@ -10,6 +10,8 @@ from cdsl.operands import Operand
from cdsl.typevar import TypeVar from cdsl.typevar import TypeVar
from cdsl.instructions import Instruction, InstructionGroup from cdsl.instructions import Instruction, InstructionGroup
from cdsl.ti import WiderOrEq from cdsl.ti import WiderOrEq
from base.types import b1
from base.immediates import imm64
import base.formats # noqa import base.formats # noqa
GROUP = InstructionGroup("primitive", "Primitive instruction set") GROUP = InstructionGroup("primitive", "Primitive instruction set")
@@ -22,26 +24,40 @@ Real = TypeVar('Real', 'Any real type.', ints=True, floats=True,
x = Operand('x', BV, doc="A semantic value X") x = Operand('x', BV, doc="A semantic value X")
y = Operand('x', BV, doc="A semantic value Y (same width as X)") y = Operand('x', BV, doc="A semantic value Y (same width as X)")
a = Operand('a', BV, doc="A semantic value A (same width as X)") a = Operand('a', BV, doc="A semantic value A (same width as X)")
cond = Operand('b', TypeVar.singleton(b1), doc='A b1 value')
real = Operand('real', Real, doc="A real cretonne value") real = Operand('real', Real, doc="A real cretonne value")
fromReal = Operand('fromReal', Real.to_bitvec(), fromReal = Operand('fromReal', Real.to_bitvec(),
doc="A real cretonne value converted to a BV") doc="A real cretonne value converted to a BV")
#
# BV Conversion/Materialization
#
prim_to_bv = Instruction( prim_to_bv = Instruction(
'prim_to_bv', r""" 'prim_to_bv', r"""
Convert an SSA Value to a flat bitvector Convert an SSA Value to a flat bitvector
""", """,
ins=(real), outs=(fromReal)) ins=(real), outs=(fromReal))
# Note that when converting from BV->real values, we use a constraint and not a
# derived function. This reflects that fact that to_bitvec() is not a
# bijection.
prim_from_bv = Instruction( prim_from_bv = Instruction(
'prim_from_bv', r""" 'prim_from_bv', r"""
Convert a flat bitvector to a real SSA Value. Convert a flat bitvector to a real SSA Value.
""", """,
ins=(fromReal), outs=(real)) ins=(fromReal), outs=(real))
N = Operand('N', imm64)
bv_from_imm64 = Instruction(
'bv_from_imm64', r"""Materialize an imm64 as a bitvector.""",
ins=(N), outs=a)
#
# Generics
#
bvite = Instruction(
'bvite', r"""Bitvector ternary operator""",
ins=(cond, x, y), outs=a)
xh = Operand('xh', BV.half_width(), xh = Operand('xh', BV.half_width(),
doc="A semantic value representing the upper half of X") doc="A semantic value representing the upper half of X")
xl = Operand('xl', BV.half_width(), xl = Operand('xl', BV.half_width(),
@@ -67,12 +83,40 @@ bvadd = Instruction(
of the operands. of the operands.
""", """,
ins=(x, y), outs=a) ins=(x, y), outs=a)
#
# Bitvector comparisons # Bitvector comparisons
cmp_res = Operand('cmp_res', BV1, doc="Single bit boolean") #
bveq = Instruction(
'bveq', r"""Unsigned bitvector equality""",
ins=(x, y), outs=cond)
bvne = Instruction(
'bveq', r"""Unsigned bitvector inequality""",
ins=(x, y), outs=cond)
bvsge = Instruction(
'bvsge', r"""Signed bitvector greater or equal""",
ins=(x, y), outs=cond)
bvsgt = Instruction(
'bvsgt', r"""Signed bitvector greater than""",
ins=(x, y), outs=cond)
bvsle = Instruction(
'bvsle', r"""Signed bitvector less than or equal""",
ins=(x, y), outs=cond)
bvslt = Instruction(
'bvslt', r"""Signed bitvector less than""",
ins=(x, y), outs=cond)
bvuge = Instruction(
'bvuge', r"""Unsigned bitvector greater or equal""",
ins=(x, y), outs=cond)
bvugt = Instruction(
'bvugt', r"""Unsigned bitvector greater than""",
ins=(x, y), outs=cond)
bvule = Instruction(
'bvule', r"""Unsigned bitvector less than or equal""",
ins=(x, y), outs=cond)
bvult = Instruction( bvult = Instruction(
'bvult', r"""Unsigned bitvector comparison""", 'bvult', r"""Unsigned bitvector less than""",
ins=(x, y), outs=cmp_res) ins=(x, y), outs=cond)
# Extensions # Extensions
ToBV = TypeVar('ToBV', 'A bitvector type.', bitvecs=True) ToBV = TypeVar('ToBV', 'A bitvector type.', bitvecs=True)

View File

@@ -14,7 +14,7 @@ from z3.z3core import Z3_mk_eq
try: try:
from typing import TYPE_CHECKING, Tuple, Dict, List # noqa from typing import TYPE_CHECKING, Tuple, Dict, List # noqa
from cdsl.xform import Rtl, XForm # noqa from cdsl.xform import Rtl, XForm # noqa
from cdsl.ast import VarMap # noqa from cdsl.ast import VarAtomMap, Atom # noqa
from cdsl.ti import VarTyping # noqa from cdsl.ti import VarTyping # noqa
if TYPE_CHECKING: if TYPE_CHECKING:
from z3 import ExprRef, BitVecRef # noqa from z3 import ExprRef, BitVecRef # noqa
@@ -137,13 +137,13 @@ def to_smt(r):
def equivalent(r1, r2, inp_m, out_m): def equivalent(r1, r2, inp_m, out_m):
# type: (Rtl, Rtl, VarMap, VarMap) -> List[ExprRef] # type: (Rtl, Rtl, VarAtomMap, VarAtomMap) -> List[ExprRef]
""" """
Given: Given:
- concrete source Rtl r1 - concrete source Rtl r1
- concrete dest Rtl r2 - concrete dest Rtl r2
- VarMap inp_m mapping r1's non-bitvector inputs to r2 - VarAtomMap inp_m mapping r1's non-bitvector inputs to r2
- VarMap out_m mapping r1's non-bitvector outputs to r2 - VarAtomMap out_m mapping r1's non-bitvector outputs to r2
Build a query checking whether r1 and r2 are semantically equivalent. Build a query checking whether r1 and r2 are semantically equivalent.
If the returned query is unsatisfiable, then r1 and r2 are equivalent. If the returned query is unsatisfiable, then r1 and r2 are equivalent.
@@ -156,17 +156,31 @@ def equivalent(r1, r2, inp_m, out_m):
assert set(r2.free_vars()) == set(inp_m.values()) assert set(r2.free_vars()) == set(inp_m.values())
# Note that the same rule is not expected to hold for out_m due to # Note that the same rule is not expected to hold for out_m due to
# temporaries/intermediates. # temporaries/intermediates. out_m specified which values are enough for
# equivalence.
# Rename the vars in r1 and r2 with unique suffixes to avoid conflicts # Rename the vars in r1 and r2 with unique suffixes to avoid conflicts
src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()} src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()} # type: VarAtomMap # noqa
dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()} dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()} # type: VarAtomMap # noqa
r1 = r1.copy(src_m) r1 = r1.copy(src_m)
r2 = r2.copy(dst_m) r2 = r2.copy(dst_m)
def _translate(m, k_m, v_m):
# type: (VarAtomMap, VarAtomMap, VarAtomMap) -> VarAtomMap
"""Obtain a new map from m, by mapping m's keys with k_m and m's values
with v_m"""
res = {} # type: VarAtomMap
for (k, v) in m1.items():
new_k = k_m[k]
new_v = v_m[v]
assert isinstance(new_k, Var)
res[new_k] = new_v
return res
# Convert inp_m, out_m in terms of variables with the .a/.b suffixes # Convert inp_m, out_m in terms of variables with the .a/.b suffixes
inp_m = {src_m[k]: dst_m[v] for (k, v) in inp_m.items()} inp_m = _translate(inp_m, src_m, dst_m)
out_m = {src_m[k]: dst_m[v] for (k, v) in out_m.items()} out_m = _translate(out_m, src_m, dst_m)
# Encode r1 and r2 as SMT queries # Encode r1 and r2 as SMT queries
(q1, m1) = to_smt(r1) (q1, m1) = to_smt(r1)
@@ -175,12 +189,14 @@ def equivalent(r1, r2, inp_m, out_m):
# Build an expression for the equality of real Cretone inputs of r1 and r2 # Build an expression for the equality of real Cretone inputs of r1 and r2
args_eq_exp = [] # type: List[ExprRef] args_eq_exp = [] # type: List[ExprRef]
for v in r1.free_vars(): for (v1, v2) in inp_m.items():
args_eq_exp.append(mk_eq(m1[v], m2[inp_m[v]])) assert isinstance(v2, Var)
args_eq_exp.append(mk_eq(m1[v1], m2[v2]))
# Build an expression for the equality of real Cretone outputs of r1 and r2 # Build an expression for the equality of real Cretone outputs of r1 and r2
results_eq_exp = [] # type: List[ExprRef] results_eq_exp = [] # type: List[ExprRef]
for (v1, v2) in out_m.items(): for (v1, v2) in out_m.items():
assert isinstance(v2, Var)
results_eq_exp.append(mk_eq(m1[v1], m2[v2])) results_eq_exp.append(mk_eq(m1[v1], m2[v2]))
# Put the whole query toghether # Put the whole query toghether
@@ -196,20 +212,22 @@ def xform_correct(x, typing):
assert x.ti.permits(typing) assert x.ti.permits(typing)
# Create copies of the x.src and x.dst with their concrete types # Create copies of the x.src and x.dst with their concrete types
src_m = {v: Var(v.name, typing[v]) for v in x.src.vars()} src_m = {v: Var(v.name, typing[v]) for v in x.src.vars()} # type: VarAtomMap # noqa
src = x.src.copy(src_m) src = x.src.copy(src_m)
dst = x.apply(src) dst = x.apply(src)
dst_m = x.dst.substitution(dst, {}) dst_m = x.dst.substitution(dst, {})
# Build maps for the inputs/outputs for src->dst # Build maps for the inputs/outputs for src->dst
inp_m = {} inp_m = {} # type: VarAtomMap
out_m = {} out_m = {} # type: VarAtomMap
for v in x.src.vars(): for v in x.src.vars():
src_v = src_m[v]
assert isinstance(src_v, Var)
if v.is_input(): if v.is_input():
inp_m[src_m[v]] = dst_m[v] inp_m[src_v] = dst_m[v]
elif v.is_output(): elif v.is_output():
out_m[src_m[v]] = dst_m[v] out_m[src_v] = dst_m[v]
# Get the primitive semantic Rtls for src and dst # Get the primitive semantic Rtls for src and dst
prim_src = elaborate(src) prim_src = elaborate(src)

View File

@@ -1,7 +1,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint from base.instructions import vselect, vsplit, vconcat, iconst, iadd, bint
from base.instructions import b1, icmp, ireduce, iadd_cout from base.instructions import b1, icmp, ireduce, iadd_cout
from base.immediates import intcc from base.immediates import intcc, imm64
from base.types import i64, i8, b32, i32, i16, f32 from base.types import i64, i8, b32, i32, i16, f32
from cdsl.typevar import TypeVar from cdsl.typevar import TypeVar
from cdsl.ast import Var from cdsl.ast import Var
@@ -9,7 +9,7 @@ from cdsl.xform import Rtl
from unittest import TestCase from unittest import TestCase
from .elaborate import elaborate from .elaborate import elaborate
from .primitives import prim_to_bv, bvsplit, prim_from_bv, bvconcat, bvadd, \ from .primitives import prim_to_bv, bvsplit, prim_from_bv, bvconcat, bvadd, \
bvult bvult, bv_from_imm64, bvite
import base.semantics # noqa import base.semantics # noqa
@@ -366,9 +366,12 @@ class TestElaborate(TestCase):
a = Var('a') a = Var('a')
c_out = Var('c_out') c_out = Var('c_out')
bvc_out = Var('bvc_out') bvc_out = Var('bvc_out')
bc_out = Var('bc_out')
bvx = Var('bvx') bvx = Var('bvx')
bvy = Var('bvy') bvy = Var('bvy')
bva = Var('bva') bva = Var('bva')
bvone = Var('bvone')
bvzero = Var('bvzero')
r = Rtl( r = Rtl(
(a, c_out) << iadd_cout.i32(x, y), (a, c_out) << iadd_cout.i32(x, y),
) )
@@ -378,10 +381,12 @@ class TestElaborate(TestCase):
bvx << prim_to_bv.i32(x), bvx << prim_to_bv.i32(x),
bvy << prim_to_bv.i32(y), bvy << prim_to_bv.i32(y),
bva << bvadd.bv32(bvx, bvy), bva << bvadd.bv32(bvx, bvy),
bvc_out << bvult.bv32(bva, bvx), bc_out << bvult.bv32(bva, bvx),
bvone << bv_from_imm64(imm64(1)),
bvzero << bv_from_imm64(imm64(0)),
bvc_out << bvite(bc_out, bvone, bvzero),
a << prim_from_bv.i32(bva), a << prim_from_bv.i32(bva),
c_out << prim_from_bv.b1(bvc_out) c_out << prim_from_bv.b1(bvc_out)
) )
exp.cleanup_concrete_rtl() exp.cleanup_concrete_rtl()
assert concrete_rtls_eq(sem, exp) assert concrete_rtls_eq(sem, exp)