Skip to content

Commit fe473ab

Browse files
committed
Replace class/function distinction with plain/metavariable distinction
Fixes #603.
1 parent 8c0ca90 commit fe473ab

File tree

4 files changed

+62
-20
lines changed

4 files changed

+62
-20
lines changed

mypy/checkexpr.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mypy.types import (
66
Type, AnyType, CallableType, Overloaded, NoneTyp, Void, TypeVarDef,
7-
TupleType, Instance, TypeVarType, ErasedType, UnionType,
7+
TupleType, Instance, TypeVarId, TypeVarType, ErasedType, UnionType,
88
PartialType, DeletedType, UnboundType, UninhabitedType, TypeType
99
)
1010
from mypy.nodes import (
@@ -22,7 +22,7 @@
2222
import mypy.checker
2323
from mypy import types
2424
from mypy.sametypes import is_same_type
25-
from mypy.erasetype import replace_func_type_vars
25+
from mypy.erasetype import replace_meta_vars
2626
from mypy.messages import MessageBuilder
2727
from mypy import messages
2828
from mypy.infer import infer_type_arguments, infer_function_type_arguments
@@ -34,6 +34,7 @@
3434
from mypy.semanal import self_type
3535
from mypy.constraints import get_actual_type
3636
from mypy.checkstrformat import StringFormatterChecker
37+
from mypy.expandtype import expand_type
3738

3839
from mypy import experiments
3940

@@ -234,6 +235,7 @@ def check_call(self, callee: Type, args: List[Node],
234235
lambda i: self.accept(args[i]))
235236

236237
if callee.is_generic():
238+
callee = freshen_generic_callable(callee)
237239
callee = self.infer_function_type_arguments_using_context(
238240
callee, context)
239241
callee = self.infer_function_type_arguments(
@@ -394,12 +396,12 @@ def infer_function_type_arguments_using_context(
394396
ctx = self.chk.type_context[-1]
395397
if not ctx:
396398
return callable
397-
# The return type may have references to function type variables that
399+
# The return type may have references to type metavariables that
398400
# we are inferring right now. We must consider them as indeterminate
399401
# and they are not potential results; thus we replace them with the
400402
# special ErasedType type. On the other hand, class type variables are
401403
# valid results.
402-
erased_ctx = replace_func_type_vars(ctx, ErasedType())
404+
erased_ctx = replace_meta_vars(ctx, ErasedType())
403405
ret_type = callable.ret_type
404406
if isinstance(ret_type, TypeVarType):
405407
if ret_type.values or (not isinstance(ctx, Instance) or
@@ -1362,7 +1364,7 @@ def infer_lambda_type_using_context(self, e: FuncExpr) -> CallableType:
13621364
# they must be considered as indeterminate. We use ErasedType since it
13631365
# does not affect type inference results (it is for purposes like this
13641366
# only).
1365-
ctx = replace_func_type_vars(ctx, ErasedType())
1367+
ctx = replace_meta_vars(ctx, ErasedType())
13661368

13671369
callable_ctx = cast(CallableType, ctx)
13681370

@@ -1779,3 +1781,14 @@ def overload_arg_similarity(actual: Type, formal: Type) -> int:
17791781
return 2
17801782
# Fall back to a conservative equality check for the remaining kinds of type.
17811783
return 2 if is_same_type(erasetype.erase_type(actual), erasetype.erase_type(formal)) else 0
1784+
1785+
1786+
def freshen_generic_callable(callee: CallableType) -> CallableType:
1787+
tvdefs = []
1788+
tvmap = {} # type: Dict[TypeVarId, Type]
1789+
for v in callee.variables:
1790+
tvdef = TypeVarDef.new_unification_variable(v)
1791+
tvdefs.append(tvdef)
1792+
tvmap[v.id] = TypeVarType(tvdef)
1793+
1794+
return cast(CallableType, expand_type(callee, tvmap)).copy_modified(variables=tvdefs)

mypy/erasetype.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def erase_id(id: TypeVarId) -> bool:
116116
return t.accept(TypeVarEraser(erase_id, AnyType()))
117117

118118

119-
def replace_func_type_vars(t: Type, target_type: Type) -> Type:
120-
"""Replace function type variables in a type with the target type."""
121-
return t.accept(TypeVarEraser(lambda id: id.is_func_var(), target_type))
119+
def replace_meta_vars(t: Type, target_type: Type) -> Type:
120+
"""Replace unification variables in a type with the target type."""
121+
return t.accept(TypeVarEraser(lambda id: id.is_meta_var(), target_type))
122122

123123

124124
class TypeVarEraser(TypeTranslator):

mypy/types.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from abc import abstractmethod
44
from typing import (
5-
Any, TypeVar, Dict, List, Tuple, cast, Generic, Set, Sequence, Optional
5+
Any, TypeVar, Dict, List, Tuple, cast, Generic, Set, Sequence, Optional, Union
66
)
77

88
import mypy.nodes
@@ -49,32 +49,43 @@ def deserialize(cls, data: JsonDict) -> 'Type':
4949

5050
class TypeVarId:
5151
# 1, 2, ... for type-related, -1, ... for function-related
52-
5352
raw_id = 0 # type: int
5453

55-
def __init__(self, raw_id: int) -> None:
54+
# Level of the variable in type inference. Currently either 0 for
55+
# declared types, or 1 for type inference unification variables.
56+
meta_level = 0 # type: int
57+
58+
# Used for allocating fresh ids
59+
next_raw_id = 1 # type: int
60+
61+
def __init__(self, raw_id: int, meta_level: int = 0) -> None:
5662
self.raw_id = raw_id
63+
self.meta_level = meta_level
64+
65+
@staticmethod
66+
def new(meta_level: int) -> 'TypeVarId':
67+
raw_id = TypeVarId.next_raw_id
68+
TypeVarId.next_raw_id += 1
69+
return TypeVarId(raw_id, meta_level)
5770

5871
def __repr__(self) -> str:
5972
return self.raw_id.__repr__()
6073

6174
def __eq__(self, other: object) -> bool:
6275
if isinstance(other, TypeVarId):
63-
return self.raw_id == other.raw_id
76+
return (self.raw_id == other.raw_id and
77+
self.meta_level == other.meta_level)
6478
else:
6579
return False
6680

6781
def __ne__(self, other: object) -> bool:
6882
return not (self == other)
6983

7084
def __hash__(self) -> int:
71-
return hash(self.raw_id)
72-
73-
def is_class_var(self) -> bool:
74-
return self.raw_id > 0
85+
return hash((self.raw_id, self.meta_level))
7586

76-
def is_func_var(self) -> bool:
77-
return self.raw_id < 0
87+
def is_meta_var(self) -> bool:
88+
return self.meta_level > 0
7889

7990

8091
class TypeVarDef(mypy.nodes.Context):
@@ -87,15 +98,23 @@ class TypeVarDef(mypy.nodes.Context):
8798
variance = INVARIANT # type: int
8899
line = 0
89100

90-
def __init__(self, name: str, raw_id: int, values: Optional[List[Type]],
101+
def __init__(self, name: str, id: Union[TypeVarId, int], values: Optional[List[Type]],
91102
upper_bound: Type, variance: int = INVARIANT, line: int = -1) -> None:
92103
self.name = name
93-
self.id = TypeVarId(raw_id)
104+
if isinstance(id, int):
105+
id = TypeVarId(id)
106+
self.id = id
94107
self.values = values
95108
self.upper_bound = upper_bound
96109
self.variance = variance
97110
self.line = line
98111

112+
@staticmethod
113+
def new_unification_variable(old: 'TypeVarDef') -> 'TypeVarDef':
114+
new_id = TypeVarId.new(meta_level=1)
115+
return TypeVarDef(old.name, new_id, old.values,
116+
old.upper_bound, old.variance, old.line)
117+
99118
def get_line(self) -> int:
100119
return self.line
101120

@@ -108,6 +127,7 @@ def __repr__(self) -> str:
108127
return self.name
109128

110129
def serialize(self) -> JsonDict:
130+
assert not self.id.is_meta_var()
111131
return {'.class': 'TypeVarDef',
112132
'name': self.name,
113133
'id': self.id.raw_id,
@@ -424,6 +444,7 @@ def erase_to_union_or_bound(self) -> Type:
424444
return self.upper_bound
425445

426446
def serialize(self) -> JsonDict:
447+
assert not self.id.is_meta_var()
427448
return {'.class': 'TypeVarType',
428449
'name': self.name,
429450
'id': self.id.raw_id,

test-data/unit/check-inference-context.test

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,3 +748,11 @@ def f2(iterable: Iterable[Tuple[str, Any]], **kw: Any) -> None:
748748
pass
749749
[builtins fixtures/dict.py]
750750
[out]
751+
752+
[case testInferenceInGenericFunction]
753+
from typing import TypeVar, List
754+
T = TypeVar('T')
755+
def f(a: T) -> None:
756+
l = [] # type: List[T]
757+
l.append(a)
758+
[builtins fixtures/list.py]

0 commit comments

Comments
 (0)