Skip to content

Commit c68ddc9

Browse files
committed
Force enum literals to simplify when inferring unions
While working on overhauling #7169, I discovered that simply just "deconstructing" enums into unions leads to some false positives in some real-world code. This is an existing problem, but became more prominent as I worked on improving type inference in the above PR. Here's a simplified example of one such problem I ran into: ``` from enum import Enum class Foo(Enum): A = 1 B = 2 class Wrapper: def __init__(self, x: bool, y: Foo) -> None: if x: if y is Foo.A: # 'y' is of type Literal[Foo.A] here pass else: # ...and of type Literal[Foo.B] here pass # We join these two types after the if/else to end up with # Literal[Foo.A, Foo.B] self.y = y else: # ...and so this fails! 'Foo' is not considered a subtype of # 'Literal[Foo.A, Foo.B]' self.y = y ``` I considered three different ways of fixing this: 1. Modify our various type comparison operations (`is_same`, `is_subtype`, `is_proper_subtype`, etc...) to consider `Foo` and `Literal[Foo.A, Foo.B]` equivalent. 2. Modify the 'join' logic so that when we join enum literals, we check and see if we can merge them back into the original class, undoing the "deconstruction". 3. Modify the `make_simplified_union` logic to do the reconstruction instead. I rejected the first two options: the first approach is the most sound one, but seemed complicated to implement. We have a lot of different type comparison operations and attempting to modify them all seems error-prone. I also didn't really like the idea of having two equally valid representations of the same type, and would rather push mypy to always standardize on one, just from a usability point of view. The second option seemed workable but limited to me. Modifying join would fix the specific example above, but I wasn't confident that was the only place we'd needed to patch. So I went with modifying `make_simplified_union` instead. The main disadvantage of this approach is that we still get false positives when working with Unions that come directly from the semantic analysis phase. For example, we still get an error with the following program: x: Literal[Foo.A, Foo.B] y: Foo # Error, we still think 'x' is of type 'Literal[Foo.A, Foo.B]' x = y But I think this is an acceptable tradeoff for now: I can't imagine too many people running into this. But if they do, we can always explore finding a way of simplifying unions after the semantic analysis phase or bite the bullet and implement approach 1.
1 parent a94e649 commit c68ddc9

File tree

3 files changed

+158
-18
lines changed

3 files changed

+158
-18
lines changed

mypy/checker.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from mypy.typeops import (
5151
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
5252
erase_def_to_union_or_bound, erase_to_union_or_bound,
53-
true_only, false_only, function_type,
53+
true_only, false_only, function_type, get_enum_values,
5454
)
5555
from mypy import message_registry
5656
from mypy.subtypes import (
@@ -4766,11 +4766,6 @@ def is_private(node_name: str) -> bool:
47664766
return node_name.startswith('__') and not node_name.endswith('__')
47674767

47684768

4769-
def get_enum_values(typ: Instance) -> List[str]:
4770-
"""Return the list of values for an Enum."""
4771-
return [name for name, sym in typ.type.names.items() if isinstance(sym.node, Var)]
4772-
4773-
47744769
def is_singleton_type(typ: Type) -> bool:
47754770
"""Returns 'true' if this type is a "singleton type" -- if there exists
47764771
exactly only one runtime value associated with this type.
@@ -4819,7 +4814,7 @@ class Status(Enum):
48194814

48204815
if isinstance(typ, UnionType):
48214816
items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items]
4822-
return make_simplified_union(items)
4817+
return UnionType.make_union(items)
48234818
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname:
48244819
new_items = []
48254820
for name, symbol in typ.type.names.items():
@@ -4834,7 +4829,7 @@ class Status(Enum):
48344829
# only using CPython, but we might as well for the sake of full correctness.
48354830
if sys.version_info < (3, 7):
48364831
new_items.sort(key=lambda lit: lit.value)
4837-
return make_simplified_union(new_items)
4832+
return UnionType.make_union(new_items)
48384833
else:
48394834
return typ
48404835

@@ -4846,7 +4841,7 @@ def coerce_to_literal(typ: Type) -> ProperType:
48464841
typ = get_proper_type(typ)
48474842
if isinstance(typ, UnionType):
48484843
new_items = [coerce_to_literal(item) for item in typ.items]
4849-
return make_simplified_union(new_items)
4844+
return UnionType.make_union(new_items)
48504845
elif isinstance(typ, Instance):
48514846
if typ.last_known_value:
48524847
return typ.last_known_value

mypy/typeops.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
since these may assume that MROs are ready.
66
"""
77

8-
from typing import cast, Optional, List, Sequence, Set
8+
from typing import cast, Optional, List, Sequence, Set, Dict
99

1010
from mypy.types import (
1111
TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded,
@@ -14,8 +14,8 @@
1414
copy_type
1515
)
1616
from mypy.nodes import (
17-
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, TypeVar, ARG_STAR, ARG_STAR2, Expression,
18-
StrExpr, ARG_POS
17+
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, TypeVar, ARG_STAR, ARG_STAR2, ARG_POS,
18+
Expression, StrExpr, Var,
1919
)
2020
from mypy.maptype import map_instance_to_supertype
2121
from mypy.expandtype import expand_type_by_instance, expand_type
@@ -293,6 +293,11 @@ def make_simplified_union(items: Sequence[Type],
293293
* [int, int] -> int
294294
* [int, Any] -> Union[int, Any] (Any types are not simplified away!)
295295
* [Any, Any] -> Any
296+
* [Literal[Foo.A], Literal[Foo.B]] -> Foo (assuming Foo is a enum with two variants A and B)
297+
298+
Note that we only collapse enum literals into the original enum when all literal variants
299+
are present. Since enums are effectively final and there are a fixed number of possible
300+
variants, it's safe to treat those two types as equivalent.
296301
297302
Note: This must NOT be used during semantic analysis, since TypeInfos may not
298303
be fully initialized.
@@ -309,6 +314,8 @@ def make_simplified_union(items: Sequence[Type],
309314

310315
from mypy.subtypes import is_proper_subtype
311316

317+
enums_found = {} # type: Dict[str, int]
318+
enum_max_members = {} # type: Dict[str, int]
312319
removed = set() # type: Set[int]
313320
for i, ti in enumerate(items):
314321
if i in removed: continue
@@ -320,13 +327,52 @@ def make_simplified_union(items: Sequence[Type],
320327
removed.add(j)
321328
cbt = cbt or tj.can_be_true
322329
cbf = cbf or tj.can_be_false
330+
323331
# if deleted subtypes had more general truthiness, use that
324332
if not ti.can_be_true and cbt:
325-
items[i] = true_or_false(ti)
333+
items[i] = ti = true_or_false(ti)
326334
elif not ti.can_be_false and cbf:
327-
items[i] = true_or_false(ti)
335+
items[i] = ti = true_or_false(ti)
336+
337+
# Keep track of all enum Literal types we encounter, in case
338+
# we can coalesce them together
339+
if isinstance(ti, LiteralType) and ti.is_enum_literal():
340+
enum_name = ti.fallback.type.fullname()
341+
if enum_name not in enum_max_members:
342+
enum_max_members[enum_name] = len(get_enum_values(ti.fallback))
343+
enums_found[enum_name] = enums_found.get(enum_name, 0) + 1
344+
if isinstance(ti, Instance) and ti.type.is_enum:
345+
enum_name = ti.type.fullname()
346+
if enum_name not in enum_max_members:
347+
enum_max_members[enum_name] = len(get_enum_values(ti))
348+
enums_found[enum_name] = enum_max_members[enum_name]
349+
350+
enums_to_compress = {n for (n, c) in enums_found.items() if c >= enum_max_members[n]}
351+
enums_encountered = set() # type: Set[str]
352+
simplified_set = [] # type: List[ProperType]
353+
for i, item in enumerate(items):
354+
if i in removed:
355+
continue
356+
357+
# Try seeing if this is an enum or enum literal, and if it's
358+
# one we should be collapsing away.
359+
if isinstance(item, LiteralType):
360+
instance = item.fallback # type: Optional[Instance]
361+
elif isinstance(item, Instance):
362+
instance = item
363+
else:
364+
instance = None
365+
366+
if instance and instance.type.is_enum:
367+
enum_name = instance.type.fullname()
368+
if enum_name in enums_encountered:
369+
continue
370+
if enum_name in enums_to_compress:
371+
simplified_set.append(instance)
372+
enums_encountered.add(enum_name)
373+
continue
374+
simplified_set.append(item)
328375

329-
simplified_set = [items[i] for i in range(len(items)) if i not in removed]
330376
return UnionType.make_union(simplified_set, line, column)
331377

332378

@@ -489,3 +535,8 @@ def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]]
489535
else:
490536
return None
491537
return strings
538+
539+
540+
def get_enum_values(typ: Instance) -> List[str]:
541+
"""Return the list of values for an Enum."""
542+
return [name for name, sym in typ.type.names.items() if isinstance(sym.node, Var)]

test-data/unit/check-enum.test

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ elif x is Foo.C:
629629
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
630630
else:
631631
reveal_type(x) # No output here: this branch is unreachable
632+
reveal_type(x) # N: Revealed type is '__main__.Foo'
632633

633634
if Foo.A is x:
634635
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -638,6 +639,7 @@ elif Foo.C is x:
638639
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
639640
else:
640641
reveal_type(x) # No output here: this branch is unreachable
642+
reveal_type(x) # N: Revealed type is '__main__.Foo'
641643

642644
y: Foo
643645
if y is Foo.A:
@@ -648,6 +650,7 @@ elif y is Foo.C:
648650
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
649651
else:
650652
reveal_type(y) # No output here: this branch is unreachable
653+
reveal_type(y) # N: Revealed type is '__main__.Foo'
651654

652655
if Foo.A is y:
653656
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -657,6 +660,7 @@ elif Foo.C is y:
657660
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
658661
else:
659662
reveal_type(y) # No output here: this branch is unreachable
663+
reveal_type(y) # N: Revealed type is '__main__.Foo'
660664
[builtins fixtures/bool.pyi]
661665

662666
[case testEnumReachabilityChecksIndirect]
@@ -686,6 +690,8 @@ if y is x:
686690
else:
687691
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
688692
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
693+
reveal_type(x) # N: Revealed type is '__main__.Foo'
694+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
689695

690696
if x is z:
691697
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -703,6 +709,8 @@ else:
703709
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
704710
reveal_type(z) # N: Revealed type is '__main__.Foo*'
705711
accepts_foo_a(z)
712+
reveal_type(x) # N: Revealed type is '__main__.Foo'
713+
reveal_type(z) # N: Revealed type is '__main__.Foo*'
706714

707715
if y is z:
708716
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -718,6 +726,8 @@ if z is y:
718726
else:
719727
reveal_type(y) # No output: this branch is unreachable
720728
reveal_type(z) # No output: this branch is unreachable
729+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
730+
reveal_type(z) # N: Revealed type is '__main__.Foo*'
721731
[builtins fixtures/bool.pyi]
722732

723733
[case testEnumReachabilityNoNarrowingForUnionMessiness]
@@ -740,13 +750,17 @@ if x is y:
740750
else:
741751
reveal_type(x) # N: Revealed type is '__main__.Foo'
742752
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
753+
reveal_type(x) # N: Revealed type is '__main__.Foo'
754+
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
743755

744756
if y is z:
745757
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
746758
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
747759
else:
748760
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
749761
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
762+
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
763+
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
750764
[builtins fixtures/bool.pyi]
751765

752766
[case testEnumReachabilityWithNone]
@@ -764,16 +778,19 @@ if x:
764778
reveal_type(x) # N: Revealed type is '__main__.Foo'
765779
else:
766780
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
781+
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
767782

768783
if x is not None:
769784
reveal_type(x) # N: Revealed type is '__main__.Foo'
770785
else:
771786
reveal_type(x) # N: Revealed type is 'None'
787+
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
772788

773789
if x is Foo.A:
774790
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
775791
else:
776792
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]'
793+
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
777794
[builtins fixtures/bool.pyi]
778795

779796
[case testEnumReachabilityWithMultipleEnums]
@@ -793,18 +810,21 @@ if x1 is Foo.A:
793810
reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]'
794811
else:
795812
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]'
813+
reveal_type(x1) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
796814

797815
x2: Union[Foo, Bar]
798816
if x2 is Bar.A:
799817
reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]'
800818
else:
801819
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]'
820+
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
802821

803822
x3: Union[Foo, Bar]
804823
if x3 is Foo.A or x3 is Bar.A:
805824
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]'
806825
else:
807826
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]'
827+
reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
808828

809829
[builtins fixtures/bool.pyi]
810830

@@ -823,7 +843,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
823843
# E: Unsupported left operand type for + ("Empty") \
824844
# N: Left operand is of type "Union[int, None, Empty]"
825845
if x is _empty:
826-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
846+
reveal_type(x) # N: Revealed type is '__main__.Empty'
827847
return 0
828848
elif x is None:
829849
reveal_type(x) # N: Revealed type is 'None'
@@ -870,7 +890,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
870890
# E: Unsupported left operand type for + ("Empty") \
871891
# N: Left operand is of type "Union[int, None, Empty]"
872892
if x is _empty:
873-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
893+
reveal_type(x) # N: Revealed type is '__main__.Empty'
874894
return 0
875895
elif x is None:
876896
reveal_type(x) # N: Revealed type is 'None'
@@ -899,7 +919,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
899919
# E: Unsupported left operand type for + ("Empty") \
900920
# N: Left operand is of type "Union[int, None, Empty]"
901921
if x is _empty:
902-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
922+
reveal_type(x) # N: Revealed type is '__main__.Empty'
903923
return 0
904924
elif x is None:
905925
reveal_type(x) # N: Revealed type is 'None'
@@ -908,3 +928,77 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
908928
reveal_type(x) # N: Revealed type is 'builtins.int'
909929
return x + 2
910930
[builtins fixtures/primitives.pyi]
931+
932+
[case testEnumUnionCompression]
933+
from typing import Union
934+
from typing_extensions import Literal
935+
from enum import Enum
936+
937+
class Foo(Enum):
938+
A = 1
939+
B = 2
940+
C = 3
941+
942+
class Bar(Enum):
943+
X = 1
944+
Y = 2
945+
946+
x1: Literal[Foo.A, Foo.B, Foo.B, Foo.B, 1, None]
947+
assert x1 is not None
948+
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B], Literal[1]]'
949+
950+
x2: Literal[1, Foo.A, Foo.B, Foo.C, None]
951+
assert x2 is not None
952+
reveal_type(x2) # N: Revealed type is 'Union[Literal[1], __main__.Foo]'
953+
954+
x3: Literal[Foo.A, Foo.B, 1, Foo.C, Foo.C, Foo.C, None]
955+
assert x3 is not None
956+
reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, Literal[1]]'
957+
958+
x4: Literal[Foo.A, Foo.B, Foo.C, Foo.C, Foo.C, None]
959+
assert x4 is not None
960+
reveal_type(x4) # N: Revealed type is '__main__.Foo'
961+
962+
x5: Union[Literal[Foo.A], Foo, None]
963+
assert x5 is not None
964+
reveal_type(x5) # N: Revealed type is '__main__.Foo'
965+
966+
x6: Literal[Foo.A, Bar.X, Foo.B, Bar.Y, Foo.C, None]
967+
assert x6 is not None
968+
reveal_type(x6) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
969+
970+
# TODO: We should really simplify this down into just 'Bar' as well.
971+
no_forcing: Literal[Bar.X, Bar.X, Bar.Y]
972+
reveal_type(no_forcing) # N: Revealed type is 'Union[Literal[__main__.Bar.X], Literal[__main__.Bar.X], Literal[__main__.Bar.Y]]'
973+
974+
[case testEnumUnionCompressionAssignment]
975+
from typing_extensions import Literal
976+
from enum import Enum
977+
978+
class Foo(Enum):
979+
A = 1
980+
B = 2
981+
982+
class Wrapper1:
983+
def __init__(self, x: object, y: Foo) -> None:
984+
if x:
985+
if y is Foo.A:
986+
pass
987+
else:
988+
pass
989+
self.y = y
990+
else:
991+
self.y = y
992+
reveal_type(self.y) # N: Revealed type is '__main__.Foo'
993+
994+
class Wrapper2:
995+
def __init__(self, x: object, y: Foo) -> None:
996+
if x:
997+
self.y = y
998+
else:
999+
if y is Foo.A:
1000+
pass
1001+
else:
1002+
pass
1003+
self.y = y
1004+
reveal_type(self.y) # N: Revealed type is '__main__.Foo'

0 commit comments

Comments
 (0)