From 8d88223f48eef97147bb991deb0b90bfcbc80b7a Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 26 Jun 2017 11:42:33 +0100 Subject: [PATCH] Support overloading with TypedDict Fix #3609. --- mypy/checkexpr.py | 6 ++ mypy/meet.py | 4 + test-data/unit/check-typeddict.test | 119 ++++++++++++++++++++++++ test-data/unit/fixtures/dict.pyi | 2 +- test-data/unit/fixtures/typing-full.pyi | 4 +- 5 files changed, 132 insertions(+), 3 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 00c95cb2e6d4..92fab7d92a50 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2674,6 +2674,12 @@ def overload_arg_similarity(actual: Type, formal: Type) -> int: return overload_arg_similarity(actual.ret_type, formal.item) else: return 0 + if isinstance(actual, TypedDictType): + if isinstance(formal, TypedDictType): + # Don't support overloading based on the keys or value types of a TypedDict since + # that would be complicated and probably only marginally useful. + return 2 + return overload_arg_similarity(actual.fallback, formal) if isinstance(formal, Instance): if isinstance(actual, CallableType): actual = actual.fallback diff --git a/mypy/meet.py b/mypy/meet.py index f0dcd8b56e34..2f7cbb482ce3 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -84,6 +84,10 @@ class C(A, B): ... t = t.erase_to_union_or_bound() if isinstance(s, TypeVarType): s = s.erase_to_union_or_bound() + if isinstance(t, TypedDictType): + t = t.as_anonymous().fallback + if isinstance(s, TypedDictType): + s = s.as_anonymous().fallback if isinstance(t, Instance): if isinstance(s, Instance): # Consider two classes non-disjoint if one is included in the mro diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index d50cf27344cc..bcf06f754d80 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -1032,6 +1032,125 @@ Point = TypedDict('Point', {'x': 1, 'y': 1}) # E: Invalid field type [builtins fixtures/dict.pyi] +-- Overloading + +[case testTypedDictOverloading] +from typing import overload, Iterable +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int}) + +@overload +def f(x: Iterable[str]) -> str: ... +@overload +def f(x: int) -> int: ... +def f(x): pass + +a: A +reveal_type(f(a)) # E: Revealed type is 'builtins.str' +reveal_type(f(1)) # E: Revealed type is 'builtins.int' +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + +[case testTypedDictOverloading2] +from typing import overload, Iterable +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int}) + +@overload +def f(x: Iterable[int]) -> None: ... +@overload +def f(x: int) -> None: ... +def f(x): pass + +a: A +f(a) # E: Argument 1 to "f" has incompatible type "A"; expected Iterable[int] +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + +[case testTypedDictOverloading3] +from typing import overload +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int}) + +@overload +def f(x: str) -> None: ... +@overload +def f(x: int) -> None: ... +def f(x): pass + +a: A +f(a) # E: No overload variant of "f" matches argument types [TypedDict(x=builtins.int, _fallback=__main__.A)] +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + +[case testTypedDictOverloading4] +from typing import overload +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int}) +B = TypedDict('B', {'x': str}) + +@overload +def f(x: A) -> int: ... +@overload +def f(x: int) -> str: ... +def f(x): pass + +a: A +b: B +reveal_type(f(a)) # E: Revealed type is 'builtins.int' +reveal_type(f(1)) # E: Revealed type is 'builtins.str' +f(b) # E: Argument 1 to "f" has incompatible type "B"; expected "A" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + +[case testTypedDictOverloading5] +from typing import overload +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int}) +B = TypedDict('B', {'y': str}) +C = TypedDict('C', {'y': int}) + +@overload +def f(x: A) -> None: ... +@overload +def f(x: B) -> None: ... +def f(x): pass + +a: A +b: B +c: C +f(a) +f(b) +f(c) # E: Argument 1 to "f" has incompatible type "C"; expected "A" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + +[case testTypedDictOverloading6] +from typing import overload +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int}) +B = TypedDict('B', {'y': str}) + +@overload +def f(x: A) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def f(x: B) -> str: ... +def f(x): pass + +a: A +b: B +reveal_type(f(a)) # E: Revealed type is 'Any' +reveal_type(f(b)) # E: Revealed type is 'Any' +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + + -- Special cases [case testForwardReferenceInTypedDict] diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index e920512274dd..62234c69d0b2 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -11,7 +11,7 @@ class object: class type: pass -class dict(Iterable[KT], Mapping[KT, VT], Generic[KT, VT]): +class dict(Mapping[KT, VT], Iterable[KT], Generic[KT, VT]): @overload def __init__(self, **kwargs: VT) -> None: pass @overload diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi index 463b117db48d..cb632b3de304 100644 --- a/test-data/unit/fixtures/typing-full.pyi +++ b/test-data/unit/fixtures/typing-full.pyi @@ -103,12 +103,12 @@ class Sequence(Iterable[T], Generic[T]): @abstractmethod def __getitem__(self, n: Any) -> T: pass -class Mapping(Generic[T, U]): +class Mapping(Iterable[T], Generic[T, U]): @overload def get(self, k: T) -> Optional[U]: ... @overload def get(self, k: T, default: Union[U, V]) -> Union[U, V]: ... -class MutableMapping(Generic[T, U]): pass +class MutableMapping(Mapping[T, U]): pass TYPE_CHECKING = 1