Skip to content

Commit a8b6d6f

Browse files
authored
Fix inference of protocol against overloaded function (#12227)
We used to infer a callable in a protocol against all overload items. This could result in incorrect results, if only one of the overload items would actually match the protocol. Fix the issue by only considering the first matching overload item. This seems to help with protocols involving `__getitem__`. In particular, this fixes regressions related to `SupportsLenAndGetItem`, which is used for `random.choice`.
1 parent 2c9a8e7 commit a8b6d6f

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

mypy/constraints.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,12 @@ def infer_against_any(self, types: Iterable[Type], any_type: AnyType) -> List[Co
658658
return res
659659

660660
def visit_overloaded(self, template: Overloaded) -> List[Constraint]:
661+
if isinstance(self.actual, CallableType):
662+
items = find_matching_overload_items(template, self.actual)
663+
else:
664+
items = template.items
661665
res: List[Constraint] = []
662-
for t in template.items:
666+
for t in items:
663667
res.extend(infer_constraints(t, self.actual, self.direction))
664668
return res
665669

@@ -701,3 +705,22 @@ def find_matching_overload_item(overloaded: Overloaded, template: CallableType)
701705
# Fall back to the first item if we can't find a match. This is totally arbitrary --
702706
# maybe we should just bail out at this point.
703707
return items[0]
708+
709+
710+
def find_matching_overload_items(overloaded: Overloaded,
711+
template: CallableType) -> List[CallableType]:
712+
"""Like find_matching_overload_item, but return all matches, not just the first."""
713+
items = overloaded.items
714+
res = []
715+
for item in items:
716+
# Return type may be indeterminate in the template, so ignore it when performing a
717+
# subtype check.
718+
if mypy.subtypes.is_callable_compatible(item, template,
719+
is_compat=mypy.subtypes.is_subtype,
720+
ignore_return=True):
721+
res.append(item)
722+
if not res:
723+
# Falling back to all items if we can't find a match is pretty arbitrary, but
724+
# it maintains backward compatibility.
725+
res = items[:]
726+
return res

test-data/unit/check-protocols.test

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2806,3 +2806,86 @@ class MyClass:
28062806
assert isinstance(self, MyProtocol)
28072807
[builtins fixtures/isinstance.pyi]
28082808
[typing fixtures/typing-full.pyi]
2809+
2810+
[case testMatchProtocolAgainstOverloadWithAmbiguity]
2811+
from typing import TypeVar, Protocol, Union, Generic, overload
2812+
2813+
T = TypeVar("T", covariant=True)
2814+
2815+
class slice: pass
2816+
2817+
class GetItem(Protocol[T]):
2818+
def __getitem__(self, k: int) -> T: ...
2819+
2820+
class Str: # Resembles 'str'
2821+
def __getitem__(self, k: Union[int, slice]) -> Str: ...
2822+
2823+
class Lst(Generic[T]): # Resembles 'list'
2824+
def __init__(self, x: T): ...
2825+
@overload
2826+
def __getitem__(self, k: int) -> T: ...
2827+
@overload
2828+
def __getitem__(self, k: slice) -> Lst[T]: ...
2829+
def __getitem__(self, k): pass
2830+
2831+
def f(x: GetItem[GetItem[Str]]) -> None: ...
2832+
2833+
a: Lst[Str]
2834+
f(Lst(a))
2835+
2836+
class Lst2(Generic[T]):
2837+
def __init__(self, x: T): ...
2838+
# The overload items are tweaked but still compatible
2839+
@overload
2840+
def __getitem__(self, k: Str) -> None: ...
2841+
@overload
2842+
def __getitem__(self, k: slice) -> Lst2[T]: ...
2843+
@overload
2844+
def __getitem__(self, k: Union[int, str]) -> T: ...
2845+
def __getitem__(self, k): pass
2846+
2847+
b: Lst2[Str]
2848+
f(Lst2(b))
2849+
2850+
class Lst3(Generic[T]): # Resembles 'list'
2851+
def __init__(self, x: T): ...
2852+
# The overload items are no longer compatible (too narrow argument type)
2853+
@overload
2854+
def __getitem__(self, k: slice) -> Lst3[T]: ...
2855+
@overload
2856+
def __getitem__(self, k: bool) -> T: ...
2857+
def __getitem__(self, k): pass
2858+
2859+
c: Lst3[Str]
2860+
f(Lst3(c)) # E: Argument 1 to "f" has incompatible type "Lst3[Lst3[Str]]"; expected "GetItem[GetItem[Str]]" \
2861+
# N: Following member(s) of "Lst3[Lst3[Str]]" have conflicts: \
2862+
# N: Expected: \
2863+
# N: def __getitem__(self, int) -> GetItem[Str] \
2864+
# N: Got: \
2865+
# N: @overload \
2866+
# N: def __getitem__(self, slice) -> Lst3[Lst3[Str]] \
2867+
# N: @overload \
2868+
# N: def __getitem__(self, bool) -> Lst3[Str]
2869+
2870+
[builtins fixtures/list.pyi]
2871+
[typing fixtures/typing-full.pyi]
2872+
2873+
[case testMatchProtocolAgainstOverloadWithMultipleMatchingItems]
2874+
from typing import Protocol, overload, TypeVar, Any
2875+
2876+
_T_co = TypeVar("_T_co", covariant=True)
2877+
_T = TypeVar("_T")
2878+
2879+
class SupportsRound(Protocol[_T_co]):
2880+
@overload
2881+
def __round__(self) -> int: ...
2882+
@overload
2883+
def __round__(self, __ndigits: int) -> _T_co: ...
2884+
2885+
class C:
2886+
# This matches both overload items of SupportsRound
2887+
def __round__(self, __ndigits: int = ...) -> int: ...
2888+
2889+
def round(number: SupportsRound[_T], ndigits: int) -> _T: ...
2890+
2891+
round(C(), 1)

0 commit comments

Comments
 (0)