diff --git a/lib/cretonne/meta/cdsl/ti.py b/lib/cretonne/meta/cdsl/ti.py index 7a95daf425..028579857d 100644 --- a/lib/cretonne/meta/cdsl/ti.py +++ b/lib/cretonne/meta/cdsl/ti.py @@ -8,12 +8,13 @@ from itertools import product try: from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa - from typing import Iterable, List, Any # noqa + from typing import Iterable, List, Any, TypeVar as MTypeVar # noqa from typing import cast from .xform import Rtl, XForm # noqa from .ast import Expr # noqa from .typevar import TypeSet # noqa if TYPE_CHECKING: + T = MTypeVar('T') TypeMap = Dict[TypeVar, TypeVar] VarTyping = Dict[Var, TypeVar] except ImportError: @@ -775,6 +776,11 @@ def unify(tv1, tv2, typ): return typ +def move_first(l, i): + # type: (List[T], int) -> List[T] + return [l[i]] + l[:i] + l[i+1:] + + def ti_def(definition, typ): # type: (Def, TypeEnv) -> TypingOrError """ @@ -821,6 +827,12 @@ def ti_def(definition, typ): typ.register(v) actual_tvs.append(v.get_typevar()) + # Make sure we unify the control typevar first. + if inst.is_polymorphic: + idx = fresh_formal_tvs.index(m[inst.ctrl_typevar]) + fresh_formal_tvs = move_first(fresh_formal_tvs, idx) + actual_tvs = move_first(actual_tvs, idx) + # Unify each actual typevar with the correpsonding fresh formal tv for (actual_tv, formal_tv) in zip(actual_tvs, fresh_formal_tvs): typ_or_err = unify(actual_tv, formal_tv, typ)