diff --git a/mypy/checker.py b/mypy/checker.py index 8528bf35248d..ae829d1157c1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5,8 +5,8 @@ from contextlib import contextmanager from typing import ( - Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Sequence, - Mapping, + Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable, + Sequence, Mapping, Generic, AbstractSet ) from typing_extensions import Final @@ -27,7 +27,7 @@ is_final_node, ARG_NAMED) from mypy import nodes -from mypy.literals import literal, literal_hash +from mypy.literals import literal, literal_hash, Key from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any from mypy.types import ( Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType, @@ -3842,67 +3842,101 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM vartype = type_map[expr] return self.conditional_callable_type_map(expr, vartype) elif isinstance(node, ComparisonExpr): - operand_types = [coerce_to_literal(type_map[expr]) - for expr in node.operands if expr in type_map] - - is_not = node.operators == ['is not'] - if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands): - if_vars = {} # type: TypeMap - else_vars = {} # type: TypeMap - - for i, expr in enumerate(node.operands): - var_type = operand_types[i] - other_type = operand_types[1 - i] - - if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type): - # This should only be true at most once: there should be - # exactly two elements in node.operands and if the 'other type' is - # a singleton type, it by definition does not need to be narrowed: - # it already has the most precise type possible so does not need to - # be narrowed/included in the output map. - # - # TODO: Generalize this to handle the case where 'other_type' is - # a union of singleton types. - - if isinstance(other_type, LiteralType) and other_type.is_enum_literal(): - fallback_name = other_type.fallback.type.fullname - var_type = try_expanding_enum_to_union(var_type, fallback_name) - - target_type = [TypeRange(other_type, is_upper_bound=False)] - if_vars, else_vars = conditional_type_map(expr, var_type, target_type) - break + # Step 1: Obtain the types of each operand and whether or not we can + # narrow their types. (For example, we shouldn't try narrowing the + # types of literal string or enum expressions). + + operands = node.operands + operand_types = [] + narrowable_operand_index_to_hash = {} + for i, expr in enumerate(operands): + if expr not in type_map: + return {}, {} + expr_type = type_map[expr] + operand_types.append(expr_type) + + if (literal(expr) == LITERAL_TYPE + and not is_literal_none(expr) + and not is_literal_enum(type_map, expr)): + h = literal_hash(expr) + if h is not None: + narrowable_operand_index_to_hash[i] = h + + # Step 2: Group operands chained by either the 'is' or '==' operands + # together. For all other operands, we keep them in groups of size 2. + # So the expression: + # + # x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8 + # + # ...is converted into the simplified operator list: + # + # [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]), + # ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])] + # + # We group identity/equality expressions so we can propagate information + # we discover about one operand across the entire chain. We don't bother + # handling 'is not' and '!=' chains in a special way: those are very rare + # in practice. + + simplified_operator_list = group_comparison_operands( + node.pairwise(), + narrowable_operand_index_to_hash, + {'==', 'is'}, + ) + + # Step 3: Analyze each group and infer more precise type maps for each + # assignable operand, if possible. We combine these type maps together + # in the final step. + + partial_type_maps = [] + for operator, expr_indices in simplified_operator_list: + if operator in {'is', 'is not'}: + if_map, else_map = self.refine_identity_comparison_expression( + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + ) + elif operator in {'==', '!='}: + if_map, else_map = self.refine_equality_comparison_expression( + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + ) + elif operator in {'in', 'not in'}: + assert len(expr_indices) == 2 + left_index, right_index = expr_indices + if left_index not in narrowable_operand_index_to_hash: + continue + + item_type = operand_types[left_index] + collection_type = operand_types[right_index] + + # We only try and narrow away 'None' for now + if not is_optional(item_type): + pass - if is_not: - if_vars, else_vars = else_vars, if_vars - return if_vars, else_vars - # Check for `x == y` where x is of type Optional[T] and y is of type T - # or a type that overlaps with T (or vice versa). - elif node.operators == ['==']: - first_type = type_map[node.operands[0]] - second_type = type_map[node.operands[1]] - if is_optional(first_type) != is_optional(second_type): - if is_optional(first_type): - optional_type, comp_type = first_type, second_type - optional_expr = node.operands[0] + collection_item_type = get_proper_type(builtin_item_type(collection_type)) + if collection_item_type is None or is_optional(collection_item_type): + continue + if (isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == 'builtins.object'): + continue + if is_overlapping_erased_types(item_type, collection_item_type): + if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {} else: - optional_type, comp_type = second_type, first_type - optional_expr = node.operands[1] - if is_overlapping_erased_types(optional_type, comp_type): - return {optional_expr: remove_optional(optional_type)}, {} - elif node.operators in [['in'], ['not in']]: - expr = node.operands[0] - left_type = type_map[expr] - right_type = get_proper_type(builtin_item_type(type_map[node.operands[1]])) - right_ok = right_type and (not is_optional(right_type) and - (not isinstance(right_type, Instance) or - right_type.type.fullname != 'builtins.object')) - if (right_type and right_ok and is_optional(left_type) and - literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and - is_overlapping_erased_types(left_type, right_type)): - if node.operators == ['in']: - return {expr: remove_optional(left_type)}, {} - if node.operators == ['not in']: - return {}, {expr: remove_optional(left_type)} + continue + else: + if_map = {} + else_map = {} + + if operator in {'is not', '!=', 'not in'}: + if_map, else_map = else_map, if_map + + partial_type_maps.append((if_map, else_map)) + + return reduce_partial_conditional_maps(partial_type_maps) elif isinstance(node, RefExpr): # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively @@ -4107,6 +4141,143 @@ def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: return output + def refine_identity_comparison_expression(self, + operands: List[Expression], + operand_types: List[Type], + chain_indices: List[int], + narrowable_operand_indices: AbstractSet[int], + ) -> Tuple[TypeMap, TypeMap]: + """Produces conditional type maps refining expressions used in an identity comparison. + + The 'operands' and 'operand_types' lists should be the full list of operands used + in the overall comparison expression. The 'chain_indices' list is the list of indices + actually used within this identity comparison chain. + + So if we have the expression: + + a <= b is c is d <= e + + ...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices' + would be the list [1, 2, 3]. + + The 'narrowable_operand_indices' parameter is the set of all indices we are allowed + to refine the types of: that is, all operands that will potentially be a part of + the output TypeMaps. + """ + singleton = None # type: Optional[ProperType] + possible_singleton_indices = [] + for i in chain_indices: + coerced_type = coerce_to_literal(operand_types[i]) + if not is_singleton_type(coerced_type): + continue + if singleton and not is_same_type(singleton, coerced_type): + # We have multiple disjoint singleton types. So the 'if' branch + # must be unreachable. + return None, {} + singleton = coerced_type + possible_singleton_indices.append(i) + + # There's nothing we can currently infer if none of the operands are singleton types, + # so we end early and infer nothing. + if singleton is None: + return {}, {} + + # If possible, use an unassignable expression as the singleton. + # We skip refining the type of the singleton below, so ideally we'd + # want to pick an expression we were going to skip anyways. + singleton_index = -1 + for i in possible_singleton_indices: + if i not in narrowable_operand_indices: + singleton_index = i + + # But if none of the possible singletons are unassignable ones, we give up + # and arbitrarily pick the last item, mostly because other parts of the + # type narrowing logic bias towards picking the rightmost item and it'd be + # nice to stay consistent. + # + # That said, it shouldn't matter which index we pick. For example, suppose we + # have this if statement, where 'x' and 'y' both have singleton types: + # + # if x is y: + # reveal_type(x) + # reveal_type(y) + # else: + # reveal_type(x) + # reveal_type(y) + # + # At this point, 'x' and 'y' *must* have the same singleton type: we would have + # ended early in the first for-loop in this function if they weren't. + # + # So, we should always get the same result in the 'if' case no matter which + # index we pick. And while we do end up getting different results in the 'else' + # case depending on the index (e.g. if we pick 'y', then its type stays the same + # while 'x' is narrowed to ''), this distinction is also moot: mypy + # currently will just mark the whole branch as unreachable if either operand is + # narrowed to . + if singleton_index == -1: + singleton_index = possible_singleton_indices[-1] + + enum_name = None + if isinstance(singleton, LiteralType) and singleton.is_enum_literal(): + enum_name = singleton.fallback.type.fullname + + target_type = [TypeRange(singleton, is_upper_bound=False)] + + partial_type_maps = [] + for i in chain_indices: + # If we try refining a singleton against itself, conditional_type_map + # will end up assuming that the 'else' branch is unreachable. This is + # typically not what we want: generally the user will intend for the + # singleton type to be some fixed 'sentinel' value and will want to refine + # the other exprs against this one instead. + if i == singleton_index: + continue + + # Naturally, we can't refine operands which are not permitted to be refined. + if i not in narrowable_operand_indices: + continue + + expr = operands[i] + expr_type = coerce_to_literal(operand_types[i]) + + if enum_name is not None: + expr_type = try_expanding_enum_to_union(expr_type, enum_name) + partial_type_maps.append(conditional_type_map(expr, expr_type, target_type)) + + return reduce_partial_conditional_maps(partial_type_maps) + + def refine_equality_comparison_expression(self, + operands: List[Expression], + operand_types: List[Type], + chain_indices: List[int], + narrowable_operand_indices: AbstractSet[int], + ) -> Tuple[TypeMap, TypeMap]: + """Produces conditional type maps refining expressions used in an equality comparison. + + For more details, see the docstring of 'refine_equality_comparison' up above. + The only difference is that this function is for refining equality operations + (e.g. 'a == b == c') instead of identity ('a is b is c'). + """ + non_optional_types = [] + for i in chain_indices: + typ = operand_types[i] + if not is_optional(typ): + non_optional_types.append(typ) + + # Make sure we have a mixture of optional and non-optional types. + if len(non_optional_types) == 0 or len(non_optional_types) == len(chain_indices): + return {}, {} + + if_map = {} + for i in narrowable_operand_indices: + expr_type = operand_types[i] + if not is_optional(expr_type): + continue + if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types): + if_map[operands[i]] = remove_optional(expr_type) + + return if_map, {} + # # Helpers # @@ -4541,16 +4712,55 @@ def gen_unique_name(base: str, table: SymbolTable) -> str: def is_true_literal(n: Expression) -> bool: + """Returns true if this expression is the 'True' literal/keyword.""" return (refers_to_fullname(n, 'builtins.True') or isinstance(n, IntExpr) and n.value == 1) def is_false_literal(n: Expression) -> bool: + """Returns true if this expression is the 'False' literal/keyword.""" return (refers_to_fullname(n, 'builtins.False') or isinstance(n, IntExpr) and n.value == 0) +def is_literal_enum(type_map: Mapping[Expression, Type], n: Expression) -> bool: + """Returns true if this expression (with the given type context) is an Enum literal. + + For example, if we had an enum: + + class Foo(Enum): + A = 1 + B = 2 + + ...and if the expression 'Foo' referred to that enum within the current type context, + then the expression 'Foo.A' would be a a literal enum. However, if we did 'a = Foo.A', + then the variable 'a' would *not* be a literal enum. + + We occasionally special-case expressions like 'Foo.A' and treat them as a single primitive + unit for the same reasons we sometimes treat 'True', 'False', or 'None' as a single + primitive unit. + """ + if not isinstance(n, MemberExpr) or not isinstance(n.expr, NameExpr): + return False + + parent_type = type_map.get(n.expr) + member_type = type_map.get(n) + if member_type is None or parent_type is None: + return False + + parent_type = get_proper_type(parent_type) + member_type = coerce_to_literal(member_type) + if not isinstance(parent_type, FunctionLike) or not isinstance(member_type, LiteralType): + return False + + if not parent_type.is_type_obj(): + return False + + return member_type.is_enum_literal() and member_type.fallback.type == parent_type.type_object() + + def is_literal_none(n: Expression) -> bool: + """Returns true if this expression is the 'None' literal/keyword.""" return isinstance(n, NameExpr) and n.fullname == 'builtins.None' @@ -4641,6 +4851,76 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: return result +def or_partial_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: + """Calculate what information we can learn from the truth of (e1 or e2) + in terms of the information that we can learn from the truth of e1 and + the truth of e2. + + Unlike 'or_conditional_maps', we include an expression in the output even + if it exists in only one map: we're assuming both maps are "partial" and + contain information about only some expressions, and so we "or" together + expressions both maps have information on. + """ + + if m1 is None: + return m2 + if m2 is None: + return m1 + # The logic here is a blend between 'and_conditional_maps' + # and 'or_conditional_maps'. We use the high-level logic from the + # former to ensure all expressions make it in the output map, + # but resolve cases where both maps contain info on the same + # expr using the unioning strategy from the latter. + result = m2.copy() + m2_keys = {literal_hash(n2): n2 for n2 in m2} + for n1 in m1: + n2 = m2_keys.get(literal_hash(n1)) + if n2 is None: + result[n1] = m1[n1] + else: + result[n2] = make_simplified_union([m1[n1], result[n2]]) + + return result + + +def reduce_partial_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]], + ) -> Tuple[TypeMap, TypeMap]: + """Reduces a list containing pairs of *partial* if/else TypeMaps into a single pair. + + That is, if a expression exists in only one map, we always include it in the output. + We only "and"/"or" together expressions that appear in multiple if/else maps. + + So for example, if we had the input: + + [ + ({x: TypeIfX, shared: TypeIfShared1}, {x: TypeElseX, shared: TypeElseShared1}), + ({y: TypeIfY, shared: TypeIfShared2}, {y: TypeElseY, shared: TypeElseShared2}), + ] + + ...we'd return the output: + + ( + {x: TypeIfX, y: TypeIfY, shared: PseudoIntersection[TypeIfShared1, TypeIfShared2]}, + {x: TypeElseX, y: TypeElseY, shared: Union[TypeElseShared1, TypeElseShared2]}, + ) + + ...where "PseudoIntersection[X, Y] == Y" because mypy actually doesn't understand intersections + yet, so we settle for just arbitrarily picking the right expr's type. + """ + if len(type_maps) == 0: + return {}, {} + elif len(type_maps) == 1: + return type_maps[0] + else: + final_if_map, final_else_map = type_maps[0] + for if_map, else_map in type_maps[1:]: + # 'and_conditional_maps' does the same thing for both global and partial type maps, + # which is why we don't need to have an 'and_partial_conditional_maps' function. + final_if_map = and_conditional_maps(final_if_map, if_map) + final_else_map = or_partial_conditional_maps(final_else_map, else_map) + return final_if_map, final_else_map + + def convert_to_typetype(type_map: TypeMap) -> TypeMap: converted_type_map = {} # type: Dict[Expression, Type] if type_map is None: @@ -5007,6 +5287,205 @@ def nothing() -> Iterator[None]: yield +TKey = TypeVar('TKey') +TValue = TypeVar('TValue') + + +class DisjointDict(Generic[TKey, TValue]): + """An variation of the union-find algorithm/data structure where instead of keeping + track of just disjoint sets, we keep track of disjoint dicts -- keep track of multiple + Set[Key] -> Set[Value] mappings, where each mapping's keys are guaranteed to be disjoint. + + This data structure is currently used exclusively by 'group_comparison_operands' below + to merge chains of '==' and 'is' comparisons when two or more chains use the same expression + in best-case O(n), where n is the number of operands. + + Specifically, the `add_mapping()` function and `items()` functions will take on average + O(k + v) and O(n) respectively, where k and v are the number of keys and values we're adding + for a given chain. Note that k <= n and v <= n. + + We hit these average/best-case scenarios for most user code: e.g. when the user has just + a single chain like 'a == b == c == d == ...' or multiple disjoint chains like + 'a==b < c==d < e==f < ...'. (Note that a naive iterative merging would be O(n^2) for + the latter case). + + In comparison, this data structure will make 'group_comparison_operands' have a worst-case + runtime of O(n*log(n)): 'add_mapping()' and 'items()' are worst-case O(k*log(n) + v) and + O(k*log(n)) respectively. This happens only in the rare case where the user keeps repeatedly + making disjoint mappings before merging them in a way that persistently dodges the path + compression optimization in '_lookup_root_id', which would end up constructing a single + tree of height log_2(n). This makes root lookups no longer amoritized constant time when we + finally call 'items()'. + """ + def __init__(self) -> None: + # Each key maps to a unique ID + self._key_to_id = {} # type: Dict[TKey, int] + + # Each id points to the parent id, forming a forest of upwards-pointing trees. If the + # current id already is the root, it points to itself. We gradually flatten these trees + # as we perform root lookups: eventually all nodes point directly to its root. + self._id_to_parent_id = {} # type: Dict[int, int] + + # Each root id in turn maps to the set of values. + self._root_id_to_values = {} # type: Dict[int, Set[TValue]] + + def add_mapping(self, keys: Set[TKey], values: Set[TValue]) -> None: + """Adds a 'Set[TKey] -> Set[TValue]' mapping. If there already exists a mapping + containing one or more of the given keys, we merge the input mapping with the old one. + + Note that the given set of keys must be non-empty -- otherwise, nothing happens. + """ + if len(keys) == 0: + return + + subtree_roots = [self._lookup_or_make_root_id(key) for key in keys] + new_root = subtree_roots[0] + + root_values = self._root_id_to_values[new_root] + root_values.update(values) + for subtree_root in subtree_roots[1:]: + if subtree_root == new_root or subtree_root not in self._root_id_to_values: + continue + self._id_to_parent_id[subtree_root] = new_root + root_values.update(self._root_id_to_values.pop(subtree_root)) + + def items(self) -> List[Tuple[Set[TKey], Set[TValue]]]: + """Returns all disjoint mappings in key-value pairs.""" + root_id_to_keys = {} # type: Dict[int, Set[TKey]] + for key in self._key_to_id: + root_id = self._lookup_root_id(key) + if root_id not in root_id_to_keys: + root_id_to_keys[root_id] = set() + root_id_to_keys[root_id].add(key) + + output = [] + for root_id, keys in root_id_to_keys.items(): + output.append((keys, self._root_id_to_values[root_id])) + + return output + + def _lookup_or_make_root_id(self, key: TKey) -> int: + if key in self._key_to_id: + return self._lookup_root_id(key) + else: + new_id = len(self._key_to_id) + self._key_to_id[key] = new_id + self._id_to_parent_id[new_id] = new_id + self._root_id_to_values[new_id] = set() + return new_id + + def _lookup_root_id(self, key: TKey) -> int: + i = self._key_to_id[key] + while i != self._id_to_parent_id[i]: + # Optimization: make keys directly point to their grandparents to speed up + # future traversals. This prevents degenerate trees of height n from forming. + new_parent = self._id_to_parent_id[self._id_to_parent_id[i]] + self._id_to_parent_id[i] = new_parent + i = new_parent + return i + + +def group_comparison_operands(pairwise_comparisons: Iterable[Tuple[str, Expression, Expression]], + operand_to_literal_hash: Mapping[int, Key], + operators_to_group: Set[str], + ) -> List[Tuple[str, List[int]]]: + """Group a series of comparison operands together chained by any operand + in the 'operators_to_group' set. All other pairwise operands are kept in + groups of size 2. + + For example, suppose we have the input comparison expression: + + x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8 + + If we get these expressions in a pairwise way (e.g. by calling ComparisionExpr's + 'pairwise()' method), we get the following as input: + + [('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('<', x3, x4), + ('is', x4, x5), ('is', x5, x6), ('is not', x6, x7), ('is not', x7, x8)] + + If `operators_to_group` is the set {'==', 'is'}, this function will produce + the following "simplified operator list": + + [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]), + ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])] + + Note that (a) we yield *indices* to the operands rather then the operand + expressions themselves and that (b) operands used in a consecutive chain + of '==' or 'is' are grouped together. + + If two of these chains happen to contain operands with the same underlying + literal hash (e.g. are assignable and correspond to the same expression), + we combine those chains together. For example, if we had: + + same == x < y == same + + ...and if 'operand_to_literal_hash' contained the same values for the indices + 0 and 3, we'd produce the following output: + + [("==", [0, 1, 2, 3]), ("<", [1, 2])] + + But if the 'operand_to_literal_hash' did *not* contain an entry, we'd instead + default to returning: + + [("==", [0, 1]), ("<", [1, 2]), ("==", [2, 3])] + + This function is currently only used to assist with type-narrowing refinements + and is extracted out to a helper function so we can unit test it. + """ + groups = { + op: DisjointDict() for op in operators_to_group + } # type: Dict[str, DisjointDict[Key, int]] + + simplified_operator_list = [] # type: List[Tuple[str, List[int]]] + last_operator = None # type: Optional[str] + current_indices = set() # type: Set[int] + current_hashes = set() # type: Set[Key] + for i, (operator, left_expr, right_expr) in enumerate(pairwise_comparisons): + if last_operator is None: + last_operator = operator + + if current_indices and (operator != last_operator or operator not in operators_to_group): + # If some of the operands in the chain are assignable, defer adding it: we might + # end up needing to merge it with other chains that appear later. + if len(current_hashes) == 0: + simplified_operator_list.append((last_operator, sorted(current_indices))) + else: + groups[last_operator].add_mapping(current_hashes, current_indices) + last_operator = operator + current_indices = set() + current_hashes = set() + + # Note: 'i' corresponds to the left operand index, so 'i + 1' is the + # right operand. + current_indices.add(i) + current_indices.add(i + 1) + + # We only ever want to combine operands/combine chains for these operators + if operator in operators_to_group: + left_hash = operand_to_literal_hash.get(i) + if left_hash is not None: + current_hashes.add(left_hash) + right_hash = operand_to_literal_hash.get(i + 1) + if right_hash is not None: + current_hashes.add(right_hash) + + if last_operator is not None: + if len(current_hashes) == 0: + simplified_operator_list.append((last_operator, sorted(current_indices))) + else: + groups[last_operator].add_mapping(current_hashes, current_indices) + + # Now that we know which chains happen to contain the same underlying expressions + # and can be merged together, add in this info back to the output. + for operator, disjoint_dict in groups.items(): + for keys, indices in disjoint_dict.items(): + simplified_operator_list.append((operator, sorted(indices))) + + # For stability, reorder list by the first operand index to appear + simplified_operator_list.sort(key=lambda item: item[1][0]) + return simplified_operator_list + + def is_typed_callable(c: Optional[Type]) -> bool: c = get_proper_type(c) if not c or not isinstance(c, CallableType): diff --git a/mypy/nodes.py b/mypy/nodes.py index 4ee3948fedd3..792a89a5fea4 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1750,6 +1750,13 @@ def __init__(self, operators: List[str], operands: List[Expression]) -> None: self.operands = operands self.method_types = [] + def pairwise(self) -> Iterator[Tuple[str, Expression, Expression]]: + """If this comparison expr is "a < b is c == d", yields the sequence + ("<", a, b), ("is", b, c), ("==", c, d) + """ + for i, operator in enumerate(self.operators): + yield operator, self.operands[i], self.operands[i + 1] + def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_comparison_expr(self) diff --git a/mypy/test/testinfer.py b/mypy/test/testinfer.py index 2e26e99453b8..e70d74530a99 100644 --- a/mypy/test/testinfer.py +++ b/mypy/test/testinfer.py @@ -1,16 +1,18 @@ """Test cases for type inference helper functions.""" -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict, Set from mypy.test.helpers import Suite, assert_equal from mypy.argmap import map_actuals_to_formals -from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED +from mypy.checker import group_comparison_operands, DisjointDict +from mypy.literals import Key +from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED, NameExpr from mypy.types import AnyType, TupleType, Type, TypeOfAny from mypy.test.typefixture import TypeFixture class MapActualsToFormalsSuite(Suite): - """Test cases for checkexpr.map_actuals_to_formals.""" + """Test cases for argmap.map_actuals_to_formals.""" def test_basic(self) -> None: self.assert_map([], [], []) @@ -223,3 +225,234 @@ def expand_callee_kinds(kinds_and_names: List[Union[int, Tuple[int, str]]] kinds.append(v) names.append(None) return kinds, names + + +class OperandDisjointDictSuite(Suite): + """Test cases for checker.DisjointDict, which is used for type inference with operands.""" + def new(self) -> DisjointDict[int, str]: + return DisjointDict() + + def test_independent_maps(self) -> None: + d = self.new() + d.add_mapping({0, 1}, {"group1"}) + d.add_mapping({2, 3, 4}, {"group2"}) + d.add_mapping({5, 6, 7}, {"group3"}) + + self.assertEqual(d.items(), [ + ({0, 1}, {"group1"}), + ({2, 3, 4}, {"group2"}), + ({5, 6, 7}, {"group3"}), + ]) + + def test_partial_merging(self) -> None: + d = self.new() + d.add_mapping({0, 1}, {"group1"}) + d.add_mapping({1, 2}, {"group2"}) + d.add_mapping({3, 4}, {"group3"}) + d.add_mapping({5, 0}, {"group4"}) + d.add_mapping({5, 6}, {"group5"}) + d.add_mapping({4, 7}, {"group6"}) + + self.assertEqual(d.items(), [ + ({0, 1, 2, 5, 6}, {"group1", "group2", "group4", "group5"}), + ({3, 4, 7}, {"group3", "group6"}), + ]) + + def test_full_merging(self) -> None: + d = self.new() + d.add_mapping({0, 1, 2}, {"a"}) + d.add_mapping({3, 4, 2}, {"b"}) + d.add_mapping({10, 11, 12}, {"c"}) + d.add_mapping({13, 14, 15}, {"d"}) + d.add_mapping({14, 10, 16}, {"e"}) + d.add_mapping({0, 10}, {"f"}) + + self.assertEqual(d.items(), [ + ({0, 1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16}, {"a", "b", "c", "d", "e", "f"}), + ]) + + def test_merge_with_multiple_overlaps(self) -> None: + d = self.new() + d.add_mapping({0, 1, 2}, {"a"}) + d.add_mapping({3, 4, 5}, {"b"}) + d.add_mapping({1, 2, 4, 5}, {"c"}) + d.add_mapping({6, 1, 2, 4, 5}, {"d"}) + d.add_mapping({6, 1, 2, 4, 5}, {"e"}) + + self.assertEqual(d.items(), [ + ({0, 1, 2, 3, 4, 5, 6}, {"a", "b", "c", "d", "e"}), + ]) + + +class OperandComparisonGroupingSuite(Suite): + """Test cases for checker.group_comparison_operands.""" + def literal_keymap(self, assignable_operands: Dict[int, NameExpr]) -> Dict[int, Key]: + output = {} # type: Dict[int, Key] + for index, expr in assignable_operands.items(): + output[index] = ('FakeExpr', expr.name) + return output + + def test_basic_cases(self) -> None: + # Note: the grouping function doesn't actually inspect the input exprs, so we + # just default to using NameExprs for simplicity. + x0 = NameExpr('x0') + x1 = NameExpr('x1') + x2 = NameExpr('x2') + x3 = NameExpr('x3') + x4 = NameExpr('x4') + + basic_input = [('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('==', x3, x4)] + + none_assignable = self.literal_keymap({}) + all_assignable = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4}) + + for assignable in [none_assignable, all_assignable]: + self.assertEqual( + group_comparison_operands(basic_input, assignable, set()), + [('==', [0, 1]), ('==', [1, 2]), ('<', [2, 3]), ('==', [3, 4])], + ) + self.assertEqual( + group_comparison_operands(basic_input, assignable, {'=='}), + [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4])], + ) + self.assertEqual( + group_comparison_operands(basic_input, assignable, {'<'}), + [('==', [0, 1]), ('==', [1, 2]), ('<', [2, 3]), ('==', [3, 4])], + ) + self.assertEqual( + group_comparison_operands(basic_input, assignable, {'==', '<'}), + [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4])], + ) + + def test_multiple_groups(self) -> None: + x0 = NameExpr('x0') + x1 = NameExpr('x1') + x2 = NameExpr('x2') + x3 = NameExpr('x3') + x4 = NameExpr('x4') + x5 = NameExpr('x5') + + self.assertEqual( + group_comparison_operands( + [('==', x0, x1), ('==', x1, x2), ('is', x2, x3), ('is', x3, x4)], + self.literal_keymap({}), + {'==', 'is'}, + ), + [('==', [0, 1, 2]), ('is', [2, 3, 4])], + ) + self.assertEqual( + group_comparison_operands( + [('==', x0, x1), ('==', x1, x2), ('==', x2, x3), ('==', x3, x4)], + self.literal_keymap({}), + {'==', 'is'}, + ), + [('==', [0, 1, 2, 3, 4])], + ) + self.assertEqual( + group_comparison_operands( + [('is', x0, x1), ('==', x1, x2), ('==', x2, x3), ('==', x3, x4)], + self.literal_keymap({}), + {'==', 'is'}, + ), + [('is', [0, 1]), ('==', [1, 2, 3, 4])], + ) + self.assertEqual( + group_comparison_operands( + [('is', x0, x1), ('is', x1, x2), ('<', x2, x3), ('==', x3, x4), ('==', x4, x5)], + self.literal_keymap({}), + {'==', 'is'}, + ), + [('is', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4, 5])], + ) + + def test_multiple_groups_coalescing(self) -> None: + x0 = NameExpr('x0') + x1 = NameExpr('x1') + x2 = NameExpr('x2') + x3 = NameExpr('x3') + x4 = NameExpr('x4') + + nothing_combined = [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4, 5])] + everything_combined = [('==', [0, 1, 2, 3, 4, 5]), ('<', [2, 3])] + + # Note: We do 'x4 == x0' at the very end! + two_groups = [ + ('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('==', x3, x4), ('==', x4, x0), + ] + self.assertEqual( + group_comparison_operands( + two_groups, + self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4, 5: x0}), + {'=='}, + ), + everything_combined, + "All vars are assignable, everything is combined" + ) + self.assertEqual( + group_comparison_operands( + two_groups, + self.literal_keymap({1: x1, 2: x2, 3: x3, 4: x4}), + {'=='}, + ), + nothing_combined, + "x0 is unassignable, so no combining" + ) + self.assertEqual( + group_comparison_operands( + two_groups, + self.literal_keymap({0: x0, 1: x1, 3: x3, 5: x0}), + {'=='}, + ), + everything_combined, + "Some vars are unassignable but x0 is, so we combine" + ) + self.assertEqual( + group_comparison_operands( + two_groups, + self.literal_keymap({0: x0, 5: x0}), + {'=='}, + ), + everything_combined, + "All vars are unassignable but x0 is, so we combine" + ) + + def test_multiple_groups_different_operators(self) -> None: + x0 = NameExpr('x0') + x1 = NameExpr('x1') + x2 = NameExpr('x2') + x3 = NameExpr('x3') + + groups = [('==', x0, x1), ('==', x1, x2), ('is', x2, x3), ('is', x3, x0)] + keymap = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x0}) + self.assertEqual( + group_comparison_operands(groups, keymap, {'==', 'is'}), + [('==', [0, 1, 2]), ('is', [2, 3, 4])], + "Different operators can never be combined" + ) + + def test_single_pair(self) -> None: + x0 = NameExpr('x0') + x1 = NameExpr('x1') + + single_comparison = [('==', x0, x1)] + expected_output = [('==', [0, 1])] + + assignable_combinations = [ + {}, {0: x0}, {1: x1}, {0: x0, 1: x1}, + ] # type: List[Dict[int, NameExpr]] + to_group_by = [set(), {'=='}, {'is'}] # type: List[Set[str]] + + for combo in assignable_combinations: + for operators in to_group_by: + keymap = self.literal_keymap(combo) + self.assertEqual( + group_comparison_operands(single_comparison, keymap, operators), + expected_output, + ) + + def test_empty_pair_list(self) -> None: + # This case should never occur in practice -- ComparisionExprs + # always contain at least one comparision. But in case it does... + + self.assertEqual(group_comparison_operands([], {}, set()), []) + self.assertEqual(group_comparison_operands([], {}, {'=='}), []) diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 241cd1ca049c..9d027f47192f 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -967,3 +967,153 @@ class A: self.b = Enum("x", [("foo", "bar")]) # E: Enum type as attribute is not supported reveal_type(A().b) # N: Revealed type is 'Any' + +[case testEnumReachabilityWithChaining] +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + +x: Foo +y: Foo + +if x is y is Foo.A: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' + +if x is Foo.A is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' + +if Foo.A is x is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' + +[builtins fixtures/primitives.pyi] + +[case testEnumReachabilityWithChainingDisjoint] +# flags: --warn-unreachable +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + + # Used to divide up a chained comparison into multiple identity groups + def __lt__(self, other: object) -> bool: return True + +x: Foo +y: Foo + +# No conflict +if x is Foo.A < y is Foo.B: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +else: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' + +# The standard output when we end up inferring two disjoint facts about the same expr +if x is Foo.A and x is Foo.B: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' + +# ..and we get the same result if we have two disjoint groups within the same comp expr +if x is Foo.A < x is Foo.B: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' +[builtins fixtures/primitives.pyi] + +[case testEnumReachabilityWithChainingDirectConflict] +# flags: --warn-unreachable +from enum import Enum +from typing_extensions import Literal, Final + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +x: Foo +if x is Foo.A is Foo.B: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' + +literal_a: Literal[Foo.A] +literal_b: Literal[Foo.B] +if x is literal_a is literal_b: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' + +final_a: Final = Foo.A +final_b: Final = Foo.B +if x is final_a is final_b: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' + +[builtins fixtures/primitives.pyi] + +[case testEnumReachabilityWithChainingBigDisjoints] +# flags: --warn-unreachable +from enum import Enum +from typing_extensions import Literal, Final + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + + def __lt__(self, other: object) -> bool: return True + +x0: Foo +x1: Foo +x2: Foo +x3: Foo +x4: Foo +x5: Foo + +if x0 is x1 is Foo.A is x2 < x3 is Foo.B is x4 is x5: + reveal_type(x0) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(x2) # N: Revealed type is 'Literal[__main__.Foo.A]' + + reveal_type(x3) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(x4) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(x5) # N: Revealed type is 'Literal[__main__.Foo.B]' +else: + reveal_type(x0) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(x2) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + + reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' + reveal_type(x4) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' + reveal_type(x5) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' +[builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 4b18cb59d1a7..9c40d550699e 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -528,6 +528,28 @@ else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] +[case testInferEqualsNotOptionalWithMultipleArgs] +from typing import Optional +x: Optional[int] +y: Optional[int] +if x == y == 1: + reveal_type(x) # N: Revealed type is 'builtins.int' + reveal_type(y) # N: Revealed type is 'builtins.int' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.int, None]' + reveal_type(y) # N: Revealed type is 'Union[builtins.int, None]' + +class A: pass +a: Optional[A] +b: Optional[A] +if a == b == object(): + reveal_type(a) # N: Revealed type is '__main__.A' + reveal_type(b) # N: Revealed type is '__main__.A' +else: + reveal_type(a) # N: Revealed type is 'Union[__main__.A, None]' + reveal_type(b) # N: Revealed type is 'Union[__main__.A, None]' +[builtins fixtures/ops.pyi] + [case testWarnNoReturnWorksWithStrictOptional] # flags: --warn-no-return def f() -> None: