Skip to content

Commit eb41417

Browse files
Michael0x2ailevkivskyi
authored andcommitted
Add support for operators with union operands (#5545)
* Add support for operators with union operands This pull request resolves #2128 -- it modifies how we check operators to add support for operations like `Union[int, float] + Union[int, float]`. This approach basically iterates over all possible variations of the left and right operands when they're unions and uses the union of the resulting inferred type as the type of the overall expression. Some implementation notes: 1. I attempting "destructuring" just the left operand, which is basically the approach proposed here: #2128 (comment) Unfortunately, I discovered it became necessary to also destructure the right operand to handle certain edge cases -- see the testOperatorDoubleUnionInterwovenUnionAdd test case. 2. This algorithm varies slightly from what we do for union math in that we don't attempt to "preserve" the union/we always destructure both operands.
1 parent bdad88a commit eb41417

11 files changed

+387
-73
lines changed

mypy/checkexpr.py

+90-26
Original file line numberDiff line numberDiff line change
@@ -1188,14 +1188,14 @@ def check_overload_call(self,
11881188
# gives a narrower type.
11891189
if unioned_return:
11901190
returns, inferred_types = zip(*unioned_return)
1191-
# Note that we use `union_overload_matches` instead of just returning
1191+
# Note that we use `combine_function_signatures` instead of just returning
11921192
# a union of inferred callables because for example a call
11931193
# Union[int -> int, str -> str](Union[int, str]) is invalid and
11941194
# we don't want to introduce internal inconsistencies.
11951195
unioned_result = (UnionType.make_simplified_union(list(returns),
11961196
context.line,
11971197
context.column),
1198-
self.union_overload_matches(inferred_types))
1198+
self.combine_function_signatures(inferred_types))
11991199

12001200
# Step 3: We try checking each branch one-by-one.
12011201
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types,
@@ -1492,8 +1492,8 @@ def type_overrides_set(self, exprs: Sequence[Expression],
14921492
for expr in exprs:
14931493
del self.type_overrides[expr]
14941494

1495-
def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, CallableType]:
1496-
"""Accepts a list of overload signatures and attempts to combine them together into a
1495+
def combine_function_signatures(self, types: Sequence[Type]) -> Union[AnyType, CallableType]:
1496+
"""Accepts a list of function signatures and attempts to combine them together into a
14971497
new CallableType consisting of the union of all of the given arguments and return types.
14981498
14991499
If there is at least one non-callable type, return Any (this can happen if there is
@@ -1507,7 +1507,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab
15071507
return callables[0]
15081508

15091509
# Note: we are assuming here that if a user uses some TypeVar 'T' in
1510-
# two different overloads, they meant for that TypeVar to mean the
1510+
# two different functions, they meant for that TypeVar to mean the
15111511
# same thing.
15121512
#
15131513
# This function will make sure that all instances of that TypeVar 'T'
@@ -1525,7 +1525,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab
15251525

15261526
too_complex = False
15271527
for target in callables:
1528-
# We fall back to Callable[..., Union[<returns>]] if the overloads do not have
1528+
# We fall back to Callable[..., Union[<returns>]] if the functions do not have
15291529
# the exact same signature. The only exception is if one arg is optional and
15301530
# the other is positional: in that case, we continue unioning (and expect a
15311531
# positional arg).
@@ -1820,19 +1820,12 @@ def check_op_reversible(self,
18201820
left_expr: Expression,
18211821
right_type: Type,
18221822
right_expr: Expression,
1823-
context: Context) -> Tuple[Type, Type]:
1824-
# Note: this kludge exists mostly to maintain compatibility with
1825-
# existing error messages. Apparently, if the left-hand-side is a
1826-
# union and we have a type mismatch, we print out a special,
1827-
# abbreviated error message. (See messages.unsupported_operand_types).
1828-
unions_present = isinstance(left_type, UnionType)
1829-
1823+
context: Context,
1824+
msg: MessageBuilder) -> Tuple[Type, Type]:
18301825
def make_local_errors() -> MessageBuilder:
18311826
"""Creates a new MessageBuilder object."""
1832-
local_errors = self.msg.clean_copy()
1827+
local_errors = msg.clean_copy()
18331828
local_errors.disable_count = 0
1834-
if unions_present:
1835-
local_errors.disable_type_names += 1
18361829
return local_errors
18371830

18381831
def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
@@ -2006,30 +1999,101 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
20061999
# TODO: Remove this extra case
20072000
return result
20082001

2009-
self.msg.add_errors(errors[0])
2002+
msg.add_errors(errors[0])
20102003
if len(results) == 1:
20112004
return results[0]
20122005
else:
20132006
error_any = AnyType(TypeOfAny.from_error)
20142007
result = error_any, error_any
20152008
return result
20162009

2017-
def check_op(self, method: str, base_type: Type, arg: Expression,
2018-
context: Context,
2010+
def check_op(self, method: str, base_type: Type,
2011+
arg: Expression, context: Context,
20192012
allow_reverse: bool = False) -> Tuple[Type, Type]:
20202013
"""Type check a binary operation which maps to a method call.
20212014
20222015
Return tuple (result type, inferred operator method type).
20232016
"""
20242017

20252018
if allow_reverse:
2026-
return self.check_op_reversible(
2027-
op_name=method,
2028-
left_type=base_type,
2029-
left_expr=TempNode(base_type),
2030-
right_type=self.accept(arg),
2031-
right_expr=arg,
2032-
context=context)
2019+
left_variants = [base_type]
2020+
if isinstance(base_type, UnionType):
2021+
left_variants = [item for item in base_type.relevant_items()]
2022+
right_type = self.accept(arg)
2023+
2024+
# Step 1: We first try leaving the right arguments alone and destructure
2025+
# just the left ones. (Mypy can sometimes perform some more precise inference
2026+
# if we leave the right operands a union -- see testOperatorWithEmptyListAndSum.
2027+
msg = self.msg.clean_copy()
2028+
msg.disable_count = 0
2029+
all_results = []
2030+
all_inferred = []
2031+
2032+
for left_possible_type in left_variants:
2033+
result, inferred = self.check_op_reversible(
2034+
op_name=method,
2035+
left_type=left_possible_type,
2036+
left_expr=TempNode(left_possible_type),
2037+
right_type=right_type,
2038+
right_expr=arg,
2039+
context=context,
2040+
msg=msg)
2041+
all_results.append(result)
2042+
all_inferred.append(inferred)
2043+
2044+
if not msg.is_errors():
2045+
results_final = UnionType.make_simplified_union(all_results)
2046+
inferred_final = UnionType.make_simplified_union(all_inferred)
2047+
return results_final, inferred_final
2048+
2049+
# Step 2: If that fails, we try again but also destructure the right argument.
2050+
# This is also necessary to make certain edge cases work -- see
2051+
# testOperatorDoubleUnionInterwovenUnionAdd, for example.
2052+
2053+
# Note: We want to pass in the original 'arg' for 'left_expr' and 'right_expr'
2054+
# whenever possible so that plugins and similar things can introspect on the original
2055+
# node if possible.
2056+
#
2057+
# We don't do the same for the base expression because it could lead to weird
2058+
# type inference errors -- e.g. see 'testOperatorDoubleUnionSum'.
2059+
# TODO: Can we use `type_overrides_set()` here?
2060+
right_variants = [(right_type, arg)]
2061+
if isinstance(right_type, UnionType):
2062+
right_variants = [(item, TempNode(item)) for item in right_type.relevant_items()]
2063+
2064+
msg = self.msg.clean_copy()
2065+
msg.disable_count = 0
2066+
all_results = []
2067+
all_inferred = []
2068+
2069+
for left_possible_type in left_variants:
2070+
for right_possible_type, right_expr in right_variants:
2071+
result, inferred = self.check_op_reversible(
2072+
op_name=method,
2073+
left_type=left_possible_type,
2074+
left_expr=TempNode(left_possible_type),
2075+
right_type=right_possible_type,
2076+
right_expr=right_expr,
2077+
context=context,
2078+
msg=msg)
2079+
all_results.append(result)
2080+
all_inferred.append(inferred)
2081+
2082+
if msg.is_errors():
2083+
self.msg.add_errors(msg)
2084+
if len(left_variants) >= 2 and len(right_variants) >= 2:
2085+
self.msg.warn_both_operands_are_from_unions(context)
2086+
elif len(left_variants) >= 2:
2087+
self.msg.warn_operand_was_from_union("Left", base_type, context)
2088+
elif len(right_variants) >= 2:
2089+
self.msg.warn_operand_was_from_union("Right", right_type, context)
2090+
2091+
# See the comment in 'check_overload_call' for more details on why
2092+
# we call 'combine_function_signature' instead of just unioning the inferred
2093+
# callable types.
2094+
results_final = UnionType.make_simplified_union(all_results)
2095+
inferred_final = self.combine_function_signatures(all_inferred)
2096+
return results_final, inferred_final
20332097
else:
20342098
return self.check_op_local_by_name(
20352099
method=method,

mypy/messages.py

+6
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,12 @@ def overloaded_signatures_ret_specific(self, index: int, context: Context) -> No
10031003
self.fail('Overloaded function implementation cannot produce return type '
10041004
'of signature {}'.format(index), context)
10051005

1006+
def warn_both_operands_are_from_unions(self, context: Context) -> None:
1007+
self.note('Both left and right operands are unions', context)
1008+
1009+
def warn_operand_was_from_union(self, side: str, original: Type, context: Context) -> None:
1010+
self.note('{} operand is of type {}'.format(side, self.format(original)), context)
1011+
10061012
def operator_method_signatures_overlap(
10071013
self, reverse_class: TypeInfo, reverse_method: str, forward_class: Type,
10081014
forward_method: str, context: Context) -> None:

test-data/unit/check-callable.test

+6-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ from typing import Callable, Union
4646
x = 5 # type: Union[int, Callable[[], str], Callable[[], int]]
4747

4848
if callable(x):
49-
y = x() + 2 # E: Unsupported operand types for + (likely involving Union)
49+
y = x() + 2 # E: Unsupported operand types for + ("str" and "int") \
50+
# N: Left operand is of type "Union[str, int]"
5051
else:
5152
z = x + 6
5253

@@ -60,7 +61,8 @@ x = 5 # type: Union[int, str, Callable[[], str]]
6061
if callable(x):
6162
y = x() + 'test'
6263
else:
63-
z = x + 6 # E: Unsupported operand types for + (likely involving Union)
64+
z = x + 6 # E: Unsupported operand types for + ("str" and "int") \
65+
# N: Left operand is of type "Union[int, str]"
6466

6567
[builtins fixtures/callable.pyi]
6668

@@ -153,7 +155,8 @@ x = 5 # type: Union[int, Callable[[], str]]
153155
if callable(x) and x() == 'test':
154156
x()
155157
else:
156-
x + 5 # E: Unsupported left operand type for + (some union)
158+
x + 5 # E: Unsupported left operand type for + ("Callable[[], str]") \
159+
# N: Left operand is of type "Union[int, Callable[[], str]]"
157160

158161
[builtins fixtures/callable.pyi]
159162

0 commit comments

Comments
 (0)