Skip to content

Commit ad6c717

Browse files
authored
Make isinstance/issubclass generate ad-hoc intersections (#8305)
This diff makes `isinstance(...)` and `issubclass(...)` try generating ad-hoc intersections of Instances when possible. For example, we previously concluded the if-branch is unreachable in the following program. This PR makes mypy infer an ad-hoc intersection instead. class A: pass class B: pass x: A if isinstance(x, B): reveal_type(x) # N: Revealed type is 'test.<subclass of "A" and "B">' If you try doing an `isinstance(...)` that legitimately is impossible due to conflicting method signatures or MRO issues, we continue to declare the branch unreachable. Passing in the `--warn-unreachable` flag will now also report an error about this: # flags: --warn-unreachable x: str # E: Subclass of "str" and "bytes" cannot exist: would have # incompatible method signatures if isinstance(x, bytes): reveal_type(x) # E: Statement is unreachable This error message has the same limitations as the other `--warn-unreachable` ones: we suppress them if the isinstance check is inside a function using TypeVars with multiple values. However, we *do* end up always inferring an intersection type when possible -- that logic is never suppressed. I initially thought we might have to suppress the new logic as well (see #3603 (comment)), but it turns out this is a non-issue in practice once you add in the check that disallows impossible intersections. For example, when I tried running this PR on the larger of our two internal codebases, I found about 25 distinct errors, all of which were legitimate and unrelated to the problem discussed in the PR. (And if we don't suppress the extra error message, we get about 100-120 errors, mostly due to tests repeatdly doing `result = blah()` followed by `assert isinstance(result, X)` where X keeps changing.)
1 parent f3c57e5 commit ad6c717

11 files changed

+834
-40
lines changed

mypy/checker.py

+158-22
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
get_proper_types, is_literal_type, TypeAliasType)
3939
from mypy.sametypes import is_same_type
4040
from mypy.messages import (
41-
MessageBuilder, make_inferred_type_note, append_invariance_notes,
41+
MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq,
4242
format_type, format_type_bare, format_type_distinctly, SUGGESTED_TEST_FIXTURES
4343
)
4444
import mypy.checkexpr
@@ -63,7 +63,7 @@
6363
from mypy.maptype import map_instance_to_supertype
6464
from mypy.typevars import fill_typevars, has_no_typevars, fill_typevars_with_any
6565
from mypy.semanal import set_callable_name, refers_to_fullname
66-
from mypy.mro import calculate_mro
66+
from mypy.mro import calculate_mro, MroError
6767
from mypy.erasetype import erase_typevars, remove_instance_last_known_values, erase_type
6868
from mypy.expandtype import expand_type, expand_type_by_instance
6969
from mypy.visitor import NodeVisitor
@@ -1963,13 +1963,15 @@ def visit_block(self, b: Block) -> None:
19631963
return
19641964
for s in b.body:
19651965
if self.binder.is_unreachable():
1966-
if (self.options.warn_unreachable
1967-
and not self.binder.is_unreachable_warning_suppressed()
1968-
and not self.is_raising_or_empty(s)):
1966+
if self.should_report_unreachable_issues() and not self.is_raising_or_empty(s):
19691967
self.msg.unreachable_statement(s)
19701968
break
19711969
self.accept(s)
19721970

1971+
def should_report_unreachable_issues(self) -> bool:
1972+
return (self.options.warn_unreachable
1973+
and not self.binder.is_unreachable_warning_suppressed())
1974+
19731975
def is_raising_or_empty(self, s: Statement) -> bool:
19741976
"""Returns 'true' if the given statement either throws an error of some kind
19751977
or is a no-op.
@@ -3636,6 +3638,100 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
36363638
self.binder.handle_continue()
36373639
return None
36383640

3641+
def make_fake_typeinfo(self,
3642+
curr_module_fullname: str,
3643+
class_gen_name: str,
3644+
class_short_name: str,
3645+
bases: List[Instance],
3646+
) -> Tuple[ClassDef, TypeInfo]:
3647+
# Build the fake ClassDef and TypeInfo together.
3648+
# The ClassDef is full of lies and doesn't actually contain a body.
3649+
# Use format_bare to generate a nice name for error messages.
3650+
# We skip fully filling out a handful of TypeInfo fields because they
3651+
# should be irrelevant for a generated type like this:
3652+
# is_protocol, protocol_members, is_abstract
3653+
cdef = ClassDef(class_short_name, Block([]))
3654+
cdef.fullname = curr_module_fullname + '.' + class_gen_name
3655+
info = TypeInfo(SymbolTable(), cdef, curr_module_fullname)
3656+
cdef.info = info
3657+
info.bases = bases
3658+
calculate_mro(info)
3659+
info.calculate_metaclass_type()
3660+
return cdef, info
3661+
3662+
def intersect_instances(self,
3663+
instances: Sequence[Instance],
3664+
ctx: Context,
3665+
) -> Optional[Instance]:
3666+
"""Try creating an ad-hoc intersection of the given instances.
3667+
3668+
Note that this function does *not* try and create a full-fledged
3669+
intersection type. Instead, it returns an instance of a new ad-hoc
3670+
subclass of the given instances.
3671+
3672+
This is mainly useful when you need a way of representing some
3673+
theoretical subclass of the instances the user may be trying to use
3674+
the generated intersection can serve as a placeholder.
3675+
3676+
This function will create a fresh subclass every time you call it,
3677+
even if you pass in the exact same arguments. So this means calling
3678+
`self.intersect_intersection([inst_1, inst_2], ctx)` twice will result
3679+
in instances of two distinct subclasses of inst_1 and inst_2.
3680+
3681+
This is by design: we want each ad-hoc intersection to be unique since
3682+
they're supposed represent some other unknown subclass.
3683+
3684+
Returns None if creating the subclass is impossible (e.g. due to
3685+
MRO errors or incompatible signatures). If we do successfully create
3686+
a subclass, its TypeInfo will automatically be added to the global scope.
3687+
"""
3688+
curr_module = self.scope.stack[0]
3689+
assert isinstance(curr_module, MypyFile)
3690+
3691+
base_classes = []
3692+
formatted_names = []
3693+
for inst in instances:
3694+
expanded = [inst]
3695+
if inst.type.is_intersection:
3696+
expanded = inst.type.bases
3697+
3698+
for expanded_inst in expanded:
3699+
base_classes.append(expanded_inst)
3700+
formatted_names.append(format_type_bare(expanded_inst))
3701+
3702+
pretty_names_list = pretty_seq(format_type_distinctly(*base_classes, bare=True), "and")
3703+
short_name = '<subclass of {}>'.format(pretty_names_list)
3704+
full_name = gen_unique_name(short_name, curr_module.names)
3705+
3706+
old_msg = self.msg
3707+
new_msg = self.msg.clean_copy()
3708+
self.msg = new_msg
3709+
try:
3710+
cdef, info = self.make_fake_typeinfo(
3711+
curr_module.fullname,
3712+
full_name,
3713+
short_name,
3714+
base_classes,
3715+
)
3716+
self.check_multiple_inheritance(info)
3717+
info.is_intersection = True
3718+
except MroError:
3719+
if self.should_report_unreachable_issues():
3720+
old_msg.impossible_intersection(
3721+
pretty_names_list, "inconsistent method resolution order", ctx)
3722+
return None
3723+
finally:
3724+
self.msg = old_msg
3725+
3726+
if new_msg.is_errors():
3727+
if self.should_report_unreachable_issues():
3728+
self.msg.impossible_intersection(
3729+
pretty_names_list, "incompatible method signatures", ctx)
3730+
return None
3731+
3732+
curr_module.names[full_name] = SymbolTableNode(GDEF, info)
3733+
return Instance(info, [])
3734+
36393735
def intersect_instance_callable(self, typ: Instance, callable_type: CallableType) -> Instance:
36403736
"""Creates a fake type that represents the intersection of an Instance and a CallableType.
36413737
@@ -3650,20 +3746,9 @@ def intersect_instance_callable(self, typ: Instance, callable_type: CallableType
36503746
gen_name = gen_unique_name("<callable subtype of {}>".format(typ.type.name),
36513747
cur_module.names)
36523748

3653-
# Build the fake ClassDef and TypeInfo together.
3654-
# The ClassDef is full of lies and doesn't actually contain a body.
3655-
# Use format_bare to generate a nice name for error messages.
3656-
# We skip fully filling out a handful of TypeInfo fields because they
3657-
# should be irrelevant for a generated type like this:
3658-
# is_protocol, protocol_members, is_abstract
3749+
# Synthesize a fake TypeInfo
36593750
short_name = format_type_bare(typ)
3660-
cdef = ClassDef(short_name, Block([]))
3661-
cdef.fullname = cur_module.fullname + '.' + gen_name
3662-
info = TypeInfo(SymbolTable(), cdef, cur_module.fullname)
3663-
cdef.info = info
3664-
info.bases = [typ]
3665-
calculate_mro(info)
3666-
info.calculate_metaclass_type()
3751+
cdef, info = self.make_fake_typeinfo(cur_module.fullname, gen_name, short_name, [typ])
36673752

36683753
# Build up a fake FuncDef so we can populate the symbol table.
36693754
func_def = FuncDef('__call__', [], Block([]), callable_type)
@@ -3828,9 +3913,11 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
38283913
return {}, {}
38293914
expr = node.args[0]
38303915
if literal(expr) == LITERAL_TYPE:
3831-
vartype = type_map[expr]
3832-
type = get_isinstance_type(node.args[1], type_map)
3833-
return conditional_type_map(expr, vartype, type)
3916+
return self.conditional_type_map_with_intersection(
3917+
expr,
3918+
type_map[expr],
3919+
get_isinstance_type(node.args[1], type_map),
3920+
)
38343921
elif refers_to_fullname(node.callee, 'builtins.issubclass'):
38353922
if len(node.args) != 2: # the error will be reported elsewhere
38363923
return {}, {}
@@ -4309,6 +4396,10 @@ def refine_identity_comparison_expression(self,
43094396

43104397
if enum_name is not None:
43114398
expr_type = try_expanding_enum_to_union(expr_type, enum_name)
4399+
4400+
# We intentionally use 'conditional_type_map' directly here instead of
4401+
# 'self.conditional_type_map_with_intersection': we only compute ad-hoc
4402+
# intersections when working with pure instances.
43124403
partial_type_maps.append(conditional_type_map(expr, expr_type, target_type))
43134404

43144405
return reduce_conditional_maps(partial_type_maps)
@@ -4726,10 +4817,55 @@ def infer_issubclass_maps(self, node: CallExpr,
47264817
# Any other object whose type we don't know precisely
47274818
# for example, Any or a custom metaclass.
47284819
return {}, {} # unknown type
4729-
yes_map, no_map = conditional_type_map(expr, vartype, type)
4820+
yes_map, no_map = self.conditional_type_map_with_intersection(expr, vartype, type)
47304821
yes_map, no_map = map(convert_to_typetype, (yes_map, no_map))
47314822
return yes_map, no_map
47324823

4824+
def conditional_type_map_with_intersection(self,
4825+
expr: Expression,
4826+
expr_type: Type,
4827+
type_ranges: Optional[List[TypeRange]],
4828+
) -> Tuple[TypeMap, TypeMap]:
4829+
# For some reason, doing "yes_map, no_map = conditional_type_maps(...)"
4830+
# doesn't work: mypyc will decide that 'yes_map' is of type None if we try.
4831+
initial_maps = conditional_type_map(expr, expr_type, type_ranges)
4832+
yes_map = initial_maps[0] # type: TypeMap
4833+
no_map = initial_maps[1] # type: TypeMap
4834+
4835+
if yes_map is not None or type_ranges is None:
4836+
return yes_map, no_map
4837+
4838+
# If conditions_type_map was unable to successfully narrow the expr_type
4839+
# using the type_ranges and concluded if-branch is unreachable, we try
4840+
# computing it again using a different algorithm that tries to generate
4841+
# an ad-hoc intersection between the expr_type and the type_ranges.
4842+
expr_type = get_proper_type(expr_type)
4843+
if isinstance(expr_type, UnionType):
4844+
possible_expr_types = get_proper_types(expr_type.relevant_items())
4845+
else:
4846+
possible_expr_types = [expr_type]
4847+
4848+
possible_target_types = []
4849+
for tr in type_ranges:
4850+
item = get_proper_type(tr.item)
4851+
if not isinstance(item, Instance) or tr.is_upper_bound:
4852+
return yes_map, no_map
4853+
possible_target_types.append(item)
4854+
4855+
out = []
4856+
for v in possible_expr_types:
4857+
if not isinstance(v, Instance):
4858+
return yes_map, no_map
4859+
for t in possible_target_types:
4860+
intersection = self.intersect_instances([v, t], expr)
4861+
if intersection is None:
4862+
continue
4863+
out.append(intersection)
4864+
if len(out) == 0:
4865+
return None, {}
4866+
new_yes_type = make_simplified_union(out)
4867+
return {expr: new_yes_type}, {}
4868+
47334869

47344870
def conditional_type_map(expr: Expression,
47354871
current_type: Optional[Type],

mypy/messages.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,11 @@ def has_no_attr(self,
290290
if matches:
291291
self.fail(
292292
'{} has no attribute "{}"; maybe {}?{}'.format(
293-
format_type(original_type), member, pretty_or(matches), extra),
293+
format_type(original_type),
294+
member,
295+
pretty_seq(matches, "or"),
296+
extra,
297+
),
294298
context,
295299
code=codes.ATTR_DEFINED)
296300
failed = True
@@ -623,7 +627,7 @@ def unexpected_keyword_argument(self, callee: CallableType, name: str, arg_type:
623627
if not matches:
624628
matches = best_matches(name, not_matching_type_args)
625629
if matches:
626-
msg += "; did you mean {}?".format(pretty_or(matches[:3]))
630+
msg += "; did you mean {}?".format(pretty_seq(matches[:3], "or"))
627631
self.fail(msg, context, code=codes.CALL_ARG)
628632
module = find_defining_module(self.modules, callee)
629633
if module:
@@ -1265,6 +1269,15 @@ def redundant_expr(self, description: str, truthiness: bool, context: Context) -
12651269
self.fail("{} is always {}".format(description, str(truthiness).lower()),
12661270
context, code=codes.UNREACHABLE)
12671271

1272+
def impossible_intersection(self,
1273+
formatted_base_class_list: str,
1274+
reason: str,
1275+
context: Context,
1276+
) -> None:
1277+
template = "Subclass of {} cannot exist: would have {}"
1278+
self.fail(template.format(formatted_base_class_list, reason), context,
1279+
code=codes.UNREACHABLE)
1280+
12681281
def report_protocol_problems(self,
12691282
subtype: Union[Instance, TupleType, TypedDictType],
12701283
supertype: Instance,
@@ -1997,13 +2010,14 @@ def best_matches(current: str, options: Iterable[str]) -> List[str]:
19972010
reverse=True, key=lambda v: (ratios[v], v))
19982011

19992012

2000-
def pretty_or(args: List[str]) -> str:
2013+
def pretty_seq(args: Sequence[str], conjunction: str) -> str:
20012014
quoted = ['"' + a + '"' for a in args]
20022015
if len(quoted) == 1:
20032016
return quoted[0]
20042017
if len(quoted) == 2:
2005-
return "{} or {}".format(quoted[0], quoted[1])
2006-
return ", ".join(quoted[:-1]) + ", or " + quoted[-1]
2018+
return "{} {} {}".format(quoted[0], conjunction, quoted[1])
2019+
last_sep = ", " + conjunction + " "
2020+
return ", ".join(quoted[:-1]) + last_sep + quoted[-1]
20072021

20082022

20092023
def append_invariance_notes(notes: List[str], arg_type: Instance,

mypy/nodes.py

+4
Original file line numberDiff line numberDiff line change
@@ -2379,13 +2379,17 @@ class is generic then it will be a type constructor of higher kind.
23792379
# Is this a newtype type?
23802380
is_newtype = False
23812381

2382+
# Is this a synthesized intersection type?
2383+
is_intersection = False
2384+
23822385
# This is a dictionary that will be serialized and un-serialized as is.
23832386
# It is useful for plugins to add their data to save in the cache.
23842387
metadata = None # type: Dict[str, JsonDict]
23852388

23862389
FLAGS = [
23872390
'is_abstract', 'is_enum', 'fallback_to_any', 'is_named_tuple',
23882391
'is_newtype', 'is_protocol', 'runtime_protocol', 'is_final',
2392+
'is_intersection',
23892393
] # type: Final[List[str]]
23902394

23912395
def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> None:

mypy/semanal.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from mypy.typevars import fill_typevars
8282
from mypy.visitor import NodeVisitor
8383
from mypy.errors import Errors, report_internal_error
84-
from mypy.messages import best_matches, MessageBuilder, pretty_or, SUGGESTED_TEST_FIXTURES
84+
from mypy.messages import best_matches, MessageBuilder, pretty_seq, SUGGESTED_TEST_FIXTURES
8585
from mypy.errorcodes import ErrorCode
8686
from mypy import message_registry, errorcodes as codes
8787
from mypy.types import (
@@ -1802,7 +1802,7 @@ def report_missing_module_attribute(self, import_id: str, source_id: str, import
18021802
alternatives = set(module.names.keys()).difference({source_id})
18031803
matches = best_matches(source_id, alternatives)[:3]
18041804
if matches:
1805-
suggestion = "; maybe {}?".format(pretty_or(matches))
1805+
suggestion = "; maybe {}?".format(pretty_seq(matches, "or"))
18061806
message += "{}".format(suggestion)
18071807
self.fail(message, context, code=codes.ATTR_DEFINED)
18081808
self.add_unknown_imported_symbol(imported_id, context)

0 commit comments

Comments
 (0)