When doing ti on a polymorphic definition first unify the control variable, then the rest.
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
2387745847
commit
7d1a9c7d81
@@ -8,12 +8,13 @@ from itertools import product
|
||||
|
||||
try:
|
||||
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
||||
from typing import Iterable, List, Any # noqa
|
||||
from typing import Iterable, List, Any, TypeVar as MTypeVar # noqa
|
||||
from typing import cast
|
||||
from .xform import Rtl, XForm # noqa
|
||||
from .ast import Expr # noqa
|
||||
from .typevar import TypeSet # noqa
|
||||
if TYPE_CHECKING:
|
||||
T = MTypeVar('T')
|
||||
TypeMap = Dict[TypeVar, TypeVar]
|
||||
VarTyping = Dict[Var, TypeVar]
|
||||
except ImportError:
|
||||
@@ -775,6 +776,11 @@ def unify(tv1, tv2, typ):
|
||||
return typ
|
||||
|
||||
|
||||
def move_first(l, i):
|
||||
# type: (List[T], int) -> List[T]
|
||||
return [l[i]] + l[:i] + l[i+1:]
|
||||
|
||||
|
||||
def ti_def(definition, typ):
|
||||
# type: (Def, TypeEnv) -> TypingOrError
|
||||
"""
|
||||
@@ -821,6 +827,12 @@ def ti_def(definition, typ):
|
||||
typ.register(v)
|
||||
actual_tvs.append(v.get_typevar())
|
||||
|
||||
# Make sure we unify the control typevar first.
|
||||
if inst.is_polymorphic:
|
||||
idx = fresh_formal_tvs.index(m[inst.ctrl_typevar])
|
||||
fresh_formal_tvs = move_first(fresh_formal_tvs, idx)
|
||||
actual_tvs = move_first(actual_tvs, idx)
|
||||
|
||||
# Unify each actual typevar with the correpsonding fresh formal tv
|
||||
for (actual_tv, formal_tv) in zip(actual_tvs, fresh_formal_tvs):
|
||||
typ_or_err = unify(actual_tv, formal_tv, typ)
|
||||
|
||||
Reference in New Issue
Block a user