Skip to content

Commit b875205

Browse files
ilevkivskyiJukkaL
authored andcommitted
Narrow types after 'in' operator (#4072)
Fixes #4071.
1 parent 861f8fa commit b875205

File tree

8 files changed

+260
-1
lines changed

8 files changed

+260
-1
lines changed

mypy/checker.py

+47
Original file line numberDiff line numberDiff line change
@@ -2953,6 +2953,39 @@ def remove_optional(typ: Type) -> Type:
29532953
return typ
29542954

29552955

2956+
def builtin_item_type(tp: Type) -> Optional[Type]:
2957+
"""Get the item type of a builtin container.
2958+
2959+
If 'tp' is not one of the built containers (these includes NamedTuple and TypedDict)
2960+
or if the container is not parameterized (like List or List[Any])
2961+
return None. This function is used to narrow optional types in situations like this:
2962+
2963+
x: Optional[int]
2964+
if x in (1, 2, 3):
2965+
x + 42 # OK
2966+
2967+
Note: this is only OK for built-in containers, where we know the behavior
2968+
of __contains__.
2969+
"""
2970+
if isinstance(tp, Instance):
2971+
if tp.type.fullname() in ['builtins.list', 'builtins.tuple', 'builtins.dict',
2972+
'builtins.set', 'builtins.frozenset']:
2973+
if not tp.args:
2974+
# TODO: fix tuple in lib-stub/builtins.pyi (it should be generic).
2975+
return None
2976+
if not isinstance(tp.args[0], AnyType):
2977+
return tp.args[0]
2978+
elif isinstance(tp, TupleType) and all(not isinstance(it, AnyType) for it in tp.items):
2979+
return UnionType.make_simplified_union(tp.items) # this type is not externally visible
2980+
elif isinstance(tp, TypedDictType):
2981+
# TypedDict always has non-optional string keys.
2982+
if tp.fallback.type.fullname() == 'typing.Mapping':
2983+
return tp.fallback.args[0]
2984+
elif tp.fallback.type.bases[0].type.fullname() == 'typing.Mapping':
2985+
return tp.fallback.type.bases[0].args[0]
2986+
return None
2987+
2988+
29562989
def and_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap:
29572990
"""Calculate what information we can learn from the truth of (e1 and e2)
29582991
in terms of the information that we can learn from the truth of e1 and
@@ -3099,6 +3132,20 @@ def find_isinstance_check(node: Expression,
30993132
optional_expr = node.operands[1]
31003133
if is_overlapping_types(optional_type, comp_type):
31013134
return {optional_expr: remove_optional(optional_type)}, {}
3135+
elif node.operators in [['in'], ['not in']]:
3136+
expr = node.operands[0]
3137+
left_type = type_map[expr]
3138+
right_type = builtin_item_type(type_map[node.operands[1]])
3139+
right_ok = right_type and (not is_optional(right_type) and
3140+
(not isinstance(right_type, Instance) or
3141+
right_type.type.fullname() != 'builtins.object'))
3142+
if (right_type and right_ok and is_optional(left_type) and
3143+
literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and
3144+
is_overlapping_types(left_type, right_type)):
3145+
if node.operators == ['in']:
3146+
return {expr: remove_optional(left_type)}, {}
3147+
if node.operators == ['not in']:
3148+
return {}, {expr: remove_optional(left_type)}
31023149
elif isinstance(node, RefExpr):
31033150
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
31043151
# respectively

test-data/unit/check-isinstance.test

+206-1
Original file line numberDiff line numberDiff line change
@@ -1757,7 +1757,6 @@ if isinstance(x, str, 1): # E: Too many arguments for "isinstance"
17571757
reveal_type(x) # E: Revealed type is 'builtins.int'
17581758
[builtins fixtures/isinstancelist.pyi]
17591759

1760-
17611760
[case testIsinstanceNarrowAny]
17621761
from typing import Any
17631762

@@ -1770,3 +1769,209 @@ def narrow_any_to_str_then_reassign_to_int() -> None:
17701769
reveal_type(v) # E: Revealed type is 'Any'
17711770

17721771
[builtins fixtures/isinstance.pyi]
1772+
1773+
[case testNarrowTypeAfterInList]
1774+
# flags: --strict-optional
1775+
from typing import List, Optional
1776+
1777+
x: List[int]
1778+
y: Optional[int]
1779+
1780+
if y in x:
1781+
reveal_type(y) # E: Revealed type is 'builtins.int'
1782+
else:
1783+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1784+
if y not in x:
1785+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1786+
else:
1787+
reveal_type(y) # E: Revealed type is 'builtins.int'
1788+
[builtins fixtures/list.pyi]
1789+
[out]
1790+
1791+
[case testNarrowTypeAfterInListOfOptional]
1792+
# flags: --strict-optional
1793+
from typing import List, Optional
1794+
1795+
x: List[Optional[int]]
1796+
y: Optional[int]
1797+
1798+
if y not in x:
1799+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1800+
else:
1801+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1802+
[builtins fixtures/list.pyi]
1803+
[out]
1804+
1805+
[case testNarrowTypeAfterInListNonOverlapping]
1806+
# flags: --strict-optional
1807+
from typing import List, Optional
1808+
1809+
x: List[str]
1810+
y: Optional[int]
1811+
1812+
if y in x:
1813+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1814+
else:
1815+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1816+
[builtins fixtures/list.pyi]
1817+
[out]
1818+
1819+
[case testNarrowTypeAfterInListNested]
1820+
# flags: --strict-optional
1821+
from typing import List, Optional, Any
1822+
1823+
x: Optional[int]
1824+
lst: Optional[List[int]]
1825+
nested_any: List[List[Any]]
1826+
1827+
if lst in nested_any:
1828+
reveal_type(lst) # E: Revealed type is 'builtins.list[builtins.int]'
1829+
if x in nested_any:
1830+
reveal_type(x) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1831+
[builtins fixtures/list.pyi]
1832+
[out]
1833+
1834+
[case testNarrowTypeAfterInTuple]
1835+
# flags: --strict-optional
1836+
from typing import Optional
1837+
class A: pass
1838+
class B(A): pass
1839+
class C(A): pass
1840+
1841+
y: Optional[B]
1842+
if y in (B(), C()):
1843+
reveal_type(y) # E: Revealed type is '__main__.B'
1844+
else:
1845+
reveal_type(y) # E: Revealed type is 'Union[__main__.B, builtins.None]'
1846+
[builtins fixtures/tuple.pyi]
1847+
[out]
1848+
1849+
[case testNarrowTypeAfterInNamedTuple]
1850+
# flags: --strict-optional
1851+
from typing import NamedTuple, Optional
1852+
class NT(NamedTuple):
1853+
x: int
1854+
y: int
1855+
nt: NT
1856+
1857+
y: Optional[int]
1858+
if y not in nt:
1859+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1860+
else:
1861+
reveal_type(y) # E: Revealed type is 'builtins.int'
1862+
[builtins fixtures/tuple.pyi]
1863+
[out]
1864+
1865+
[case testNarrowTypeAfterInDict]
1866+
# flags: --strict-optional
1867+
from typing import Dict, Optional
1868+
x: Dict[str, int]
1869+
y: Optional[str]
1870+
1871+
if y in x:
1872+
reveal_type(y) # E: Revealed type is 'builtins.str'
1873+
else:
1874+
reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]'
1875+
if y not in x:
1876+
reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]'
1877+
else:
1878+
reveal_type(y) # E: Revealed type is 'builtins.str'
1879+
[builtins fixtures/dict.pyi]
1880+
[out]
1881+
1882+
[case testNarrowTypeAfterInList_python2]
1883+
# flags: --strict-optional
1884+
from typing import List, Optional
1885+
1886+
x = [] # type: List[int]
1887+
y = None # type: Optional[int]
1888+
1889+
# TODO: Fix running tests on Python 2: "Iterator[int]" has no attribute "next"
1890+
if y in x: # type: ignore
1891+
reveal_type(y) # E: Revealed type is 'builtins.int'
1892+
else:
1893+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1894+
if y not in x: # type: ignore
1895+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1896+
else:
1897+
reveal_type(y) # E: Revealed type is 'builtins.int'
1898+
1899+
[builtins_py2 fixtures/python2.pyi]
1900+
[out]
1901+
1902+
[case testNarrowTypeAfterInNoAnyOrObject]
1903+
# flags: --strict-optional
1904+
from typing import Any, List, Optional
1905+
x: List[Any]
1906+
z: List[object]
1907+
1908+
y: Optional[int]
1909+
if y in x:
1910+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1911+
else:
1912+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1913+
1914+
if y not in z:
1915+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1916+
else:
1917+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1918+
[typing fixtures/typing-full.pyi]
1919+
[builtins fixtures/list.pyi]
1920+
[out]
1921+
1922+
[case testNarrowTypeAfterInUserDefined]
1923+
# flags: --strict-optional
1924+
from typing import Container, Optional
1925+
1926+
class C(Container[int]):
1927+
def __contains__(self, item: object) -> bool:
1928+
return item is 'surprise'
1929+
1930+
y: Optional[int]
1931+
# We never trust user defined types
1932+
if y in C():
1933+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1934+
else:
1935+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1936+
if y not in C():
1937+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1938+
else:
1939+
reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]'
1940+
[typing fixtures/typing-full.pyi]
1941+
[builtins fixtures/list.pyi]
1942+
[out]
1943+
1944+
[case testNarrowTypeAfterInSet]
1945+
# flags: --strict-optional
1946+
from typing import Optional, Set
1947+
s: Set[str]
1948+
1949+
y: Optional[str]
1950+
if y in {'a', 'b', 'c'}:
1951+
reveal_type(y) # E: Revealed type is 'builtins.str'
1952+
else:
1953+
reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]'
1954+
if y not in s:
1955+
reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]'
1956+
else:
1957+
reveal_type(y) # E: Revealed type is 'builtins.str'
1958+
[builtins fixtures/set.pyi]
1959+
[out]
1960+
1961+
[case testNarrowTypeAfterInTypedDict]
1962+
# flags: --strict-optional
1963+
from typing import Optional
1964+
from mypy_extensions import TypedDict
1965+
class TD(TypedDict):
1966+
a: int
1967+
b: str
1968+
td: TD
1969+
1970+
def f() -> None:
1971+
x: Optional[str]
1972+
if x not in td:
1973+
return
1974+
reveal_type(x) # E: Revealed type is 'builtins.str'
1975+
[typing fixtures/typing-full.pyi]
1976+
[builtins fixtures/dict.pyi]
1977+
[out]

test-data/unit/fixtures/dict.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class dict(Generic[KT, VT]):
1919
def __getitem__(self, key: KT) -> VT: pass
2020
def __setitem__(self, k: KT, v: VT) -> None: pass
2121
def __iter__(self) -> Iterator[KT]: pass
22+
def __contains__(self, item: object) -> bool: pass
2223
def update(self, a: Mapping[KT, VT]) -> None: pass
2324
@overload
2425
def get(self, k: KT) -> Optional[VT]: pass

test-data/unit/fixtures/list.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class list(Generic[T]):
1616
@overload
1717
def __init__(self, x: Iterable[T]) -> None: pass
1818
def __iter__(self) -> Iterator[T]: pass
19+
def __contains__(self, item: object) -> bool: pass
1920
def __add__(self, x: list[T]) -> list[T]: pass
2021
def __mul__(self, x: int) -> list[T]: pass
2122
def __getitem__(self, x: int) -> T: pass

test-data/unit/fixtures/python2.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class function: pass
1111
class int: pass
1212
class str: pass
1313
class unicode: pass
14+
class bool: pass
1415

1516
T = TypeVar('T')
1617
class list(Iterable[T], Generic[T]): pass

test-data/unit/fixtures/set.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ class function: pass
1313

1414
class int: pass
1515
class str: pass
16+
class bool: pass
1617

1718
class set(Iterable[T], Generic[T]):
1819
def __iter__(self) -> Iterator[T]: pass
20+
def __contains__(self, item: object) -> bool: pass
1921
def add(self, x: T) -> None: pass
2022
def discard(self, x: T) -> None: pass
2123
def update(self, x: Set[T]) -> None: pass

test-data/unit/fixtures/tuple.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class type:
1212
def __call__(self, *a) -> object: pass
1313
class tuple(Sequence[Tco], Generic[Tco]):
1414
def __iter__(self) -> Iterator[Tco]: pass
15+
def __contains__(self, item: object) -> bool: pass
1516
def __getitem__(self, x: int) -> Tco: pass
1617
def count(self, obj: Any) -> int: pass
1718
class function: pass

test-data/unit/fixtures/typing-full.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class Mapping(Iterable[T], Protocol[T, T_co]):
126126
def get(self, k: T, default: Union[T_co, V]) -> Union[T_co, V]: pass
127127
def values(self) -> Iterable[T_co]: pass # Approximate return type
128128
def __len__(self) -> int: ...
129+
def __contains__(self, arg: object) -> int: pass
129130

130131
@runtime
131132
class MutableMapping(Mapping[T, U], Protocol):

0 commit comments

Comments
 (0)