diff --git a/mypy/build.py b/mypy/build.py index 7913eae9c6ed..2f556120d430 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -31,14 +31,12 @@ Callable, ClassVar, Dict, - Iterable, Iterator, Mapping, NamedTuple, NoReturn, Sequence, TextIO, - TypeVar, ) from typing_extensions import Final, TypeAlias as _TypeAlias @@ -47,6 +45,7 @@ import mypy.semanal_main from mypy.checker import TypeChecker from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error +from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort from mypy.indirection import TypeIndirectionVisitor from mypy.messages import MessageBuilder from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable, TypeInfo @@ -3466,15 +3465,8 @@ def sorted_components( edges = {id: deps_filtered(graph, vertices, id, pri_max) for id in vertices} sccs = list(strongly_connected_components(vertices, edges)) # Topsort. - sccsmap = {id: frozenset(scc) for scc in sccs for id in scc} - data: dict[AbstractSet[str], set[AbstractSet[str]]] = {} - for scc in sccs: - deps: set[AbstractSet[str]] = set() - for id in scc: - deps.update(sccsmap[x] for x in deps_filtered(graph, vertices, id, pri_max)) - data[frozenset(scc)] = deps res = [] - for ready in topsort(data): + for ready in topsort(prepare_sccs(sccs, edges)): # Sort the sets in ready by reversed smallest State.order. Examples: # # - If ready is [{x}, {y}], x.order == 1, y.order == 2, we get @@ -3499,100 +3491,6 @@ def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: in ] -def strongly_connected_components( - vertices: AbstractSet[str], edges: dict[str, list[str]] -) -> Iterator[set[str]]: - """Compute Strongly Connected Components of a directed graph. - - Args: - vertices: the labels for the vertices - edges: for each vertex, gives the target vertices of its outgoing edges - - Returns: - An iterator yielding strongly connected components, each - represented as a set of vertices. Each input vertex will occur - exactly once; vertices not part of a SCC are returned as - singleton sets. - - From https://code.activestate.com/recipes/578507/. - """ - identified: set[str] = set() - stack: list[str] = [] - index: dict[str, int] = {} - boundaries: list[int] = [] - - def dfs(v: str) -> Iterator[set[str]]: - index[v] = len(stack) - stack.append(v) - boundaries.append(index[v]) - - for w in edges[v]: - if w not in index: - yield from dfs(w) - elif w not in identified: - while index[w] < boundaries[-1]: - boundaries.pop() - - if boundaries[-1] == index[v]: - boundaries.pop() - scc = set(stack[index[v] :]) - del stack[index[v] :] - identified.update(scc) - yield scc - - for v in vertices: - if v not in index: - yield from dfs(v) - - -T = TypeVar("T") - - -def topsort(data: dict[T, set[T]]) -> Iterable[set[T]]: - """Topological sort. - - Args: - data: A map from vertices to all vertices that it has an edge - connecting it to. NOTE: This data structure - is modified in place -- for normalization purposes, - self-dependencies are removed and entries representing - orphans are added. - - Returns: - An iterator yielding sets of vertices that have an equivalent - ordering. - - Example: - Suppose the input has the following structure: - - {A: {B, C}, B: {D}, C: {D}} - - This is normalized to: - - {A: {B, C}, B: {D}, C: {D}, D: {}} - - The algorithm will yield the following values: - - {D} - {B, C} - {A} - - From https://code.activestate.com/recipes/577413/. - """ - # TODO: Use a faster algorithm? - for k, v in data.items(): - v.discard(k) # Ignore self dependencies. - for item in set.union(*data.values()) - set(data.keys()): - data[item] = set() - while True: - ready = {item for item, dep in data.items() if not dep} - if not ready: - break - yield ready - data = {item: (dep - ready) for item, dep in data.items() if item not in ready} - assert not data, f"A cyclic dependency exists amongst {data!r}" - - def missing_stubs_file(cache_dir: str) -> str: return os.path.join(cache_dir, "missing_stubs") diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 4b204a80c130..43896171eadc 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -12,7 +12,7 @@ import mypy.errorcodes as codes from mypy import applytype, erasetype, join, message_registry, nodes, operators, types from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals -from mypy.checkmember import analyze_member_access, type_object_type +from mypy.checkmember import analyze_member_access, freeze_all_type_vars, type_object_type from mypy.checkstrformat import StringFormatterChecker from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars from mypy.errors import ErrorWatcher, report_internal_error @@ -98,8 +98,15 @@ ) from mypy.semanal_enum import ENUM_BASES from mypy.state import state -from mypy.subtypes import is_equivalent, is_same_type, is_subtype, non_method_protocol_members +from mypy.subtypes import ( + find_member, + is_equivalent, + is_same_type, + is_subtype, + non_method_protocol_members, +) from mypy.traverser import has_await_expression +from mypy.type_visitor import TypeTranslator from mypy.typeanal import ( check_for_explicit_any, has_any_from_unimported_type, @@ -114,6 +121,7 @@ false_only, fixup_partial_type, function_type, + get_type_vars, is_literal_type_like, make_simplified_union, simple_literal_type, @@ -146,6 +154,7 @@ TypedDictType, TypeOfAny, TypeType, + TypeVarLikeType, TypeVarTupleType, TypeVarType, UninhabitedType, @@ -300,6 +309,7 @@ def __init__( # on whether current expression is a callee, to give better error messages # related to type context. self.is_callee = False + type_state.infer_polymorphic = self.chk.options.new_type_inference def reset(self) -> None: self.resolved_type = {} @@ -1791,6 +1801,51 @@ def infer_function_type_arguments( inferred_args[0] = self.named_type("builtins.str") elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg): self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context) + + if self.chk.options.new_type_inference and any( + a is None + or isinstance(get_proper_type(a), UninhabitedType) + or set(get_type_vars(a)) & set(callee_type.variables) + for a in inferred_args + ): + # If the regular two-phase inference didn't work, try inferring type + # variables while allowing for polymorphic solutions, i.e. for solutions + # potentially involving free variables. + # TODO: support the similar inference for return type context. + poly_inferred_args = infer_function_type_arguments( + callee_type, + arg_types, + arg_kinds, + formal_to_actual, + context=self.argument_infer_context(), + strict=self.chk.in_checked_function(), + allow_polymorphic=True, + ) + for i, pa in enumerate(get_proper_types(poly_inferred_args)): + if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa): + # Indicate that free variables should not be applied in the call below. + poly_inferred_args[i] = None + poly_callee_type = self.apply_generic_arguments( + callee_type, poly_inferred_args, context + ) + yes_vars = poly_callee_type.variables + no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables} + if not set(get_type_vars(poly_callee_type)) & no_vars: + # Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can + # be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed. + applied = apply_poly(poly_callee_type, yes_vars) + if applied is not None and poly_inferred_args != [UninhabitedType()] * len( + poly_inferred_args + ): + freeze_all_type_vars(applied) + return applied + # If it didn't work, erase free variables as , to avoid confusing errors. + inferred_args = [ + expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables}) + if a is not None + else None + for a in inferred_args + ] else: # In dynamically typed functions use implicit 'Any' types for # type variables. @@ -5393,6 +5448,92 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl return c.copy_modified(ret_type=new_ret_type) +def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> Optional[CallableType]: + """Make free type variables generic in the type if possible. + + This will translate the type `tp` while trying to create valid bindings for + type variables `poly_tvars` while traversing the type. This follows the same rules + as we do during semantic analysis phase, examples: + * Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T + * Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T) + * List[T] -> None (not possible) + """ + try: + return tp.copy_modified( + arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types], + ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)), + variables=[], + ) + except PolyTranslationError: + return None + + +class PolyTranslationError(Exception): + pass + + +class PolyTranslator(TypeTranslator): + """Make free type variables generic in the type if possible. + + See docstring for apply_poly() for details. + """ + + def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None: + self.poly_tvars = set(poly_tvars) + # This is a simplified version of TypeVarScope used during semantic analysis. + self.bound_tvars: set[TypeVarLikeType] = set() + self.seen_aliases: set[TypeInfo] = set() + + def visit_callable_type(self, t: CallableType) -> Type: + found_vars = set() + for arg in t.arg_types: + found_vars |= set(get_type_vars(arg)) & self.poly_tvars + + found_vars -= self.bound_tvars + self.bound_tvars |= found_vars + result = super().visit_callable_type(t) + self.bound_tvars -= found_vars + + assert isinstance(result, ProperType) and isinstance(result, CallableType) + result.variables = list(result.variables) + list(found_vars) + return result + + def visit_type_var(self, t: TypeVarType) -> Type: + if t in self.poly_tvars and t not in self.bound_tvars: + raise PolyTranslationError() + return super().visit_type_var(t) + + def visit_param_spec(self, t: ParamSpecType) -> Type: + # TODO: Support polymorphic apply for ParamSpec. + raise PolyTranslationError() + + def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: + # TODO: Support polymorphic apply for TypeVarTuple. + raise PolyTranslationError() + + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + if not t.args: + return t.copy_modified() + if not t.is_recursive: + return get_proper_type(t).accept(self) + # We can't handle polymorphic application for recursive generic aliases + # without risking an infinite recursion, just give up for now. + raise PolyTranslationError() + + def visit_instance(self, t: Instance) -> Type: + # There is the same problem with callback protocols as with aliases + # (callback protocols are essentially more flexible aliases to callables). + # Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T]. + if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]: + if t.type in self.seen_aliases: + raise PolyTranslationError() + self.seen_aliases.add(t.type) + call = find_member("__call__", t, t, is_operator=True) + assert call is not None + return call.accept(self) + return super().visit_instance(t) + + class ArgInferSecondPassQuery(types.BoolTypeQuery): """Query whether an argument type should be inferred in the second pass. diff --git a/mypy/constraints.py b/mypy/constraints.py index 33230871b505..803b9819be6f 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -886,7 +886,30 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: param_spec = template.param_spec() if param_spec is None: # FIX verify argument counts - # FIX what if one of the functions is generic + # TODO: Erase template variables if it is generic? + if ( + type_state.infer_polymorphic + and cactual.variables + and cactual.param_spec() is None + # Technically, the correct inferred type for application of e.g. + # Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic + # like U -> U, should be Callable[..., Any], but if U is a self-type, we can + # allow it to leak, to be later bound to self. A bunch of existing code + # depends on this old behaviour. + and not any(tv.id.raw_id == 0 for tv in cactual.variables) + ): + # If actual is generic, unify it with template. Note: this is + # not an ideal solution (which would be adding the generic variables + # to the constraint inference set), but it's a good first approximation, + # and this will prevent leaking these variables in the solutions. + # Note: this may infer constraints like T <: S or T <: List[S] + # that contain variables in the target. + unified = mypy.subtypes.unify_generic_callable( + cactual, template, ignore_return=True + ) + if unified is not None: + cactual = unified + res.extend(infer_constraints(cactual, template, neg_op(self.direction))) # We can't infer constraints from arguments if the template is Callable[..., T] # (with literal '...'). diff --git a/mypy/graph_utils.py b/mypy/graph_utils.py new file mode 100644 index 000000000000..399301a6b0fd --- /dev/null +++ b/mypy/graph_utils.py @@ -0,0 +1,112 @@ +"""Helpers for manipulations with graphs.""" + +from __future__ import annotations + +from typing import AbstractSet, Iterable, Iterator, TypeVar + +T = TypeVar("T") + + +def strongly_connected_components( + vertices: AbstractSet[T], edges: dict[T, list[T]] +) -> Iterator[set[T]]: + """Compute Strongly Connected Components of a directed graph. + + Args: + vertices: the labels for the vertices + edges: for each vertex, gives the target vertices of its outgoing edges + + Returns: + An iterator yielding strongly connected components, each + represented as a set of vertices. Each input vertex will occur + exactly once; vertices not part of a SCC are returned as + singleton sets. + + From https://code.activestate.com/recipes/578507/. + """ + identified: set[T] = set() + stack: list[T] = [] + index: dict[T, int] = {} + boundaries: list[int] = [] + + def dfs(v: T) -> Iterator[set[T]]: + index[v] = len(stack) + stack.append(v) + boundaries.append(index[v]) + + for w in edges[v]: + if w not in index: + yield from dfs(w) + elif w not in identified: + while index[w] < boundaries[-1]: + boundaries.pop() + + if boundaries[-1] == index[v]: + boundaries.pop() + scc = set(stack[index[v] :]) + del stack[index[v] :] + identified.update(scc) + yield scc + + for v in vertices: + if v not in index: + yield from dfs(v) + + +def prepare_sccs( + sccs: list[set[T]], edges: dict[T, list[T]] +) -> dict[AbstractSet[T], set[AbstractSet[T]]]: + """Use original edges to organize SCCs in a graph by dependencies between them.""" + sccsmap = {v: frozenset(scc) for scc in sccs for v in scc} + data: dict[AbstractSet[T], set[AbstractSet[T]]] = {} + for scc in sccs: + deps: set[AbstractSet[T]] = set() + for v in scc: + deps.update(sccsmap[x] for x in edges[v]) + data[frozenset(scc)] = deps + return data + + +def topsort(data: dict[T, set[T]]) -> Iterable[set[T]]: + """Topological sort. + + Args: + data: A map from vertices to all vertices that it has an edge + connecting it to. NOTE: This data structure + is modified in place -- for normalization purposes, + self-dependencies are removed and entries representing + orphans are added. + + Returns: + An iterator yielding sets of vertices that have an equivalent + ordering. + + Example: + Suppose the input has the following structure: + + {A: {B, C}, B: {D}, C: {D}} + + This is normalized to: + + {A: {B, C}, B: {D}, C: {D}, D: {}} + + The algorithm will yield the following values: + + {D} + {B, C} + {A} + + From https://code.activestate.com/recipes/577413/. + """ + # TODO: Use a faster algorithm? + for k, v in data.items(): + v.discard(k) # Ignore self dependencies. + for item in set.union(*data.values()) - set(data.keys()): + data[item] = set() + while True: + ready = {item for item, dep in data.items() if not dep} + if not ready: + break + yield ready + data = {item: (dep - ready) for item, dep in data.items() if item not in ready} + assert not data, f"A cyclic dependency exists amongst {data!r}" diff --git a/mypy/infer.py b/mypy/infer.py index fbec3d7c4278..66ca4169e2ff 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -36,6 +36,7 @@ def infer_function_type_arguments( formal_to_actual: list[list[int]], context: ArgumentInferContext, strict: bool = True, + allow_polymorphic: bool = False, ) -> list[Type | None]: """Infer the type arguments of a generic function. @@ -57,7 +58,7 @@ def infer_function_type_arguments( # Solve constraints. type_vars = callee_type.type_var_ids() - return solve_constraints(type_vars, constraints, strict) + return solve_constraints(type_vars, constraints, strict, allow_polymorphic) def infer_type_arguments( diff --git a/mypy/main.py b/mypy/main.py index 81a0a045745b..b60c5b2a6bba 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -983,6 +983,11 @@ def add_invertible_flag( dest="custom_typing_module", help="Use a custom typing module", ) + internals_group.add_argument( + "--new-type-inference", + action="store_true", + help="Enable new experimental type inference algorithm", + ) internals_group.add_argument( "--disable-recursive-aliases", action="store_true", diff --git a/mypy/options.py b/mypy/options.py index 2785d2034c54..f75734124eb0 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -344,6 +344,8 @@ def __init__(self) -> None: # skip most errors after this many messages have been reported. # -1 means unlimited. self.many_errors_threshold = defaults.MANY_ERRORS_THRESHOLD + # Enable new experimental type inference algorithm. + self.new_type_inference = False # Disable recursive type aliases (currently experimental) self.disable_recursive_aliases = False # Deprecated reverse version of the above, do not use. diff --git a/mypy/semanal.py b/mypy/semanal.py index 073bde661617..249a57d550b2 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -234,7 +234,6 @@ fix_instance_types, has_any_from_unimported_type, no_subscript_builtin_alias, - remove_dups, type_constructors, ) from mypy.typeops import function_type, get_type_vars, try_getting_str_literals_from_type @@ -277,6 +276,7 @@ get_proper_type, get_proper_types, is_named_instance, + remove_dups, ) from mypy.types_utils import is_invalid_recursive_alias, store_argument_type from mypy.typevars import fill_typevars diff --git a/mypy/solve.py b/mypy/solve.py index b8304d29c1ce..6693d66f3479 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -2,27 +2,35 @@ from __future__ import annotations -from collections import defaultdict +from typing import Iterable -from mypy.constraints import SUPERTYPE_OF, Constraint +from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op +from mypy.expandtype import expand_type +from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort from mypy.join import join_types from mypy.meet import meet_types from mypy.subtypes import is_subtype +from mypy.typeops import get_type_vars from mypy.types import ( AnyType, ProperType, Type, TypeOfAny, TypeVarId, + TypeVarType, UninhabitedType, UnionType, get_proper_type, + remove_dups, ) from mypy.typestate import type_state def solve_constraints( - vars: list[TypeVarId], constraints: list[Constraint], strict: bool = True + vars: list[TypeVarId], + constraints: list[Constraint], + strict: bool = True, + allow_polymorphic: bool = False, ) -> list[Type | None]: """Solve type constraints. @@ -33,62 +41,355 @@ def solve_constraints( pick NoneType as the value of the type variable. If strict=False, pick AnyType. """ + if not vars: + return [] + if allow_polymorphic: + # Constraints like T :> S and S <: T are semantically the same, but they are + # represented differently. Normalize the constraint list w.r.t this equivalence. + constraints = normalize_constraints(constraints, vars) + # Collect a list of constraints for each type variable. - cmap: dict[TypeVarId, list[Constraint]] = defaultdict(list) + cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars} for con in constraints: - cmap[con.type_var].append(con) + if con.type_var in vars: + cmap[con.type_var].append(con) + + if allow_polymorphic: + solutions = solve_non_linear(vars, constraints, cmap) + else: + solutions = {} + for tv, cs in cmap.items(): + if not cs: + continue + lowers = [c.target for c in cs if c.op == SUPERTYPE_OF] + uppers = [c.target for c in cs if c.op == SUBTYPE_OF] + solutions[tv] = solve_one(lowers, uppers, []) res: list[Type | None] = [] + for v in vars: + if v in solutions: + res.append(solutions[v]) + else: + # No constraints for type variable -- 'UninhabitedType' is the most specific type. + candidate: Type + if strict: + candidate = UninhabitedType() + candidate.ambiguous = True + else: + candidate = AnyType(TypeOfAny.special_form) + res.append(candidate) + return res + + +def solve_non_linear( + vars: list[TypeVarId], constraints: list[Constraint], cmap: dict[TypeVarId, list[Constraint]] +) -> dict[TypeVarId, Type | None]: + """Solve set of constraints that may include non-linear ones, like T <: List[S]. - # Solve each type variable separately. + The whole algorithm consists of five steps: + * Propagate via linear constraints to get all possible constraints for each variable + * Find dependencies between type variables, group them in SCCs, and sort topologically + * Check all SCC are intrinsically linear, we can't solve (express) T <: List[T] + * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC) + * Solve constraints iteratively starting from leafs, updating targets after each step. + """ + extra_constraints = [] for tvar in vars: - bottom: Type | None = None - top: Type | None = None - candidate: Type | None = None - - # Process each constraint separately, and calculate the lower and upper - # bounds based on constraints. Note that we assume that the constraint - # targets do not have constraint references. - for c in cmap.get(tvar, []): - if c.op == SUPERTYPE_OF: - if bottom is None: - bottom = c.target - else: - if type_state.infer_unions: - # This deviates from the general mypy semantics because - # recursive types are union-heavy in 95% of cases. - bottom = UnionType.make_union([bottom, c.target]) - else: - bottom = join_types(bottom, c.target) + extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap)) + extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap)) + constraints += remove_dups(extra_constraints) + + # Recompute constraint map after propagating. + cmap = {tv: [] for tv in vars} + for con in constraints: + if con.type_var in vars: + cmap[con.type_var].append(con) + + dmap = compute_dependencies(cmap) + sccs = list(strongly_connected_components(set(vars), dmap)) + if all(check_linear(scc, cmap) for scc in sccs): + raw_batches = list(topsort(prepare_sccs(sccs, dmap))) + leafs = raw_batches[0] + free_vars = [] + for scc in leafs: + # If all constrain targets in this SCC are type variables within the + # same SCC then the only meaningful solution we can express, is that + # each variable is equal to a new free variable. For example if we + # have T <: S, S <: U, we deduce: T = S = U = . + if all( + isinstance(c.target, TypeVarType) and c.target.id in vars + for tv in scc + for c in cmap[tv] + ): + # For convenience with current type application machinery, we randomly + # choose one of the existing type variables in SCC and designate it as free + # instead of defining a new type variable as a common solution. + # TODO: be careful about upper bounds (or values) when introducing free vars. + free_vars.append(sorted(scc, key=lambda x: x.raw_id)[0]) + + # Flatten the SCCs that are independent, we can solve them together, + # since we don't need to update any targets in between. + batches = [] + for batch in raw_batches: + next_bc = [] + for scc in batch: + next_bc.extend(list(scc)) + batches.append(next_bc) + + solutions: dict[TypeVarId, Type | None] = {} + for flat_batch in batches: + solutions.update(solve_iteratively(flat_batch, cmap, free_vars)) + # We remove the solutions like T = T for free variables. This will indicate + # to the apply function, that they should not be touched. + # TODO: return list of free type variables explicitly, this logic is fragile + # (but if we do, we need to be careful everything works in incremental modes). + for tv in free_vars: + if tv in solutions: + del solutions[tv] + return solutions + return {} + + +def solve_iteratively( + batch: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] +) -> dict[TypeVarId, Type | None]: + """Solve constraints sequentially, updating constraint targets after each step. + + We solve for type variables that appear in `batch`. If a constraint target is not constant + (i.e. constraint looks like T :> F[S, ...]), we substitute solutions found so far in + the target F[S, ...]. This way we can gradually solve for all variables in the batch taking + one solvable variable at a time (i.e. such a variable that has at least one constant bound). + + Importantly, variables in free_vars are considered constants, so for example if we have just + one initial constraint T <: List[S], we will have two SCCs {T} and {S}, then we first + designate S as free, and therefore T = List[S] is a valid solution for T. + """ + solutions = {} + relevant_constraints = [] + for tv in batch: + relevant_constraints.extend(cmap.get(tv, [])) + lowers, uppers = transitive_closure(batch, relevant_constraints) + s_batch = set(batch) + not_allowed_vars = [v for v in batch if v not in free_vars] + while s_batch: + for tv in s_batch: + if any(not get_vars(l, not_allowed_vars) for l in lowers[tv]) or any( + not get_vars(u, not_allowed_vars) for u in uppers[tv] + ): + solvable_tv = tv + break + else: + break + # Solve each solvable type variable separately. + s_batch.remove(solvable_tv) + result = solve_one(lowers[solvable_tv], uppers[solvable_tv], not_allowed_vars) + solutions[solvable_tv] = result + if result is None: + # TODO: support backtracking lower/upper bound choices + # (will require switching this function from iterative to recursive). + continue + # Update the (transitive) constraints if there is a solution. + subs = {solvable_tv: result} + lowers = {tv: {expand_type(l, subs) for l in lowers[tv]} for tv in lowers} + uppers = {tv: {expand_type(u, subs) for u in uppers[tv]} for tv in uppers} + for v in cmap: + for c in cmap[v]: + c.target = expand_type(c.target, subs) + return solutions + + +def solve_one( + lowers: Iterable[Type], uppers: Iterable[Type], not_allowed_vars: list[TypeVarId] +) -> Type | None: + """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds.""" + bottom: Type | None = None + top: Type | None = None + candidate: Type | None = None + + # Process each bound separately, and calculate the lower and upper + # bounds based on constraints. Note that we assume that the constraint + # targets do not have constraint references. + for target in lowers: + # There may be multiple steps needed to solve all vars within a + # (linear) SCC. We ignore targets pointing to not yet solved vars. + if get_vars(target, not_allowed_vars): + continue + if bottom is None: + bottom = target + else: + if type_state.infer_unions: + # This deviates from the general mypy semantics because + # recursive types are union-heavy in 95% of cases. + bottom = UnionType.make_union([bottom, target]) else: - if top is None: - top = c.target - else: - top = meet_types(top, c.target) - - p_top = get_proper_type(top) - p_bottom = get_proper_type(bottom) - if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType): - source_any = top if isinstance(p_top, AnyType) else bottom - assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType) - res.append(AnyType(TypeOfAny.from_another_any, source_any=source_any)) + bottom = join_types(bottom, target) + + for target in uppers: + # Same as above. + if get_vars(target, not_allowed_vars): continue - elif bottom is None: - if top: - candidate = top + if top is None: + top = target + else: + top = meet_types(top, target) + + p_top = get_proper_type(top) + p_bottom = get_proper_type(bottom) + if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType): + source_any = top if isinstance(p_top, AnyType) else bottom + assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType) + return AnyType(TypeOfAny.from_another_any, source_any=source_any) + elif bottom is None: + if top: + candidate = top + else: + # No constraints for type variable + return None + elif top is None: + candidate = bottom + elif is_subtype(bottom, top): + candidate = bottom + else: + candidate = None + return candidate + + +def normalize_constraints( + constraints: list[Constraint], vars: list[TypeVarId] +) -> list[Constraint]: + """Normalize list of constraints (to simplify life for the non-linear solver). + + This includes two things currently: + * Complement T :> S by S <: T + * Remove strict duplicates + """ + res = constraints.copy() + for c in constraints: + if isinstance(c.target, TypeVarType): + res.append(Constraint(c.target, neg_op(c.op), c.origin_type_var)) + return [c for c in remove_dups(constraints) if c.type_var in vars] + + +def propagate_constraints_for( + var: TypeVarId, direction: int, cmap: dict[TypeVarId, list[Constraint]] +) -> list[Constraint]: + """Propagate via linear constraints to get additional constraints for `var`. + + For example if we have constraints: + [T <: int, S <: T, S :> str] + we can add two more + [S <: int, T :> str] + """ + extra_constraints = [] + seen = set() + front = [var] + if cmap[var]: + var_def = cmap[var][0].origin_type_var + else: + return [] + while front: + tv = front.pop(0) + for c in cmap[tv]: + if ( + isinstance(c.target, TypeVarType) + and c.target.id not in seen + and c.target.id in cmap + and c.op == direction + ): + front.append(c.target.id) + seen.add(c.target.id) + elif c.op == direction: + new_c = Constraint(var_def, direction, c.target) + if new_c not in cmap[var]: + extra_constraints.append(new_c) + return extra_constraints + + +def transitive_closure( + tvars: list[TypeVarId], constraints: list[Constraint] +) -> tuple[dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]]]: + """Find transitive closure for given constraints on type variables. + + Transitive closure gives maximal set of lower/upper bounds for each type variable, + such that we cannot deduce any further bounds by chaining other existing bounds. + + For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive + closure is given by: + * {} <: T <: {S, U, int} + * {T} <: S <: {U, int} + * {T, S} <: U <: {int} + """ + # TODO: merge propagate_constraints_for() into this function. + # TODO: add secondary constraints here to make the algorithm complete. + uppers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} + lowers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} + graph: set[tuple[TypeVarId, TypeVarId]] = set() + + # Prime the closure with the initial trivial values. + for c in constraints: + if isinstance(c.target, TypeVarType) and c.target.id in tvars: + if c.op == SUBTYPE_OF: + graph.add((c.type_var, c.target.id)) else: - # No constraints for type variable -- 'UninhabitedType' is the most specific type. - if strict: - candidate = UninhabitedType() - candidate.ambiguous = True - else: - candidate = AnyType(TypeOfAny.special_form) - elif top is None: - candidate = bottom - elif is_subtype(bottom, top): - candidate = bottom + graph.add((c.target.id, c.type_var)) + if c.op == SUBTYPE_OF: + uppers[c.type_var].add(c.target) else: - candidate = None - res.append(candidate) + lowers[c.type_var].add(c.target) + + # At this stage we know that constant bounds have been propagated already, so we + # only need to propagate linear constraints. + for c in constraints: + if isinstance(c.target, TypeVarType) and c.target.id in tvars: + if c.op == SUBTYPE_OF: + lower, upper = c.type_var, c.target.id + else: + lower, upper = c.target.id, c.type_var + extras = { + (l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph + } + graph |= extras + for u in tvars: + if (upper, u) in graph: + lowers[u] |= lowers[lower] + for l in tvars: + if (l, lower) in graph: + uppers[l] |= uppers[upper] + return lowers, uppers + +def compute_dependencies( + cmap: dict[TypeVarId, list[Constraint]] +) -> dict[TypeVarId, list[TypeVarId]]: + """Compute dependencies between type variables induced by constraints. + + If we have a constraint like T <: List[S], we say that T depends on S, since + we will need to solve for S first before we can solve for T. + """ + res = {} + vars = list(cmap.keys()) + for tv in cmap: + deps = set() + for c in cmap[tv]: + deps |= get_vars(c.target, vars) + res[tv] = list(deps) return res + + +def check_linear(scc: set[TypeVarId], cmap: dict[TypeVarId, list[Constraint]]) -> bool: + """Check there are only linear constraints between type variables in SCC. + + Linear are constraints like T <: S (while T <: F[S] are non-linear). + """ + for tv in scc: + if any( + get_vars(c.target, list(scc)) and not isinstance(c.target, TypeVarType) + for c in cmap[tv] + ): + return False + return True + + +def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]: + """Find type variables for which we are solving in a target type.""" + return {tv.id for tv in get_type_vars(target)} & set(vars) diff --git a/mypy/test/testgraph.py b/mypy/test/testgraph.py index ce7697142ff2..b0d148d5ae9c 100644 --- a/mypy/test/testgraph.py +++ b/mypy/test/testgraph.py @@ -5,17 +5,10 @@ import sys from typing import AbstractSet -from mypy.build import ( - BuildManager, - BuildSourceSet, - State, - order_ascc, - sorted_components, - strongly_connected_components, - topsort, -) +from mypy.build import BuildManager, BuildSourceSet, State, order_ascc, sorted_components from mypy.errors import Errors from mypy.fscache import FileSystemCache +from mypy.graph_utils import strongly_connected_components, topsort from mypy.modulefinder import SearchPaths from mypy.options import Options from mypy.plugin import Plugin diff --git a/mypy/typeanal.py b/mypy/typeanal.py index d1e6e315b9e3..39a44a289365 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -1843,19 +1843,6 @@ def set_any_tvars( return TypeAliasType(node, args, newline, newcolumn) -def remove_dups(tvars: list[T]) -> list[T]: - if len(tvars) <= 1: - return tvars - # Get unique elements in order of appearance - all_tvars: set[T] = set() - new_tvars: list[T] = [] - for t in tvars: - if t not in all_tvars: - new_tvars.append(t) - all_tvars.add(t) - return new_tvars - - def flatten_tvars(lists: list[list[T]]) -> list[T]: result: list[T] = [] for lst in lists: diff --git a/mypy/types.py b/mypy/types.py index 5fbdd385826c..e4c22da12603 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -3475,6 +3475,19 @@ def callable_with_ellipsis(any_type: AnyType, ret_type: Type, fallback: Instance ) +def remove_dups(types: list[T]) -> list[T]: + if len(types) <= 1: + return types + # Get unique elements in order of appearance + all_types: set[T] = set() + new_types: list[T] = [] + for t in types: + if t not in all_types: + new_types.append(t) + all_types.add(t) + return new_types + + # This cyclic import is unfortunate, but to avoid it we would need to move away all uses # of get_proper_type() from types.py. Majority of them have been removed, but few remaining # are quite tricky to get rid of, but ultimately we want to do it at some point. diff --git a/mypy/typestate.py b/mypy/typestate.py index 9f65481e5e94..ff5933af5928 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -93,6 +93,9 @@ class TypeState: inferring: Final[list[tuple[Type, Type]]] # Whether to use joins or unions when solving constraints, see checkexpr.py for details. infer_unions: bool + # Whether to use new type inference algorithm that can infer polymorphic types. + # This is temporary and will be removed soon when new algorithm is more polished. + infer_polymorphic: bool # N.B: We do all of the accesses to these properties through # TypeState, instead of making these classmethods and accessing @@ -110,6 +113,7 @@ def __init__(self) -> None: self._assuming_proper = [] self.inferring = [] self.infer_unions = False + self.infer_polymorphic = False def is_assumed_subtype(self, left: Type, right: Type) -> bool: for l, r in reversed(self._assuming): @@ -311,7 +315,7 @@ def add_all_protocol_deps(self, deps: dict[str, set[str]]) -> None: def reset_global_state() -> None: """Reset most existing global state. - Currently most of it is in this module. Few exceptions are strict optional status and + Currently most of it is in this module. Few exceptions are strict optional status and functools.lru_cache. """ type_state.reset_all_subtype_caches() diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 06b80be85096..b78fd21d4817 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2733,3 +2733,250 @@ dict1: Any dict2 = {"a": C1(), **{x: C2() for x in dict1}} reveal_type(dict2) # N: Revealed type is "builtins.dict[Any, __main__.B]" [builtins fixtures/dict.pyi] + +-- Type inference for generic decorators applied to generic callables +-- ------------------------------------------------------------------ + +[case testInferenceAgainstGenericCallable] +# flags: --new-type-inference +from typing import TypeVar, Callable, List + +X = TypeVar('X') +T = TypeVar('T') + +def foo(x: Callable[[int], X]) -> List[X]: + ... +def bar(x: Callable[[X], int]) -> List[X]: + ... + +def id(x: T) -> T: + ... +reveal_type(foo(id)) # N: Revealed type is "builtins.list[builtins.int]" +reveal_type(bar(id)) # N: Revealed type is "builtins.list[builtins.int]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableNoLeak] +# flags: --new-type-inference +from typing import TypeVar, Callable + +T = TypeVar('T') + +def f(x: Callable[..., T]) -> T: + return x() + +def tpl(x: T) -> T: + return x + +# This is valid because of "..." +reveal_type(f(tpl)) # N: Revealed type is "Any" +[out] + +[case testInferenceAgainstGenericCallableChain] +# flags: --new-type-inference +from typing import TypeVar, Callable, List + +X = TypeVar('X') +T = TypeVar('T') + +def chain(f: Callable[[X], T], g: Callable[[T], int]) -> Callable[[X], int]: ... +def id(x: T) -> T: + ... +reveal_type(chain(id, id)) # N: Revealed type is "def (builtins.int) -> builtins.int" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableGeneric] +# flags: --new-type-inference +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]: + ... +def id(x: U) -> U: + ... +reveal_type(dec(id)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]" + +@dec +def same(x: U) -> U: + ... +reveal_type(same) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]" +reveal_type(same(42)) # N: Revealed type is "builtins.list[builtins.int]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableGenericReverse] +# flags: --new-type-inference +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def dec(f: Callable[[S], List[T]]) -> Callable[[S], T]: + ... +def id(x: U) -> U: + ... +reveal_type(dec(id)) # N: Revealed type is "def [T] (builtins.list[T`2]) -> T`2" + +@dec +def same(x: U) -> U: + ... +reveal_type(same) # N: Revealed type is "def [T] (builtins.list[T`4]) -> T`4" +reveal_type(same([42])) # N: Revealed type is "builtins.int" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableGenericArg] +# flags: --new-type-inference +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def dec(f: Callable[[S], T]) -> Callable[[S], T]: + ... +def test(x: U) -> List[U]: + ... +reveal_type(dec(test)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]" + +@dec +def single(x: U) -> List[U]: + ... +reveal_type(single) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]" +reveal_type(single(42)) # N: Revealed type is "builtins.list[builtins.int]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableGenericChain] +# flags: --new-type-inference +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def comb(f: Callable[[T], S], g: Callable[[S], U]) -> Callable[[T], U]: ... +def id(x: U) -> U: + ... +reveal_type(comb(id, id)) # N: Revealed type is "def [T] (T`1) -> T`1" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableGenericNonLinear] +# flags: --new-type-inference +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: + def inner(x: S) -> List[T]: + return [f(x) for f in fs] + return inner + +# Errors caused by arg *name* mismatch are truly cryptic, but this is a known issue :/ +def id(__x: U) -> U: + ... +fs = [id, id, id] +reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]" +reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`5) -> builtins.list[S`5]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCurry] +# flags: --new-type-inference +from typing import Callable, List, TypeVar + +S = TypeVar("S") +T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") + +def dec1(f: Callable[[T], S]) -> Callable[[], Callable[[T], S]]: ... +def dec2(f: Callable[[T, U], S]) -> Callable[[U], Callable[[T], S]]: ... + +def test1(x: V) -> V: ... +def test2(x: V, y: V) -> V: ... + +reveal_type(dec1(test1)) # N: Revealed type is "def () -> def [T] (T`1) -> T`1" +# TODO: support this situation +reveal_type(dec2(test2)) # N: Revealed type is "def (builtins.object) -> def (builtins.object) -> builtins.object" +[builtins fixtures/paramspec.pyi] + +[case testInferenceAgainstGenericCallableGenericAlias] +# flags: --new-type-inference +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +A = Callable[[S], T] +B = Callable[[S], List[T]] + +def dec(f: A[S, T]) -> B[S, T]: + ... +def id(x: U) -> U: + ... +reveal_type(dec(id)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableGenericProtocol] +# flags: --strict-optional --new-type-inference +from typing import TypeVar, Protocol, Generic, Optional + +T = TypeVar('T') + +class F(Protocol[T]): + def __call__(self, __x: T) -> T: ... + +def lift(f: F[T]) -> F[Optional[T]]: ... +def g(x: T) -> T: + return x + +reveal_type(lift(g)) # N: Revealed type is "def [T] (Union[T`1, None]) -> Union[T`1, None]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericSplitOrder] +# flags: --strict-optional --new-type-inference +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def dec(f: Callable[[T], S], g: Callable[[T], int]) -> Callable[[T], List[S]]: ... +def id(x: U) -> U: + ... + +reveal_type(dec(id, id)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericSplitOrderGeneric] +# flags: --strict-optional --new-type-inference +from typing import TypeVar, Callable, Tuple + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') +V = TypeVar('V') + +def dec(f: Callable[[T], S], g: Callable[[T], U]) -> Callable[[T], Tuple[S, U]]: ... +def id(x: V) -> V: + ... + +reveal_type(dec(id, id)) # N: Revealed type is "def [T] (T`1) -> Tuple[T`1, T`1]" +[builtins fixtures/tuple.pyi] + +[case testInferenceAgainstGenericEllipsisSelfSpecialCase] +# flags: --new-type-inference +from typing import Self, Callable, TypeVar + +T = TypeVar("T") +def dec(f: Callable[..., T]) -> Callable[..., T]: ... + +class C: + @dec + def test(self) -> Self: ... + +c: C +reveal_type(c.test()) # N: Revealed type is "__main__.C" diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 901e73008d56..cafcaca0a14c 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1040,6 +1040,28 @@ reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`-1)" reveal_type(jf(1)) # N: Revealed type is "None" [builtins fixtures/paramspec.pyi] +[case testGenericsInInferredParamspecReturn] +# flags: --new-type-inference +from typing import Callable, TypeVar, Generic +from typing_extensions import ParamSpec + +_P = ParamSpec("_P") +_T = TypeVar("_T") + +class Job(Generic[_P, _T]): + def __init__(self, target: Callable[_P, _T]) -> None: ... + def into_callable(self) -> Callable[_P, _T]: ... + +def generic_f(x: _T) -> _T: ... + +j = Job(generic_f) +reveal_type(j) # N: Revealed type is "__main__.Job[[x: _T`-1], _T`-1]" + +jf = j.into_callable() +reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`-1) -> _T`-1" +reveal_type(jf(1)) # N: Revealed type is "builtins.int" +[builtins fixtures/paramspec.pyi] + [case testStackedConcatenateIsIllegal] from typing_extensions import Concatenate, ParamSpec from typing import Callable @@ -1520,3 +1542,18 @@ def identity(func: Callable[P, None]) -> Callable[P, None]: ... @identity def f(f: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ... [builtins fixtures/paramspec.pyi] + +[case testParamSpecDecoratorAppliedToGeneric] +# flags: --new-type-inference +from typing import Callable, List, TypeVar +from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") +U = TypeVar("U") + +def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ... +def test(x: U) -> U: ... +reveal_type(dec) # N: Revealed type is "def [P, T] (f: def (*P.args, **P.kwargs) -> T`-2) -> def (*P.args, **P.kwargs) -> builtins.list[T`-2]" +reveal_type(dec(test)) # N: Revealed type is "def [U] (x: U`-1) -> builtins.list[U`-1]" +[builtins fixtures/paramspec.pyi] diff --git a/test-data/unit/check-plugin-attrs.test b/test-data/unit/check-plugin-attrs.test index 9aa31c1ed10b..5b8c361906a8 100644 --- a/test-data/unit/check-plugin-attrs.test +++ b/test-data/unit/check-plugin-attrs.test @@ -1173,12 +1173,13 @@ class A: [builtins fixtures/bool.pyi] [case testAttrsFactoryBadReturn] +# flags: --new-type-inference import attr def my_factory() -> int: return 7 @attr.s class A: - x: int = attr.ib(factory=list) # E: Incompatible types in assignment (expression has type "List[T]", variable has type "int") + x: int = attr.ib(factory=list) # E: Incompatible types in assignment (expression has type "List[]", variable has type "int") y: str = attr.ib(factory=my_factory) # E: Incompatible types in assignment (expression has type "int", variable has type "str") [builtins fixtures/list.pyi] diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 8c82b6843a3b..eada8cf6fa85 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -833,6 +833,7 @@ _program.py:3: error: Dict entry 1 has incompatible type "str": "str"; expected _program.py:5: error: "Dict[str, int]" has no attribute "xyz" [case testDefaultDict] +# flags: --new-type-inference import typing as t from collections import defaultdict @@ -858,11 +859,11 @@ class MyDDict(t.DefaultDict[int,T], t.Generic[T]): MyDDict(dict)['0'] MyDDict(dict)[0] [out] -_program.py:6: error: Argument 1 to "defaultdict" has incompatible type "Type[List[Any]]"; expected "Callable[[], str]" -_program.py:9: error: Invalid index type "str" for "defaultdict[int, str]"; expected type "int" -_program.py:9: error: Incompatible types in assignment (expression has type "int", target has type "str") -_program.py:19: error: Argument 1 to "tst" has incompatible type "defaultdict[str, List[]]"; expected "defaultdict[int, List[]]" -_program.py:23: error: Invalid index type "str" for "MyDDict[Dict[_KT, _VT]]"; expected type "int" +_program.py:7: error: Argument 1 to "defaultdict" has incompatible type "Type[List[Any]]"; expected "Callable[[], str]" +_program.py:10: error: Invalid index type "str" for "defaultdict[int, str]"; expected type "int" +_program.py:10: error: Incompatible types in assignment (expression has type "int", target has type "str") +_program.py:20: error: Argument 1 to "tst" has incompatible type "defaultdict[str, List[]]"; expected "defaultdict[int, List[]]" +_program.py:24: error: Invalid index type "str" for "MyDDict[Dict[, ]]"; expected type "int" [case testNoSubcriptionOfStdlibCollections] # flags: --python-version 3.6 @@ -2032,7 +2033,6 @@ from dataclasses import dataclass, replace class A: x: int - a = A(x=42) a2 = replace(a, x=42) reveal_type(a2) @@ -2040,7 +2040,47 @@ a2 = replace() a2 = replace(a, x='spam') a2 = replace(a, x=42, q=42) [out] -_testDataclassReplace.py:10: note: Revealed type is "_testDataclassReplace.A" -_testDataclassReplace.py:11: error: Too few arguments for "replace" -_testDataclassReplace.py:12: error: Argument "x" to "replace" of "A" has incompatible type "str"; expected "int" -_testDataclassReplace.py:13: error: Unexpected keyword argument "q" for "replace" of "A" +_testDataclassReplace.py:9: note: Revealed type is "_testDataclassReplace.A" +_testDataclassReplace.py:10: error: Too few arguments for "replace" +_testDataclassReplace.py:11: error: Argument "x" to "replace" of "A" has incompatible type "str"; expected "int" +_testDataclassReplace.py:12: error: Unexpected keyword argument "q" for "replace" of "A" + +[case testGenericInferenceWithTuple] +# flags: --new-type-inference +from typing import TypeVar, Callable, Tuple + +T = TypeVar("T") + +def f(x: Callable[..., T]) -> T: + return x() + +x: Tuple[str, ...] = f(tuple) +[out] + +[case testGenericInferenceWithDataclass] +# flags: --new-type-inference +from typing import Any, Collection, List +from dataclasses import dataclass, field + +class Foo: + pass + +@dataclass +class A: + items: Collection[Foo] = field(default_factory=list) +[out] + +[case testGenericInferenceWithItertools] +# flags: --new-type-inference +from typing import TypeVar, Tuple +from itertools import groupby +K = TypeVar("K") +V = TypeVar("V") + +def fst(kv: Tuple[K, V]) -> K: + k, v = kv + return k + +pairs = [(len(s), s) for s in ["one", "two", "three"]] +grouped = groupby(pairs, key=fst) +[out]