Change TV ranking to select src vars as a representative during unification; Nit: cleanup dot() emitting code; Nit: fix small bug in verify_semantics() - make an internal copy of src rtl to avoid clobbering of typevars re-used in multiple definitions
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
20d96a1ac4
commit
e41ddf2a0d
@@ -158,8 +158,7 @@ class TypeCheckingBaseTest(TestCase):
|
||||
self.v8 = Var("v8")
|
||||
self.v9 = Var("v9")
|
||||
self.imm0 = Var("imm0")
|
||||
self.IxN_nonscalar = TypeVar("IxN_nonscalar", "", ints=True,
|
||||
scalars=False, simd=True)
|
||||
self.IxN = TypeVar("IxN", "", ints=True, scalars=True, simd=True)
|
||||
self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True,
|
||||
scalars=False, simd=True)
|
||||
self.b1 = TypeVar.singleton(b1)
|
||||
@@ -176,7 +175,7 @@ class TestRTL(TypeCheckingBaseTest):
|
||||
self.assertEqual(ti_rtl(r, ti),
|
||||
"On line 1: fail ti on `typeof_v2` <: `1`: " +
|
||||
"Error: empty type created when unifying " +
|
||||
"`typeof_v2` and `half_vector(typeof_v2)`")
|
||||
"`typeof_v3` and `half_vector(typeof_v3)`")
|
||||
|
||||
def test_vselect(self):
|
||||
# type: () -> None
|
||||
@@ -202,11 +201,11 @@ class TestRTL(TypeCheckingBaseTest):
|
||||
)
|
||||
ti = TypeEnv()
|
||||
typing = ti_rtl(r, ti)
|
||||
ixn = self.IxN_nonscalar.get_fresh_copy("IxN1")
|
||||
ixn = self.IxN.get_fresh_copy("IxN1")
|
||||
txn = self.TxN.get_fresh_copy("TxN1")
|
||||
check_typing(typing, ({
|
||||
self.v0: ixn,
|
||||
self.v1: ixn.as_bool(),
|
||||
self.v1: txn.as_bool(),
|
||||
self.v2: ixn,
|
||||
self.v3: txn,
|
||||
self.v4: txn,
|
||||
@@ -319,7 +318,7 @@ class TestRTL(TypeCheckingBaseTest):
|
||||
self.assertEqual(typing,
|
||||
"On line 2: fail ti on `typeof_v4` <: `4`: " +
|
||||
"Error: empty type created when unifying " +
|
||||
"`typeof_v4` and `typeof_v5`")
|
||||
"`i16` and `i32`")
|
||||
|
||||
def test_extend_reduce(self):
|
||||
# type: () -> None
|
||||
@@ -471,7 +470,7 @@ class TestXForm(TypeCheckingBaseTest):
|
||||
assert var_m[v0] == var_m[v2] and \
|
||||
var_m[v3] == var_m[v4] and\
|
||||
var_m[v5] == var_m[v3] and\
|
||||
var_m[v1] == var_m[v2].as_bool() and\
|
||||
var_m[v1] == var_m[v5].as_bool() and\
|
||||
var_m[v1].get_typeset() == var_m[v3].as_bool().get_typeset()
|
||||
check_concrete_typing_xform(var_m, xform)
|
||||
|
||||
|
||||
@@ -104,7 +104,6 @@ class TypesEqual(TypeConstraint):
|
||||
"""
|
||||
def __init__(self, tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> None
|
||||
assert tv1.is_derived and tv2.is_derived
|
||||
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
|
||||
|
||||
def _args(self):
|
||||
@@ -279,7 +278,7 @@ class TypeEnv(object):
|
||||
:attribute idx: counter used to get fresh ids
|
||||
"""
|
||||
|
||||
RANK_DERIVED = 5
|
||||
RANK_SINGLETON = 5
|
||||
RANK_INPUT = 4
|
||||
RANK_INTERMEDIATE = 3
|
||||
RANK_OUTPUT = 2
|
||||
@@ -364,12 +363,18 @@ class TypeEnv(object):
|
||||
# type: (TypeVar) -> int
|
||||
"""
|
||||
Get the rank of tv in the partial order. TVs directly associated with a
|
||||
Var get their rank from the Var (see register()).
|
||||
Internally generated non-derived TVs implicitly get the lowest rank (0)
|
||||
Derived variables get the highest rank.
|
||||
Var get their rank from the Var (see register()). Internally generated
|
||||
non-derived TVs implicitly get the lowest rank (0). Derived variables
|
||||
get their rank from their free typevar. Singletons have the highest
|
||||
rank. TVs associated with vars in a source pattern have a higher rank
|
||||
than TVs associted with temporary vars.
|
||||
"""
|
||||
default_rank = TypeEnv.RANK_DERIVED if tv.is_derived else\
|
||||
TypeEnv.RANK_INTERNAL
|
||||
default_rank = TypeEnv.RANK_INTERNAL if tv.singleton_type() is None \
|
||||
else TypeEnv.RANK_SINGLETON
|
||||
|
||||
if tv.is_derived:
|
||||
tv = tv.free_typevar()
|
||||
|
||||
return self.ranks.get(tv, default_rank)
|
||||
|
||||
def register(self, v):
|
||||
@@ -565,28 +570,36 @@ class TypeEnv(object):
|
||||
|
||||
# Add all registered TVs (as some of them may be singleton nodes not
|
||||
# appearing in the graph
|
||||
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar]
|
||||
nodes = set() # type: Set[TypeVar]
|
||||
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
|
||||
|
||||
for (k, v) in self.type_map.items():
|
||||
def add_nodes(*args):
|
||||
# type: (*TypeVar) -> None
|
||||
for tv in args:
|
||||
nodes.add(tv)
|
||||
while (tv.is_derived):
|
||||
nodes.add(tv.base)
|
||||
edges.add((tv, tv.base, "solid", "forward",
|
||||
tv.derived_func))
|
||||
tv = tv.base
|
||||
|
||||
for v in self.vars:
|
||||
add_nodes(v.get_typevar())
|
||||
|
||||
for (tv1, tv2) in self.type_map.items():
|
||||
# Add all intermediate TVs appearing in edges
|
||||
nodes.add(k)
|
||||
nodes.add(v)
|
||||
edges.add((k, v, "dotted", "forward", None))
|
||||
while (v.is_derived):
|
||||
nodes.add(v.base)
|
||||
edges.add((v, v.base, "solid", "forward", v.derived_func))
|
||||
v = v.base
|
||||
add_nodes(tv1, tv2)
|
||||
edges.add((tv1, tv2, "dotted", "forward", None))
|
||||
|
||||
for constr in self.constraints:
|
||||
if isinstance(constr, TypesEqual):
|
||||
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||
add_nodes(constr.tv1, constr.tv2)
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
|
||||
elif isinstance(constr, WiderOrEq):
|
||||
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||
add_nodes(constr.tv1, constr.tv2)
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
|
||||
elif isinstance(constr, SameWidth):
|
||||
assert constr.tv1 in nodes and constr.tv2 in nodes
|
||||
add_nodes(constr.tv1, constr.tv2)
|
||||
edges.add((constr.tv1, constr.tv2, "dashed", "none",
|
||||
"same_width"))
|
||||
else:
|
||||
@@ -640,7 +653,9 @@ def get_type_env(typing_or_err):
|
||||
"""
|
||||
Helper function to appease mypy when checking the result of typing.
|
||||
"""
|
||||
assert isinstance(typing_or_err, TypeEnv)
|
||||
assert isinstance(typing_or_err, TypeEnv), \
|
||||
"Unexpected error: {}".format(typing_or_err)
|
||||
|
||||
if (TYPE_CHECKING):
|
||||
return cast(TypeEnv, typing_or_err)
|
||||
else:
|
||||
@@ -752,8 +767,6 @@ def unify(tv1, tv2, typ):
|
||||
typ.equivalent(tv1, tv2)
|
||||
return typ
|
||||
|
||||
assert tv2.is_derived, "Ordering gives us !tv1.is_derived==>tv2.is_derived"
|
||||
|
||||
if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)):
|
||||
inv_f = TypeVar.inverse_func(tv1.derived_func)
|
||||
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
|
||||
|
||||
@@ -31,6 +31,7 @@ def verify_semantics(inst, src, xforms):
|
||||
|
||||
# 2) Any possible typing for the instruction should be covered by
|
||||
# exactly ONE semantic XForm
|
||||
src = src.copy({})
|
||||
typenv = get_type_env(ti_rtl(src, TypeEnv()))
|
||||
typenv.normalize()
|
||||
typenv = typenv.extract()
|
||||
|
||||
@@ -9,7 +9,6 @@ from __future__ import absolute_import
|
||||
from cdsl.operands import Operand
|
||||
from cdsl.typevar import TypeVar
|
||||
from cdsl.instructions import Instruction, InstructionGroup
|
||||
from cdsl.ti import SameWidth
|
||||
import base.formats # noqa
|
||||
|
||||
GROUP = InstructionGroup("primitive", "Primitive instruction set")
|
||||
@@ -39,8 +38,7 @@ prim_from_bv = Instruction(
|
||||
'prim_from_bv', r"""
|
||||
Convert a flat bitvector to a real SSA Value.
|
||||
""",
|
||||
ins=(x), outs=(real),
|
||||
constraints=SameWidth(BV, Real))
|
||||
ins=(fromReal), outs=(real))
|
||||
|
||||
xh = Operand('xh', BV.half_width(),
|
||||
doc="A semantic value representing the upper half of X")
|
||||
|
||||
@@ -205,8 +205,8 @@ class TestElaborate(TestCase):
|
||||
assert concrete_rtls_eq(sem, cleanup_concrete_rtl(Rtl(
|
||||
bvx << prim_to_bv.i32x4(x),
|
||||
(bvlo, bvhi) << bvsplit.bv128(bvx),
|
||||
lo << prim_from_bv.i32x2.bv64(bvlo),
|
||||
hi << prim_from_bv.i32x2.bv64(bvhi))))
|
||||
lo << prim_from_bv.i32x2(bvlo),
|
||||
hi << prim_from_bv.i32x2(bvhi))))
|
||||
|
||||
def test_elaborate_vconcat(self):
|
||||
# type: () -> None
|
||||
@@ -227,7 +227,7 @@ class TestElaborate(TestCase):
|
||||
bvlo << prim_to_bv.i32x2(lo),
|
||||
bvhi << prim_to_bv.i32x2(hi),
|
||||
bvx << bvconcat.bv64(bvlo, bvhi),
|
||||
x << prim_from_bv.i32x4.bv128(bvx))))
|
||||
x << prim_from_bv.i32x4(bvx))))
|
||||
|
||||
def test_elaborate_iadd_simple(self):
|
||||
# type: () -> None
|
||||
@@ -247,7 +247,7 @@ class TestElaborate(TestCase):
|
||||
bvx << prim_to_bv.i32(x),
|
||||
bvy << prim_to_bv.i32(y),
|
||||
bva << bvadd.bv32(bvx, bvy),
|
||||
a << prim_from_bv.i32.bv32(bva))))
|
||||
a << prim_from_bv.i32(bva))))
|
||||
|
||||
def test_elaborate_iadd_elaborate_1(self):
|
||||
# type: () -> None
|
||||
@@ -279,7 +279,7 @@ class TestElaborate(TestCase):
|
||||
bva_3 << bvadd.bv32(bvlo_1, bvlo_2),
|
||||
bva_4 << bvadd.bv32(bvhi_1, bvhi_2),
|
||||
bvx_5 << bvconcat.bv32(bva_3, bva_4),
|
||||
a << prim_from_bv.i32x2.bv64(bvx_5))))
|
||||
a << prim_from_bv.i32x2(bvx_5))))
|
||||
|
||||
def test_elaborate_iadd_elaborate_2(self):
|
||||
# type: () -> None
|
||||
@@ -334,4 +334,4 @@ class TestElaborate(TestCase):
|
||||
bva_14 << bvadd.bv8(bvhi_11, bvhi_12),
|
||||
bvx_15 << bvconcat.bv8(bva_13, bva_14),
|
||||
bvx_5 << bvconcat.bv16(bvx_10, bvx_15),
|
||||
a << prim_from_bv.i8x4.bv32(bvx_5))))
|
||||
a << prim_from_bv.i8x4(bvx_5))))
|
||||
|
||||
@@ -153,8 +153,7 @@ class TestRuntimeChecks(TestCase):
|
||||
|
||||
def test_vselect_imm(self):
|
||||
# type: () -> None
|
||||
ts = TypeSet(lanes=(2, 256), ints=(8, 64),
|
||||
floats=(32, 64), bools=(8, 64))
|
||||
ts = TypeSet(lanes=(2, 256), ints=(8, 64))
|
||||
r = Rtl(
|
||||
self.v0 << iconst(self.imm0),
|
||||
self.v1 << icmp(intcc.eq, self.v2, self.v0),
|
||||
@@ -167,7 +166,7 @@ class TestRuntimeChecks(TestCase):
|
||||
.format(self.v3.get_typevar().name)
|
||||
|
||||
self.check_yo_check(
|
||||
x, sequence(typeset_check(self.v3, ts),
|
||||
x, sequence(typeset_check(self.v2, ts),
|
||||
equiv_check(tv2_exp, tv3_exp)))
|
||||
|
||||
def test_reduce_extend(self):
|
||||
|
||||
Reference in New Issue
Block a user