@@ -1188,14 +1188,14 @@ def check_overload_call(self,
1188
1188
# gives a narrower type.
1189
1189
if unioned_return :
1190
1190
returns , inferred_types = zip (* unioned_return )
1191
- # Note that we use `union_overload_matches ` instead of just returning
1191
+ # Note that we use `combine_function_signatures ` instead of just returning
1192
1192
# a union of inferred callables because for example a call
1193
1193
# Union[int -> int, str -> str](Union[int, str]) is invalid and
1194
1194
# we don't want to introduce internal inconsistencies.
1195
1195
unioned_result = (UnionType .make_simplified_union (list (returns ),
1196
1196
context .line ,
1197
1197
context .column ),
1198
- self .union_overload_matches (inferred_types ))
1198
+ self .combine_function_signatures (inferred_types ))
1199
1199
1200
1200
# Step 3: We try checking each branch one-by-one.
1201
1201
inferred_result = self .infer_overload_return_type (plausible_targets , args , arg_types ,
@@ -1492,8 +1492,8 @@ def type_overrides_set(self, exprs: Sequence[Expression],
1492
1492
for expr in exprs :
1493
1493
del self .type_overrides [expr ]
1494
1494
1495
- def union_overload_matches (self , types : Sequence [Type ]) -> Union [AnyType , CallableType ]:
1496
- """Accepts a list of overload signatures and attempts to combine them together into a
1495
+ def combine_function_signatures (self , types : Sequence [Type ]) -> Union [AnyType , CallableType ]:
1496
+ """Accepts a list of function signatures and attempts to combine them together into a
1497
1497
new CallableType consisting of the union of all of the given arguments and return types.
1498
1498
1499
1499
If there is at least one non-callable type, return Any (this can happen if there is
@@ -1507,7 +1507,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab
1507
1507
return callables [0 ]
1508
1508
1509
1509
# Note: we are assuming here that if a user uses some TypeVar 'T' in
1510
- # two different overloads , they meant for that TypeVar to mean the
1510
+ # two different functions , they meant for that TypeVar to mean the
1511
1511
# same thing.
1512
1512
#
1513
1513
# This function will make sure that all instances of that TypeVar 'T'
@@ -1525,7 +1525,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab
1525
1525
1526
1526
too_complex = False
1527
1527
for target in callables :
1528
- # We fall back to Callable[..., Union[<returns>]] if the overloads do not have
1528
+ # We fall back to Callable[..., Union[<returns>]] if the functions do not have
1529
1529
# the exact same signature. The only exception is if one arg is optional and
1530
1530
# the other is positional: in that case, we continue unioning (and expect a
1531
1531
# positional arg).
@@ -1820,19 +1820,12 @@ def check_op_reversible(self,
1820
1820
left_expr : Expression ,
1821
1821
right_type : Type ,
1822
1822
right_expr : Expression ,
1823
- context : Context ) -> Tuple [Type , Type ]:
1824
- # Note: this kludge exists mostly to maintain compatibility with
1825
- # existing error messages. Apparently, if the left-hand-side is a
1826
- # union and we have a type mismatch, we print out a special,
1827
- # abbreviated error message. (See messages.unsupported_operand_types).
1828
- unions_present = isinstance (left_type , UnionType )
1829
-
1823
+ context : Context ,
1824
+ msg : MessageBuilder ) -> Tuple [Type , Type ]:
1830
1825
def make_local_errors () -> MessageBuilder :
1831
1826
"""Creates a new MessageBuilder object."""
1832
- local_errors = self . msg .clean_copy ()
1827
+ local_errors = msg .clean_copy ()
1833
1828
local_errors .disable_count = 0
1834
- if unions_present :
1835
- local_errors .disable_type_names += 1
1836
1829
return local_errors
1837
1830
1838
1831
def lookup_operator (op_name : str , base_type : Type ) -> Optional [Type ]:
@@ -2006,30 +1999,101 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
2006
1999
# TODO: Remove this extra case
2007
2000
return result
2008
2001
2009
- self . msg .add_errors (errors [0 ])
2002
+ msg .add_errors (errors [0 ])
2010
2003
if len (results ) == 1 :
2011
2004
return results [0 ]
2012
2005
else :
2013
2006
error_any = AnyType (TypeOfAny .from_error )
2014
2007
result = error_any , error_any
2015
2008
return result
2016
2009
2017
- def check_op (self , method : str , base_type : Type , arg : Expression ,
2018
- context : Context ,
2010
+ def check_op (self , method : str , base_type : Type ,
2011
+ arg : Expression , context : Context ,
2019
2012
allow_reverse : bool = False ) -> Tuple [Type , Type ]:
2020
2013
"""Type check a binary operation which maps to a method call.
2021
2014
2022
2015
Return tuple (result type, inferred operator method type).
2023
2016
"""
2024
2017
2025
2018
if allow_reverse :
2026
- return self .check_op_reversible (
2027
- op_name = method ,
2028
- left_type = base_type ,
2029
- left_expr = TempNode (base_type ),
2030
- right_type = self .accept (arg ),
2031
- right_expr = arg ,
2032
- context = context )
2019
+ left_variants = [base_type ]
2020
+ if isinstance (base_type , UnionType ):
2021
+ left_variants = [item for item in base_type .relevant_items ()]
2022
+ right_type = self .accept (arg )
2023
+
2024
+ # Step 1: We first try leaving the right arguments alone and destructure
2025
+ # just the left ones. (Mypy can sometimes perform some more precise inference
2026
+ # if we leave the right operands a union -- see testOperatorWithEmptyListAndSum.
2027
+ msg = self .msg .clean_copy ()
2028
+ msg .disable_count = 0
2029
+ all_results = []
2030
+ all_inferred = []
2031
+
2032
+ for left_possible_type in left_variants :
2033
+ result , inferred = self .check_op_reversible (
2034
+ op_name = method ,
2035
+ left_type = left_possible_type ,
2036
+ left_expr = TempNode (left_possible_type ),
2037
+ right_type = right_type ,
2038
+ right_expr = arg ,
2039
+ context = context ,
2040
+ msg = msg )
2041
+ all_results .append (result )
2042
+ all_inferred .append (inferred )
2043
+
2044
+ if not msg .is_errors ():
2045
+ results_final = UnionType .make_simplified_union (all_results )
2046
+ inferred_final = UnionType .make_simplified_union (all_inferred )
2047
+ return results_final , inferred_final
2048
+
2049
+ # Step 2: If that fails, we try again but also destructure the right argument.
2050
+ # This is also necessary to make certain edge cases work -- see
2051
+ # testOperatorDoubleUnionInterwovenUnionAdd, for example.
2052
+
2053
+ # Note: We want to pass in the original 'arg' for 'left_expr' and 'right_expr'
2054
+ # whenever possible so that plugins and similar things can introspect on the original
2055
+ # node if possible.
2056
+ #
2057
+ # We don't do the same for the base expression because it could lead to weird
2058
+ # type inference errors -- e.g. see 'testOperatorDoubleUnionSum'.
2059
+ # TODO: Can we use `type_overrides_set()` here?
2060
+ right_variants = [(right_type , arg )]
2061
+ if isinstance (right_type , UnionType ):
2062
+ right_variants = [(item , TempNode (item )) for item in right_type .relevant_items ()]
2063
+
2064
+ msg = self .msg .clean_copy ()
2065
+ msg .disable_count = 0
2066
+ all_results = []
2067
+ all_inferred = []
2068
+
2069
+ for left_possible_type in left_variants :
2070
+ for right_possible_type , right_expr in right_variants :
2071
+ result , inferred = self .check_op_reversible (
2072
+ op_name = method ,
2073
+ left_type = left_possible_type ,
2074
+ left_expr = TempNode (left_possible_type ),
2075
+ right_type = right_possible_type ,
2076
+ right_expr = right_expr ,
2077
+ context = context ,
2078
+ msg = msg )
2079
+ all_results .append (result )
2080
+ all_inferred .append (inferred )
2081
+
2082
+ if msg .is_errors ():
2083
+ self .msg .add_errors (msg )
2084
+ if len (left_variants ) >= 2 and len (right_variants ) >= 2 :
2085
+ self .msg .warn_both_operands_are_from_unions (context )
2086
+ elif len (left_variants ) >= 2 :
2087
+ self .msg .warn_operand_was_from_union ("Left" , base_type , context )
2088
+ elif len (right_variants ) >= 2 :
2089
+ self .msg .warn_operand_was_from_union ("Right" , right_type , context )
2090
+
2091
+ # See the comment in 'check_overload_call' for more details on why
2092
+ # we call 'combine_function_signature' instead of just unioning the inferred
2093
+ # callable types.
2094
+ results_final = UnionType .make_simplified_union (all_results )
2095
+ inferred_final = self .combine_function_signatures (all_inferred )
2096
+ return results_final , inferred_final
2033
2097
else :
2034
2098
return self .check_op_local_by_name (
2035
2099
method = method ,
0 commit comments