Change TV ranking to select src vars as a representative during unification; Nit: cleanup dot() emitting code; Nit: fix small bug in verify_semantics() - make an internal copy of src rtl to avoid clobbering of typevars re-used in multiple definitions

This commit is contained in:
Dimo
2017-07-25 15:09:22 -07:00
committed by Jakob Stoklund Olesen
parent 20d96a1ac4
commit e41ddf2a0d
6 changed files with 51 additions and 41 deletions

View File

@@ -158,8 +158,7 @@ class TypeCheckingBaseTest(TestCase):
self.v8 = Var("v8")
self.v9 = Var("v9")
self.imm0 = Var("imm0")
self.IxN_nonscalar = TypeVar("IxN_nonscalar", "", ints=True,
scalars=False, simd=True)
self.IxN = TypeVar("IxN", "", ints=True, scalars=True, simd=True)
self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True,
scalars=False, simd=True)
self.b1 = TypeVar.singleton(b1)
@@ -176,7 +175,7 @@ class TestRTL(TypeCheckingBaseTest):
self.assertEqual(ti_rtl(r, ti),
"On line 1: fail ti on `typeof_v2` <: `1`: " +
"Error: empty type created when unifying " +
"`typeof_v2` and `half_vector(typeof_v2)`")
"`typeof_v3` and `half_vector(typeof_v3)`")
def test_vselect(self):
# type: () -> None
@@ -202,11 +201,11 @@ class TestRTL(TypeCheckingBaseTest):
)
ti = TypeEnv()
typing = ti_rtl(r, ti)
ixn = self.IxN_nonscalar.get_fresh_copy("IxN1")
ixn = self.IxN.get_fresh_copy("IxN1")
txn = self.TxN.get_fresh_copy("TxN1")
check_typing(typing, ({
self.v0: ixn,
self.v1: ixn.as_bool(),
self.v1: txn.as_bool(),
self.v2: ixn,
self.v3: txn,
self.v4: txn,
@@ -319,7 +318,7 @@ class TestRTL(TypeCheckingBaseTest):
self.assertEqual(typing,
"On line 2: fail ti on `typeof_v4` <: `4`: " +
"Error: empty type created when unifying " +
"`typeof_v4` and `typeof_v5`")
"`i16` and `i32`")
def test_extend_reduce(self):
# type: () -> None
@@ -471,7 +470,7 @@ class TestXForm(TypeCheckingBaseTest):
assert var_m[v0] == var_m[v2] and \
var_m[v3] == var_m[v4] and\
var_m[v5] == var_m[v3] and\
var_m[v1] == var_m[v2].as_bool() and\
var_m[v1] == var_m[v5].as_bool() and\
var_m[v1].get_typeset() == var_m[v3].as_bool().get_typeset()
check_concrete_typing_xform(var_m, xform)

View File

@@ -104,7 +104,6 @@ class TypesEqual(TypeConstraint):
"""
def __init__(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None
assert tv1.is_derived and tv2.is_derived
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
def _args(self):
@@ -279,7 +278,7 @@ class TypeEnv(object):
:attribute idx: counter used to get fresh ids
"""
RANK_DERIVED = 5
RANK_SINGLETON = 5
RANK_INPUT = 4
RANK_INTERMEDIATE = 3
RANK_OUTPUT = 2
@@ -364,12 +363,18 @@ class TypeEnv(object):
# type: (TypeVar) -> int
"""
Get the rank of tv in the partial order. TVs directly associated with a
Var get their rank from the Var (see register()).
Internally generated non-derived TVs implicitly get the lowest rank (0)
Derived variables get the highest rank.
Var get their rank from the Var (see register()). Internally generated
non-derived TVs implicitly get the lowest rank (0). Derived variables
get their rank from their free typevar. Singletons have the highest
rank. TVs associated with vars in a source pattern have a higher rank
than TVs associted with temporary vars.
"""
default_rank = TypeEnv.RANK_DERIVED if tv.is_derived else\
TypeEnv.RANK_INTERNAL
default_rank = TypeEnv.RANK_INTERNAL if tv.singleton_type() is None \
else TypeEnv.RANK_SINGLETON
if tv.is_derived:
tv = tv.free_typevar()
return self.ranks.get(tv, default_rank)
def register(self, v):
@@ -565,28 +570,36 @@ class TypeEnv(object):
# Add all registered TVs (as some of them may be singleton nodes not
# appearing in the graph
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
nodes = set() # type: Set[TypeVar]
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
for (k, v) in self.type_map.items():
def add_nodes(*args):
# type: (*TypeVar) -> None
for tv in args:
nodes.add(tv)
while (tv.is_derived):
nodes.add(tv.base)
edges.add((tv, tv.base, "solid", "forward",
tv.derived_func))
tv = tv.base
for v in self.vars:
add_nodes(v.get_typevar())
for (tv1, tv2) in self.type_map.items():
# Add all intermediate TVs appearing in edges
nodes.add(k)
nodes.add(v)
edges.add((k, v, "dotted", "forward", None))
while (v.is_derived):
nodes.add(v.base)
edges.add((v, v.base, "solid", "forward", v.derived_func))
v = v.base
add_nodes(tv1, tv2)
edges.add((tv1, tv2, "dotted", "forward", None))
for constr in self.constraints:
if isinstance(constr, TypesEqual):
assert constr.tv1 in nodes and constr.tv2 in nodes
add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
elif isinstance(constr, WiderOrEq):
assert constr.tv1 in nodes and constr.tv2 in nodes
add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
elif isinstance(constr, SameWidth):
assert constr.tv1 in nodes and constr.tv2 in nodes
add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "none",
"same_width"))
else:
@@ -640,7 +653,9 @@ def get_type_env(typing_or_err):
"""
Helper function to appease mypy when checking the result of typing.
"""
assert isinstance(typing_or_err, TypeEnv)
assert isinstance(typing_or_err, TypeEnv), \
"Unexpected error: {}".format(typing_or_err)
if (TYPE_CHECKING):
return cast(TypeEnv, typing_or_err)
else:
@@ -752,8 +767,6 @@ def unify(tv1, tv2, typ):
typ.equivalent(tv1, tv2)
return typ
assert tv2.is_derived, "Ordering gives us !tv1.is_derived==>tv2.is_derived"
if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)):
inv_f = TypeVar.inverse_func(tv1.derived_func)
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)

View File

@@ -31,6 +31,7 @@ def verify_semantics(inst, src, xforms):
# 2) Any possible typing for the instruction should be covered by
# exactly ONE semantic XForm
src = src.copy({})
typenv = get_type_env(ti_rtl(src, TypeEnv()))
typenv.normalize()
typenv = typenv.extract()

View File

@@ -9,7 +9,6 @@ from __future__ import absolute_import
from cdsl.operands import Operand
from cdsl.typevar import TypeVar
from cdsl.instructions import Instruction, InstructionGroup
from cdsl.ti import SameWidth
import base.formats # noqa
GROUP = InstructionGroup("primitive", "Primitive instruction set")
@@ -39,8 +38,7 @@ prim_from_bv = Instruction(
'prim_from_bv', r"""
Convert a flat bitvector to a real SSA Value.
""",
ins=(x), outs=(real),
constraints=SameWidth(BV, Real))
ins=(fromReal), outs=(real))
xh = Operand('xh', BV.half_width(),
doc="A semantic value representing the upper half of X")

View File

@@ -205,8 +205,8 @@ class TestElaborate(TestCase):
assert concrete_rtls_eq(sem, cleanup_concrete_rtl(Rtl(
bvx << prim_to_bv.i32x4(x),
(bvlo, bvhi) << bvsplit.bv128(bvx),
lo << prim_from_bv.i32x2.bv64(bvlo),
hi << prim_from_bv.i32x2.bv64(bvhi))))
lo << prim_from_bv.i32x2(bvlo),
hi << prim_from_bv.i32x2(bvhi))))
def test_elaborate_vconcat(self):
# type: () -> None
@@ -227,7 +227,7 @@ class TestElaborate(TestCase):
bvlo << prim_to_bv.i32x2(lo),
bvhi << prim_to_bv.i32x2(hi),
bvx << bvconcat.bv64(bvlo, bvhi),
x << prim_from_bv.i32x4.bv128(bvx))))
x << prim_from_bv.i32x4(bvx))))
def test_elaborate_iadd_simple(self):
# type: () -> None
@@ -247,7 +247,7 @@ class TestElaborate(TestCase):
bvx << prim_to_bv.i32(x),
bvy << prim_to_bv.i32(y),
bva << bvadd.bv32(bvx, bvy),
a << prim_from_bv.i32.bv32(bva))))
a << prim_from_bv.i32(bva))))
def test_elaborate_iadd_elaborate_1(self):
# type: () -> None
@@ -279,7 +279,7 @@ class TestElaborate(TestCase):
bva_3 << bvadd.bv32(bvlo_1, bvlo_2),
bva_4 << bvadd.bv32(bvhi_1, bvhi_2),
bvx_5 << bvconcat.bv32(bva_3, bva_4),
a << prim_from_bv.i32x2.bv64(bvx_5))))
a << prim_from_bv.i32x2(bvx_5))))
def test_elaborate_iadd_elaborate_2(self):
# type: () -> None
@@ -334,4 +334,4 @@ class TestElaborate(TestCase):
bva_14 << bvadd.bv8(bvhi_11, bvhi_12),
bvx_15 << bvconcat.bv8(bva_13, bva_14),
bvx_5 << bvconcat.bv16(bvx_10, bvx_15),
a << prim_from_bv.i8x4.bv32(bvx_5))))
a << prim_from_bv.i8x4(bvx_5))))

View File

@@ -153,8 +153,7 @@ class TestRuntimeChecks(TestCase):
def test_vselect_imm(self):
# type: () -> None
ts = TypeSet(lanes=(2, 256), ints=(8, 64),
floats=(32, 64), bools=(8, 64))
ts = TypeSet(lanes=(2, 256), ints=(8, 64))
r = Rtl(
self.v0 << iconst(self.imm0),
self.v1 << icmp(intcc.eq, self.v2, self.v0),
@@ -167,7 +166,7 @@ class TestRuntimeChecks(TestCase):
.format(self.v3.get_typevar().name)
self.check_yo_check(
x, sequence(typeset_check(self.v3, ts),
x, sequence(typeset_check(self.v2, ts),
equiv_check(tv2_exp, tv3_exp)))
def test_reduce_extend(self):