Skip to content

Commit 648d99a

Browse files
authored
Refine parent type when narrowing "lookup" expressions (#7917)
This diff adds support for the following pattern: ```python from typing import Enum, List from enum import Enum class Key(Enum): A = 1 B = 2 class Foo: key: Literal[Key.A] blah: List[int] class Bar: key: Literal[Key.B] something: List[str] x: Union[Foo, Bar] if x.key is Key.A: reveal_type(x) # Revealed type is 'Foo' else: reveal_type(x) # Revealed type is 'Bar' ``` In short, when we do `x.key is Key.A`, we "propagate" the information we discovered about `x.key` up to refine the type of `x`. We perform this propagation only when `x` is a Union and only when we are doing member or index lookups into instances, typeddicts, namedtuples, and tuples. For indexing operations, we have one additional limitation: we *must* use a literal expression in order for narrowing to work at all. Using Literal types or Final instances won't work; See #7905 for more details. To put it another way, this adds support for tagged unions, I guess. This more or less resolves #7344. We currently don't have support for narrowing based on string or int literals, but that's a separate issue and should be resolved by #7169 (which I resumed work on earlier this week).
1 parent a37ab53 commit 648d99a

File tree

6 files changed

+708
-23
lines changed

6 files changed

+708
-23
lines changed

mypy/checker.py

Lines changed: 186 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from contextlib import contextmanager
66

77
from typing import (
8-
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Sequence
8+
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Sequence,
9+
Mapping,
910
)
1011
from typing_extensions import Final
1112

@@ -42,12 +43,14 @@
4243
)
4344
import mypy.checkexpr
4445
from mypy.checkmember import (
45-
analyze_descriptor_access, type_object_type,
46+
analyze_member_access, analyze_descriptor_access, type_object_type,
4647
)
4748
from mypy.typeops import (
4849
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
49-
erase_def_to_union_or_bound, erase_to_union_or_bound,
50-
true_only, false_only, function_type, TypeVarExtractor
50+
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal,
51+
try_getting_str_literals_from_type, try_getting_int_literals_from_type,
52+
tuple_fallback, is_singleton_type, try_expanding_enum_to_union,
53+
true_only, false_only, function_type, TypeVarExtractor,
5154
)
5255
from mypy import message_registry
5356
from mypy.subtypes import (
@@ -71,9 +74,6 @@
7174
from mypy.plugin import Plugin, CheckerPluginInterface
7275
from mypy.sharedparse import BINARY_MAGIC_METHODS
7376
from mypy.scope import Scope
74-
from mypy.typeops import (
75-
tuple_fallback, coerce_to_literal, is_singleton_type, try_expanding_enum_to_union
76-
)
7777
from mypy import state, errorcodes as codes
7878
from mypy.traverser import has_return_statement, all_return_statements
7979
from mypy.errorcodes import ErrorCode
@@ -3708,6 +3708,12 @@ def find_isinstance_check(self, node: Expression
37083708
37093709
Guaranteed to not return None, None. (But may return {}, {})
37103710
"""
3711+
if_map, else_map = self.find_isinstance_check_helper(node)
3712+
new_if_map = self.propagate_up_typemap_info(self.type_map, if_map)
3713+
new_else_map = self.propagate_up_typemap_info(self.type_map, else_map)
3714+
return new_if_map, new_else_map
3715+
3716+
def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeMap]:
37113717
type_map = self.type_map
37123718
if is_true_literal(node):
37133719
return {}, None
@@ -3834,28 +3840,196 @@ def find_isinstance_check(self, node: Expression
38343840
else None)
38353841
return if_map, else_map
38363842
elif isinstance(node, OpExpr) and node.op == 'and':
3837-
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
3838-
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
3843+
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
3844+
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)
38393845

38403846
# (e1 and e2) is true if both e1 and e2 are true,
38413847
# and false if at least one of e1 and e2 is false.
38423848
return (and_conditional_maps(left_if_vars, right_if_vars),
38433849
or_conditional_maps(left_else_vars, right_else_vars))
38443850
elif isinstance(node, OpExpr) and node.op == 'or':
3845-
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
3846-
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
3851+
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
3852+
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)
38473853

38483854
# (e1 or e2) is true if at least one of e1 or e2 is true,
38493855
# and false if both e1 and e2 are false.
38503856
return (or_conditional_maps(left_if_vars, right_if_vars),
38513857
and_conditional_maps(left_else_vars, right_else_vars))
38523858
elif isinstance(node, UnaryExpr) and node.op == 'not':
3853-
left, right = self.find_isinstance_check(node.expr)
3859+
left, right = self.find_isinstance_check_helper(node.expr)
38543860
return right, left
38553861

38563862
# Not a supported isinstance check
38573863
return {}, {}
38583864

3865+
def propagate_up_typemap_info(self,
3866+
existing_types: Mapping[Expression, Type],
3867+
new_types: TypeMap) -> TypeMap:
3868+
"""Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
3869+
3870+
Specifically, this function accepts two mappings of expression to original types:
3871+
the original mapping (existing_types), and a new mapping (new_types) intended to
3872+
update the original.
3873+
3874+
This function iterates through new_types and attempts to use the information to try
3875+
refining any parent types that happen to be unions.
3876+
3877+
For example, suppose there are two types "A = Tuple[int, int]" and "B = Tuple[str, str]".
3878+
Next, suppose that 'new_types' specifies the expression 'foo[0]' has a refined type
3879+
of 'int' and that 'foo' was previously deduced to be of type Union[A, B].
3880+
3881+
Then, this function will observe that since A[0] is an int and B[0] is not, the type of
3882+
'foo' can be further refined from Union[A, B] into just B.
3883+
3884+
We perform this kind of "parent narrowing" for member lookup expressions and indexing
3885+
expressions into tuples, namedtuples, and typeddicts. We repeat this narrowing
3886+
recursively if the parent is also a "lookup expression". So for example, if we have
3887+
the expression "foo['bar'].baz[0]", we'd potentially end up refining types for the
3888+
expressions "foo", "foo['bar']", and "foo['bar'].baz".
3889+
3890+
We return the newly refined map. This map is guaranteed to be a superset of 'new_types'.
3891+
"""
3892+
if new_types is None:
3893+
return None
3894+
output_map = {}
3895+
for expr, expr_type in new_types.items():
3896+
# The original inferred type should always be present in the output map, of course
3897+
output_map[expr] = expr_type
3898+
3899+
# Next, try using this information to refine the parent types, if applicable.
3900+
new_mapping = self.refine_parent_types(existing_types, expr, expr_type)
3901+
for parent_expr, proposed_parent_type in new_mapping.items():
3902+
# We don't try inferring anything if we've already inferred something for
3903+
# the parent expression.
3904+
# TODO: Consider picking the narrower type instead of always discarding this?
3905+
if parent_expr in new_types:
3906+
continue
3907+
output_map[parent_expr] = proposed_parent_type
3908+
return output_map
3909+
3910+
def refine_parent_types(self,
3911+
existing_types: Mapping[Expression, Type],
3912+
expr: Expression,
3913+
expr_type: Type) -> Mapping[Expression, Type]:
3914+
"""Checks if the given expr is a 'lookup operation' into a union and iteratively refines
3915+
the parent types based on the 'expr_type'.
3916+
3917+
For example, if 'expr' is an expression like 'a.b.c.d', we'll potentially return refined
3918+
types for expressions 'a', 'a.b', and 'a.b.c'.
3919+
3920+
For more details about what a 'lookup operation' is and how we use the expr_type to refine
3921+
the parent types of lookup_expr, see the docstring in 'propagate_up_typemap_info'.
3922+
"""
3923+
output = {} # type: Dict[Expression, Type]
3924+
3925+
# Note: parent_expr and parent_type are progressively refined as we crawl up the
3926+
# parent lookup chain.
3927+
while True:
3928+
# First, check if this expression is one that's attempting to
3929+
# "lookup" some key in the parent type. If so, save the parent type
3930+
# and create function that will try replaying the same lookup
3931+
# operation against arbitrary types.
3932+
if isinstance(expr, MemberExpr):
3933+
parent_expr = expr.expr
3934+
parent_type = existing_types.get(parent_expr)
3935+
member_name = expr.name
3936+
3937+
def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
3938+
msg_copy = self.msg.clean_copy()
3939+
msg_copy.disable_count = 0
3940+
member_type = analyze_member_access(
3941+
name=member_name,
3942+
typ=new_parent_type,
3943+
context=parent_expr,
3944+
is_lvalue=False,
3945+
is_super=False,
3946+
is_operator=False,
3947+
msg=msg_copy,
3948+
original_type=new_parent_type,
3949+
chk=self,
3950+
in_literal_context=False,
3951+
)
3952+
if msg_copy.is_errors():
3953+
return None
3954+
else:
3955+
return member_type
3956+
elif isinstance(expr, IndexExpr):
3957+
parent_expr = expr.base
3958+
parent_type = existing_types.get(parent_expr)
3959+
3960+
index_type = existing_types.get(expr.index)
3961+
if index_type is None:
3962+
return output
3963+
3964+
str_literals = try_getting_str_literals_from_type(index_type)
3965+
if str_literals is not None:
3966+
# Refactoring these two indexing replay functions is surprisingly
3967+
# tricky -- see https://github.com/python/mypy/pull/7917, which
3968+
# was blocked by https://github.com/mypyc/mypyc/issues/586
3969+
def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
3970+
if not isinstance(new_parent_type, TypedDictType):
3971+
return None
3972+
try:
3973+
assert str_literals is not None
3974+
member_types = [new_parent_type.items[key] for key in str_literals]
3975+
except KeyError:
3976+
return None
3977+
return make_simplified_union(member_types)
3978+
else:
3979+
int_literals = try_getting_int_literals_from_type(index_type)
3980+
if int_literals is not None:
3981+
def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
3982+
if not isinstance(new_parent_type, TupleType):
3983+
return None
3984+
try:
3985+
assert int_literals is not None
3986+
member_types = [new_parent_type.items[key] for key in int_literals]
3987+
except IndexError:
3988+
return None
3989+
return make_simplified_union(member_types)
3990+
else:
3991+
return output
3992+
else:
3993+
return output
3994+
3995+
# If we somehow didn't previously derive the parent type, abort completely
3996+
# with what we have so far: something went wrong at an earlier stage.
3997+
if parent_type is None:
3998+
return output
3999+
4000+
# We currently only try refining the parent type if it's a Union.
4001+
# If not, there's no point in trying to refine any further parents
4002+
# since we have no further information we can use to refine the lookup
4003+
# chain, so we end early as an optimization.
4004+
parent_type = get_proper_type(parent_type)
4005+
if not isinstance(parent_type, UnionType):
4006+
return output
4007+
4008+
# Take each element in the parent union and replay the original lookup procedure
4009+
# to figure out which parents are compatible.
4010+
new_parent_types = []
4011+
for item in parent_type.items:
4012+
item = get_proper_type(item)
4013+
member_type = replay_lookup(item)
4014+
if member_type is None:
4015+
# We were unable to obtain the member type. So, we give up on refining this
4016+
# parent type entirely and abort.
4017+
return output
4018+
4019+
if is_overlapping_types(member_type, expr_type):
4020+
new_parent_types.append(item)
4021+
4022+
# If none of the parent types overlap (if we derived an empty union), something
4023+
# went wrong. We should never hit this case, but deriving the uninhabited type or
4024+
# reporting an error both seem unhelpful. So we abort.
4025+
if not new_parent_types:
4026+
return output
4027+
4028+
expr = parent_expr
4029+
expr_type = output[parent_expr] = make_simplified_union(new_parent_types)
4030+
4031+
return output
4032+
38594033
#
38604034
# Helpers
38614035
#

mypy/checkexpr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2709,6 +2709,9 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
27092709
index = e.index
27102710
left_type = get_proper_type(left_type)
27112711

2712+
# Visit the index, just to make sure we have a type for it available
2713+
self.accept(index)
2714+
27122715
if isinstance(left_type, UnionType):
27132716
original_type = original_type or left_type
27142717
return make_simplified_union([self.visit_index_with_type(typ, e,

mypy/test/testcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
'check-isinstance.test',
4747
'check-lists.test',
4848
'check-namedtuple.test',
49+
'check-narrowing.test',
4950
'check-typeddict.test',
5051
'check-type-aliases.test',
5152
'check-ignore.test',

mypy/typeops.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
since these may assume that MROs are ready.
66
"""
77

8-
from typing import cast, Optional, List, Sequence, Set, Iterable
8+
from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar
9+
from typing_extensions import Type as TypingType
910
import sys
1011

1112
from mypy.types import (
1213
TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded,
13-
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType,
14+
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, TypedDictType,
1415
AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types,
1516
copy_type, TypeAliasType, TypeQuery
1617
)
1718
from mypy.nodes import (
18-
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, TypeVar, ARG_STAR, ARG_STAR2, ARG_POS,
19+
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS,
1920
Expression, StrExpr, Var
2021
)
2122
from mypy.maptype import map_instance_to_supertype
@@ -43,6 +44,25 @@ def tuple_fallback(typ: TupleType) -> Instance:
4344
return Instance(info, [join_type_list(typ.items)])
4445

4546

47+
def try_getting_instance_fallback(typ: ProperType) -> Optional[Instance]:
48+
"""Returns the Instance fallback for this type if one exists.
49+
50+
Otherwise, returns None.
51+
"""
52+
if isinstance(typ, Instance):
53+
return typ
54+
elif isinstance(typ, TupleType):
55+
return tuple_fallback(typ)
56+
elif isinstance(typ, TypedDictType):
57+
return typ.fallback
58+
elif isinstance(typ, FunctionLike):
59+
return typ.fallback
60+
elif isinstance(typ, LiteralType):
61+
return typ.fallback
62+
else:
63+
return None
64+
65+
4666
def type_object_type_from_function(signature: FunctionLike,
4767
info: TypeInfo,
4868
def_info: TypeInfo,
@@ -481,27 +501,66 @@ def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]]
481501
2. 'typ' is a LiteralType containing a string
482502
3. 'typ' is a UnionType containing only LiteralType of strings
483503
"""
484-
typ = get_proper_type(typ)
485-
486504
if isinstance(expr, StrExpr):
487505
return [expr.value]
488506

507+
# TODO: See if we can eliminate this function and call the below one directly
508+
return try_getting_str_literals_from_type(typ)
509+
510+
511+
def try_getting_str_literals_from_type(typ: Type) -> Optional[List[str]]:
512+
"""If the given expression or type corresponds to a string Literal
513+
or a union of string Literals, returns a list of the underlying strings.
514+
Otherwise, returns None.
515+
516+
For example, if we had the type 'Literal["foo", "bar"]' as input, this function
517+
would return a list of strings ["foo", "bar"].
518+
"""
519+
return try_getting_literals_from_type(typ, str, "builtins.str")
520+
521+
522+
def try_getting_int_literals_from_type(typ: Type) -> Optional[List[int]]:
523+
"""If the given expression or type corresponds to an int Literal
524+
or a union of int Literals, returns a list of the underlying ints.
525+
Otherwise, returns None.
526+
527+
For example, if we had the type 'Literal[1, 2, 3]' as input, this function
528+
would return a list of ints [1, 2, 3].
529+
"""
530+
return try_getting_literals_from_type(typ, int, "builtins.int")
531+
532+
533+
T = TypeVar('T')
534+
535+
536+
def try_getting_literals_from_type(typ: Type,
537+
target_literal_type: TypingType[T],
538+
target_fullname: str) -> Optional[List[T]]:
539+
"""If the given expression or type corresponds to a Literal or
540+
union of Literals where the underlying values corresponds to the given
541+
target type, returns a list of those underlying values. Otherwise,
542+
returns None.
543+
"""
544+
typ = get_proper_type(typ)
545+
489546
if isinstance(typ, Instance) and typ.last_known_value is not None:
490547
possible_literals = [typ.last_known_value] # type: List[Type]
491548
elif isinstance(typ, UnionType):
492549
possible_literals = list(typ.items)
493550
else:
494551
possible_literals = [typ]
495552

496-
strings = []
553+
literals = [] # type: List[T]
497554
for lit in get_proper_types(possible_literals):
498-
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str':
555+
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == target_fullname:
499556
val = lit.value
500-
assert isinstance(val, str)
501-
strings.append(val)
557+
if isinstance(val, target_literal_type):
558+
literals.append(val)
559+
else:
560+
return None
502561
else:
503562
return None
504-
return strings
563+
return literals
505564

506565

507566
def get_enum_values(typ: Instance) -> List[str]:

0 commit comments

Comments
 (0)