Change TV ranking to select src vars as a representative during unification; Nit: cleanup dot() emitting code; Nit: fix small bug in verify_semantics() - make an internal copy of src rtl to avoid clobbering of typevars re-used in multiple definitions

This commit is contained in:
Dimo
2017-07-25 15:09:22 -07:00
committed by Jakob Stoklund Olesen
parent 20d96a1ac4
commit e41ddf2a0d
6 changed files with 51 additions and 41 deletions

View File

@@ -158,8 +158,7 @@ class TypeCheckingBaseTest(TestCase):
self.v8 = Var("v8") self.v8 = Var("v8")
self.v9 = Var("v9") self.v9 = Var("v9")
self.imm0 = Var("imm0") self.imm0 = Var("imm0")
self.IxN_nonscalar = TypeVar("IxN_nonscalar", "", ints=True, self.IxN = TypeVar("IxN", "", ints=True, scalars=True, simd=True)
scalars=False, simd=True)
self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True, self.TxN = TypeVar("TxN", "", ints=True, bools=True, floats=True,
scalars=False, simd=True) scalars=False, simd=True)
self.b1 = TypeVar.singleton(b1) self.b1 = TypeVar.singleton(b1)
@@ -176,7 +175,7 @@ class TestRTL(TypeCheckingBaseTest):
self.assertEqual(ti_rtl(r, ti), self.assertEqual(ti_rtl(r, ti),
"On line 1: fail ti on `typeof_v2` <: `1`: " + "On line 1: fail ti on `typeof_v2` <: `1`: " +
"Error: empty type created when unifying " + "Error: empty type created when unifying " +
"`typeof_v2` and `half_vector(typeof_v2)`") "`typeof_v3` and `half_vector(typeof_v3)`")
def test_vselect(self): def test_vselect(self):
# type: () -> None # type: () -> None
@@ -202,11 +201,11 @@ class TestRTL(TypeCheckingBaseTest):
) )
ti = TypeEnv() ti = TypeEnv()
typing = ti_rtl(r, ti) typing = ti_rtl(r, ti)
ixn = self.IxN_nonscalar.get_fresh_copy("IxN1") ixn = self.IxN.get_fresh_copy("IxN1")
txn = self.TxN.get_fresh_copy("TxN1") txn = self.TxN.get_fresh_copy("TxN1")
check_typing(typing, ({ check_typing(typing, ({
self.v0: ixn, self.v0: ixn,
self.v1: ixn.as_bool(), self.v1: txn.as_bool(),
self.v2: ixn, self.v2: ixn,
self.v3: txn, self.v3: txn,
self.v4: txn, self.v4: txn,
@@ -319,7 +318,7 @@ class TestRTL(TypeCheckingBaseTest):
self.assertEqual(typing, self.assertEqual(typing,
"On line 2: fail ti on `typeof_v4` <: `4`: " + "On line 2: fail ti on `typeof_v4` <: `4`: " +
"Error: empty type created when unifying " + "Error: empty type created when unifying " +
"`typeof_v4` and `typeof_v5`") "`i16` and `i32`")
def test_extend_reduce(self): def test_extend_reduce(self):
# type: () -> None # type: () -> None
@@ -471,7 +470,7 @@ class TestXForm(TypeCheckingBaseTest):
assert var_m[v0] == var_m[v2] and \ assert var_m[v0] == var_m[v2] and \
var_m[v3] == var_m[v4] and\ var_m[v3] == var_m[v4] and\
var_m[v5] == var_m[v3] and\ var_m[v5] == var_m[v3] and\
var_m[v1] == var_m[v2].as_bool() and\ var_m[v1] == var_m[v5].as_bool() and\
var_m[v1].get_typeset() == var_m[v3].as_bool().get_typeset() var_m[v1].get_typeset() == var_m[v3].as_bool().get_typeset()
check_concrete_typing_xform(var_m, xform) check_concrete_typing_xform(var_m, xform)

View File

@@ -104,7 +104,6 @@ class TypesEqual(TypeConstraint):
""" """
def __init__(self, tv1, tv2): def __init__(self, tv1, tv2):
# type: (TypeVar, TypeVar) -> None # type: (TypeVar, TypeVar) -> None
assert tv1.is_derived and tv2.is_derived
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr) (self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
def _args(self): def _args(self):
@@ -279,7 +278,7 @@ class TypeEnv(object):
:attribute idx: counter used to get fresh ids :attribute idx: counter used to get fresh ids
""" """
RANK_DERIVED = 5 RANK_SINGLETON = 5
RANK_INPUT = 4 RANK_INPUT = 4
RANK_INTERMEDIATE = 3 RANK_INTERMEDIATE = 3
RANK_OUTPUT = 2 RANK_OUTPUT = 2
@@ -364,12 +363,18 @@ class TypeEnv(object):
# type: (TypeVar) -> int # type: (TypeVar) -> int
""" """
Get the rank of tv in the partial order. TVs directly associated with a Get the rank of tv in the partial order. TVs directly associated with a
Var get their rank from the Var (see register()). Var get their rank from the Var (see register()). Internally generated
Internally generated non-derived TVs implicitly get the lowest rank (0) non-derived TVs implicitly get the lowest rank (0). Derived variables
Derived variables get the highest rank. get their rank from their free typevar. Singletons have the highest
rank. TVs associated with vars in a source pattern have a higher rank
than TVs associted with temporary vars.
""" """
default_rank = TypeEnv.RANK_DERIVED if tv.is_derived else\ default_rank = TypeEnv.RANK_INTERNAL if tv.singleton_type() is None \
TypeEnv.RANK_INTERNAL else TypeEnv.RANK_SINGLETON
if tv.is_derived:
tv = tv.free_typevar()
return self.ranks.get(tv, default_rank) return self.ranks.get(tv, default_rank)
def register(self, v): def register(self, v):
@@ -565,28 +570,36 @@ class TypeEnv(object):
# Add all registered TVs (as some of them may be singleton nodes not # Add all registered TVs (as some of them may be singleton nodes not
# appearing in the graph # appearing in the graph
nodes = set([v.get_typevar() for v in self.vars]) # type: Set[TypeVar] nodes = set() # type: Set[TypeVar]
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
for (k, v) in self.type_map.items(): def add_nodes(*args):
# type: (*TypeVar) -> None
for tv in args:
nodes.add(tv)
while (tv.is_derived):
nodes.add(tv.base)
edges.add((tv, tv.base, "solid", "forward",
tv.derived_func))
tv = tv.base
for v in self.vars:
add_nodes(v.get_typevar())
for (tv1, tv2) in self.type_map.items():
# Add all intermediate TVs appearing in edges # Add all intermediate TVs appearing in edges
nodes.add(k) add_nodes(tv1, tv2)
nodes.add(v) edges.add((tv1, tv2, "dotted", "forward", None))
edges.add((k, v, "dotted", "forward", None))
while (v.is_derived):
nodes.add(v.base)
edges.add((v, v.base, "solid", "forward", v.derived_func))
v = v.base
for constr in self.constraints: for constr in self.constraints:
if isinstance(constr, TypesEqual): if isinstance(constr, TypesEqual):
assert constr.tv1 in nodes and constr.tv2 in nodes add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal")) edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
elif isinstance(constr, WiderOrEq): elif isinstance(constr, WiderOrEq):
assert constr.tv1 in nodes and constr.tv2 in nodes add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">=")) edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
elif isinstance(constr, SameWidth): elif isinstance(constr, SameWidth):
assert constr.tv1 in nodes and constr.tv2 in nodes add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "none", edges.add((constr.tv1, constr.tv2, "dashed", "none",
"same_width")) "same_width"))
else: else:
@@ -640,7 +653,9 @@ def get_type_env(typing_or_err):
""" """
Helper function to appease mypy when checking the result of typing. Helper function to appease mypy when checking the result of typing.
""" """
assert isinstance(typing_or_err, TypeEnv) assert isinstance(typing_or_err, TypeEnv), \
"Unexpected error: {}".format(typing_or_err)
if (TYPE_CHECKING): if (TYPE_CHECKING):
return cast(TypeEnv, typing_or_err) return cast(TypeEnv, typing_or_err)
else: else:
@@ -752,8 +767,6 @@ def unify(tv1, tv2, typ):
typ.equivalent(tv1, tv2) typ.equivalent(tv1, tv2)
return typ return typ
assert tv2.is_derived, "Ordering gives us !tv1.is_derived==>tv2.is_derived"
if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)): if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)):
inv_f = TypeVar.inverse_func(tv1.derived_func) inv_f = TypeVar.inverse_func(tv1.derived_func)
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ) return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)

View File

@@ -31,6 +31,7 @@ def verify_semantics(inst, src, xforms):
# 2) Any possible typing for the instruction should be covered by # 2) Any possible typing for the instruction should be covered by
# exactly ONE semantic XForm # exactly ONE semantic XForm
src = src.copy({})
typenv = get_type_env(ti_rtl(src, TypeEnv())) typenv = get_type_env(ti_rtl(src, TypeEnv()))
typenv.normalize() typenv.normalize()
typenv = typenv.extract() typenv = typenv.extract()

View File

@@ -9,7 +9,6 @@ from __future__ import absolute_import
from cdsl.operands import Operand from cdsl.operands import Operand
from cdsl.typevar import TypeVar from cdsl.typevar import TypeVar
from cdsl.instructions import Instruction, InstructionGroup from cdsl.instructions import Instruction, InstructionGroup
from cdsl.ti import SameWidth
import base.formats # noqa import base.formats # noqa
GROUP = InstructionGroup("primitive", "Primitive instruction set") GROUP = InstructionGroup("primitive", "Primitive instruction set")
@@ -39,8 +38,7 @@ prim_from_bv = Instruction(
'prim_from_bv', r""" 'prim_from_bv', r"""
Convert a flat bitvector to a real SSA Value. Convert a flat bitvector to a real SSA Value.
""", """,
ins=(x), outs=(real), ins=(fromReal), outs=(real))
constraints=SameWidth(BV, Real))
xh = Operand('xh', BV.half_width(), xh = Operand('xh', BV.half_width(),
doc="A semantic value representing the upper half of X") doc="A semantic value representing the upper half of X")

View File

@@ -205,8 +205,8 @@ class TestElaborate(TestCase):
assert concrete_rtls_eq(sem, cleanup_concrete_rtl(Rtl( assert concrete_rtls_eq(sem, cleanup_concrete_rtl(Rtl(
bvx << prim_to_bv.i32x4(x), bvx << prim_to_bv.i32x4(x),
(bvlo, bvhi) << bvsplit.bv128(bvx), (bvlo, bvhi) << bvsplit.bv128(bvx),
lo << prim_from_bv.i32x2.bv64(bvlo), lo << prim_from_bv.i32x2(bvlo),
hi << prim_from_bv.i32x2.bv64(bvhi)))) hi << prim_from_bv.i32x2(bvhi))))
def test_elaborate_vconcat(self): def test_elaborate_vconcat(self):
# type: () -> None # type: () -> None
@@ -227,7 +227,7 @@ class TestElaborate(TestCase):
bvlo << prim_to_bv.i32x2(lo), bvlo << prim_to_bv.i32x2(lo),
bvhi << prim_to_bv.i32x2(hi), bvhi << prim_to_bv.i32x2(hi),
bvx << bvconcat.bv64(bvlo, bvhi), bvx << bvconcat.bv64(bvlo, bvhi),
x << prim_from_bv.i32x4.bv128(bvx)))) x << prim_from_bv.i32x4(bvx))))
def test_elaborate_iadd_simple(self): def test_elaborate_iadd_simple(self):
# type: () -> None # type: () -> None
@@ -247,7 +247,7 @@ class TestElaborate(TestCase):
bvx << prim_to_bv.i32(x), bvx << prim_to_bv.i32(x),
bvy << prim_to_bv.i32(y), bvy << prim_to_bv.i32(y),
bva << bvadd.bv32(bvx, bvy), bva << bvadd.bv32(bvx, bvy),
a << prim_from_bv.i32.bv32(bva)))) a << prim_from_bv.i32(bva))))
def test_elaborate_iadd_elaborate_1(self): def test_elaborate_iadd_elaborate_1(self):
# type: () -> None # type: () -> None
@@ -279,7 +279,7 @@ class TestElaborate(TestCase):
bva_3 << bvadd.bv32(bvlo_1, bvlo_2), bva_3 << bvadd.bv32(bvlo_1, bvlo_2),
bva_4 << bvadd.bv32(bvhi_1, bvhi_2), bva_4 << bvadd.bv32(bvhi_1, bvhi_2),
bvx_5 << bvconcat.bv32(bva_3, bva_4), bvx_5 << bvconcat.bv32(bva_3, bva_4),
a << prim_from_bv.i32x2.bv64(bvx_5)))) a << prim_from_bv.i32x2(bvx_5))))
def test_elaborate_iadd_elaborate_2(self): def test_elaborate_iadd_elaborate_2(self):
# type: () -> None # type: () -> None
@@ -334,4 +334,4 @@ class TestElaborate(TestCase):
bva_14 << bvadd.bv8(bvhi_11, bvhi_12), bva_14 << bvadd.bv8(bvhi_11, bvhi_12),
bvx_15 << bvconcat.bv8(bva_13, bva_14), bvx_15 << bvconcat.bv8(bva_13, bva_14),
bvx_5 << bvconcat.bv16(bvx_10, bvx_15), bvx_5 << bvconcat.bv16(bvx_10, bvx_15),
a << prim_from_bv.i8x4.bv32(bvx_5)))) a << prim_from_bv.i8x4(bvx_5))))

View File

@@ -153,8 +153,7 @@ class TestRuntimeChecks(TestCase):
def test_vselect_imm(self): def test_vselect_imm(self):
# type: () -> None # type: () -> None
ts = TypeSet(lanes=(2, 256), ints=(8, 64), ts = TypeSet(lanes=(2, 256), ints=(8, 64))
floats=(32, 64), bools=(8, 64))
r = Rtl( r = Rtl(
self.v0 << iconst(self.imm0), self.v0 << iconst(self.imm0),
self.v1 << icmp(intcc.eq, self.v2, self.v0), self.v1 << icmp(intcc.eq, self.v2, self.v0),
@@ -167,7 +166,7 @@ class TestRuntimeChecks(TestCase):
.format(self.v3.get_typevar().name) .format(self.v3.get_typevar().name)
self.check_yo_check( self.check_yo_check(
x, sequence(typeset_check(self.v3, ts), x, sequence(typeset_check(self.v2, ts),
equiv_check(tv2_exp, tv3_exp))) equiv_check(tv2_exp, tv3_exp)))
def test_reduce_extend(self): def test_reduce_extend(self):