From 6d886e9c46e892369d71c75dd676839b9eebf9be Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 2 Oct 2018 12:44:53 +0100 Subject: [PATCH] Improve usage of outer context for inference (#5699) Fixes #4872 Fixes #3876 Fixes #2678 Fixes #5199 Fixes #5493 (It also fixes a bunch of similar issues previously closed as duplicates, except one, see below). This PR fixes a problems when mypy commits to soon to using outer context for type inference. This is done by: * Postponing inference to inner (argument) context in situations where type inferred from outer (return) context doesn't satisfy bounds or constraints. * Adding a special case for situation where optional return is inferred against optional context. In such situation, unwrapping the optional is a better idea in 99% of cases. (Note: this doesn't affect type safety, only gives empirically more reasonable inferred types.) In general, instead of adding a special case, it would be better to use inner and outer context at the same time, but this a big change (see comment in code), and using the simple special case fixes majority of issues. Among reported issues, only https://github.com/python/mypy/issues/5311 will stay unfixed. --- mypy/applytype.py | 38 ++- mypy/checker.py | 14 +- mypy/checkexpr.py | 61 ++-- mypy/types.py | 17 ++ test-data/unit/check-generics.test | 2 +- test-data/unit/check-inference-context.test | 298 ++++++++++++++++++++ test-data/unit/check-overloading.test | 8 +- test-data/unit/check-typevar-values.test | 2 +- 8 files changed, 393 insertions(+), 47 deletions(-) diff --git a/mypy/applytype.py b/mypy/applytype.py index e1d81218b2f8..afd963928eee 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -9,13 +9,17 @@ def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optional[Type]], - msg: MessageBuilder, context: Context) -> CallableType: + msg: MessageBuilder, context: Context, + skip_unsatisfied: bool = False) -> CallableType: """Apply generic type arguments to a callable type. For example, applying [int] to 'def [T] (T) -> T' results in 'def (int) -> int'. Note that each type can be None; in this case, it will not be applied. + + If `skip_unsatisfied` is True, then just skip the types that don't satisfy type variable + bound or constraints, instead of giving an error. """ tvars = callable.variables assert len(tvars) == len(orig_types) @@ -25,7 +29,9 @@ def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optiona for i, type in enumerate(types): assert not isinstance(type, PartialType), "Internal error: must never apply partial type" values = callable.variables[i].values - if values and type: + if type is None: + continue + if values: if isinstance(type, AnyType): continue if isinstance(type, TypeVarType) and type.values: @@ -34,15 +40,31 @@ def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optiona if all(any(is_same_type(v, v1) for v in values) for v1 in type.values): continue + matching = [] for value in values: if mypy.subtypes.is_subtype(type, value): - types[i] = value - break + matching.append(value) + if matching: + best = matching[0] + # If there are more than one matching value, we select the narrowest + for match in matching[1:]: + if mypy.subtypes.is_subtype(match, best): + best = match + types[i] = best else: - msg.incompatible_typevar_value(callable, type, callable.variables[i].name, context) - upper_bound = callable.variables[i].upper_bound - if type and not mypy.subtypes.is_subtype(type, upper_bound): - msg.incompatible_typevar_value(callable, type, callable.variables[i].name, context) + if skip_unsatisfied: + types[i] = None + else: + msg.incompatible_typevar_value(callable, type, callable.variables[i].name, + context) + else: + upper_bound = callable.variables[i].upper_bound + if not mypy.subtypes.is_subtype(type, upper_bound): + if skip_unsatisfied: + types[i] = None + else: + msg.incompatible_typevar_value(callable, type, callable.variables[i].name, + context) # Create a map from type variable id to target type. id_to_type = {} # type: Dict[TypeVarId, Type] diff --git a/mypy/checker.py b/mypy/checker.py index 5efdd8221827..d9fb20aa3a00 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -31,7 +31,8 @@ Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType, Instance, NoneTyp, strip_type, TypeType, TypeOfAny, UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef, - true_only, false_only, function_type, is_named_instance, union_items, TypeQuery + true_only, false_only, function_type, is_named_instance, union_items, TypeQuery, + is_optional, remove_optional ) from mypy.sametypes import is_same_type, is_same_types from mypy.messages import MessageBuilder, make_inferred_type_note @@ -3792,17 +3793,6 @@ def is_literal_none(n: Expression) -> bool: return isinstance(n, NameExpr) and n.fullname == 'builtins.None' -def is_optional(t: Type) -> bool: - return isinstance(t, UnionType) and any(isinstance(e, NoneTyp) for e in t.items) - - -def remove_optional(typ: Type) -> Type: - if isinstance(typ, UnionType): - return UnionType.make_union([t for t in typ.items if not isinstance(t, NoneTyp)]) - else: - return typ - - def is_literal_not_implemented(n: Expression) -> bool: return isinstance(n, NameExpr) and n.fullname == 'builtins.NotImplemented' diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 0ed9e4948d3d..95a4ecdf2c1a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -18,8 +18,9 @@ from mypy.types import ( Type, AnyType, CallableType, Overloaded, NoneTyp, TypeVarDef, TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType, - PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, true_only, - false_only, is_named_instance, function_type, callable_type, FunctionLike, StarType, + PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, + true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike, + StarType, is_optional, remove_optional, is_invariant_instance ) from mypy.nodes import ( NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr, @@ -30,7 +31,7 @@ ConditionalExpr, ComparisonExpr, TempNode, SetComprehension, DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr, YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr, - TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, ClassDef, Block, SymbolNode, + TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, LITERAL_TYPE, REVEAL_TYPE ) from mypy.literals import literal @@ -819,20 +820,36 @@ def infer_function_type_arguments_using_context( # valid results. erased_ctx = replace_meta_vars(ctx, ErasedType()) ret_type = callable.ret_type - if isinstance(ret_type, TypeVarType): - if ret_type.values or (not isinstance(ctx, Instance) or - not ctx.args): - # The return type is a type variable. If it has values, we can't easily restrict - # type inference to conform to the valid values. If it's unrestricted, we could - # infer a too general type for the type variable if we use context, and this could - # result in confusing and spurious type errors elsewhere. - # - # Give up and just use function arguments for type inference. As an exception, - # if the context is a generic instance type, actually use it as context, as - # this *seems* to usually be the reasonable thing to do. - # - # See also github issues #462 and #360. - ret_type = NoneTyp() + if is_optional(ret_type) and is_optional(ctx): + # If both the context and the return type are optional, unwrap the optional, + # since in 99% cases this is what a user expects. In other words, we replace + # Optional[T] <: Optional[int] + # with + # T <: int + # while the former would infer T <: Optional[int]. + ret_type = remove_optional(ret_type) + erased_ctx = remove_optional(erased_ctx) + # + # TODO: Instead of this hack and the one below, we need to use outer and + # inner contexts at the same time. This is however not easy because of two + # reasons: + # * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables + # on both sides. (This is not too hard.) + # * We need to update all the inference "infrastructure", so that all + # variables in an expression are inferred at the same time. + # (And this is hard, also we need to be careful with lambdas that require + # two passes.) + if isinstance(ret_type, TypeVarType) and not is_invariant_instance(ctx): + # Another special case: the return type is a type variable. If it's unrestricted, + # we could infer a too general type for the type variable if we use context, + # and this could result in confusing and spurious type errors elsewhere. + # + # Give up and just use function arguments for type inference. As an exception, + # if the context is an invariant instance type, actually use it as context, as + # this *seems* to usually be the reasonable thing to do. + # + # See also github issues #462 and #360. + return callable.copy_modified() args = infer_type_arguments(callable.type_var_ids(), ret_type, erased_ctx) # Only substitute non-Uninhabited and non-erased types. new_args = [] # type: List[Optional[Type]] @@ -841,7 +858,10 @@ def infer_function_type_arguments_using_context( new_args.append(None) else: new_args.append(arg) - return self.apply_generic_arguments(callable, new_args, error_context) + # Don't show errors after we have only used the outer context for inference. + # We will use argument context to infer more variables. + return self.apply_generic_arguments(callable, new_args, error_context, + skip_unsatisfied=True) def infer_function_type_arguments(self, callee_type: CallableType, args: List[Expression], @@ -1609,9 +1629,10 @@ def check_arg(caller_type: Type, original_caller_type: Type, caller_kind: int, return False def apply_generic_arguments(self, callable: CallableType, types: Sequence[Optional[Type]], - context: Context) -> CallableType: + context: Context, skip_unsatisfied: bool = False) -> CallableType: """Simple wrapper around mypy.applytype.apply_generic_arguments.""" - return applytype.apply_generic_arguments(callable, types, self.msg, context) + return applytype.apply_generic_arguments(callable, types, self.msg, context, + skip_unsatisfied=skip_unsatisfied) def visit_member_expr(self, e: MemberExpr, is_lvalue: bool = False) -> Type: """Visit member expression (of form e.id).""" diff --git a/mypy/types.py b/mypy/types.py index acad55a12d75..9a84cc270e2a 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1918,6 +1918,23 @@ def union_items(typ: Type) -> List[Type]: return [typ] +def is_invariant_instance(tp: Type) -> bool: + if not isinstance(tp, Instance) or not tp.args: + return False + return any(v.variance == INVARIANT for v in tp.type.defn.type_vars) + + +def is_optional(t: Type) -> bool: + return isinstance(t, UnionType) and any(isinstance(e, NoneTyp) for e in t.items) + + +def remove_optional(typ: Type) -> Type: + if isinstance(typ, UnionType): + return UnionType.make_union([t for t in typ.items if not isinstance(t, NoneTyp)]) + else: + return typ + + names = globals().copy() # type: Final names.pop('NOT_READY', None) deserialize_map = { diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index e6fcda7ae6e2..aaa9283cd0d7 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -849,7 +849,7 @@ def fun2(v: Vec[T], scale: T) -> Vec[T]: return v reveal_type(fun1([(1, 1)])) # E: Revealed type is 'builtins.int*' -fun1(1) # E: Argument 1 to "fun1" has incompatible type "int"; expected "List[Tuple[int, int]]" +fun1(1) # E: Argument 1 to "fun1" has incompatible type "int"; expected "List[Tuple[bool, bool]]" fun1([(1, 'x')]) # E: Cannot infer type argument 1 of "fun1" reveal_type(fun2([(1, 1)], 1)) # E: Revealed type is 'builtins.list[Tuple[builtins.int*, builtins.int*]]' diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index 7b9495f62d2f..6977b3e2c465 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -920,3 +920,301 @@ class C: def f(self) -> None: g: Callable[[], int] = lambda: 1 or self.x self.x = int() + +[case testWideOuterContextSubClassBound] +from typing import TypeVar + +class A: ... +class B(A): ... + +T = TypeVar('T', bound=B) +def f(x: T) -> T: ... +def outer(x: A) -> None: ... + +outer(f(B())) +x: A = f(B()) + +[case testWideOuterContextSubClassBoundGenericReturn] +from typing import TypeVar, Iterable, List + +class A: ... +class B(A): ... + +T = TypeVar('T', bound=B) +def f(x: T) -> List[T]: ... +def outer(x: Iterable[A]) -> None: ... + +outer(f(B())) +x: Iterable[A] = f(B()) +[builtins fixtures/list.pyi] + +[case testWideOuterContextSubClassValues] +from typing import TypeVar + +class A: ... +class B(A): ... + +T = TypeVar('T', B, int) +def f(x: T) -> T: ... +def outer(x: A) -> None: ... + +outer(f(B())) +x: A = f(B()) + +[case testWideOuterContextSubClassValuesGenericReturn] +from typing import TypeVar, Iterable, List + +class A: ... +class B(A): ... + +T = TypeVar('T', B, int) +def f(x: T) -> List[T]: ... +def outer(x: Iterable[A]) -> None: ... + +outer(f(B())) +x: Iterable[A] = f(B()) +[builtins fixtures/list.pyi] + +[case testWideOuterContextSubclassBoundGeneric] +from typing import TypeVar, Generic + +S = TypeVar('S') +class A(Generic[S]): ... +class B(A[S]): ... + +T = TypeVar('T', bound=B[int]) +def f(x: T) -> T: ... +def outer(x: A[int]) -> None: ... + +y: B[int] +outer(f(y)) +x: A[int] = f(y) + +[case testWideOuterContextSubclassBoundGenericCovariant] +from typing import TypeVar, Generic + +S_co = TypeVar('S_co', covariant=True) +class A(Generic[S_co]): ... +class B(A[S_co]): ... + +T = TypeVar('T', bound=B[int]) +def f(x: T) -> T: ... +def outer(x: A[int]) -> None: ... + +y: B[int] +outer(f(y)) +x: A[int] = f(y) + +[case testWideOuterContextSubclassValuesGeneric] +from typing import TypeVar, Generic + +S = TypeVar('S') +class A(Generic[S]): ... +class B(A[S]): ... + +T = TypeVar('T', B[int], int) +def f(x: T) -> T: ... +def outer(x: A[int]) -> None: ... + +y: B[int] +outer(f(y)) +x: A[int] = f(y) + +[case testWideOuterContextSubclassValuesGenericCovariant] +from typing import TypeVar, Generic + +S_co = TypeVar('S_co', covariant=True) +class A(Generic[S_co]): ... +class B(A[S_co]): ... + +T = TypeVar('T', B[int], int) +def f(x: T) -> T: ... +def outer(x: A[int]) -> None: ... + +y: B[int] +outer(f(y)) +x: A[int] = f(y) + +[case testWideOuterContextUnionBound] +from typing import TypeVar, Union + +class A: ... +class B: ... + +T = TypeVar('T', bound=B) +def f(x: T) -> T: ... +def outer(x: Union[A, B]) -> None: ... + +outer(f(B())) +x: Union[A, B] = f(B()) + +[case testWideOuterContextUnionBoundGenericReturn] +from typing import TypeVar, Union, Iterable, List + +class A: ... +class B: ... + +T = TypeVar('T', bound=B) +def f(x: T) -> List[T]: ... +def outer(x: Iterable[Union[A, B]]) -> None: ... + +outer(f(B())) +x: Iterable[Union[A, B]] = f(B()) +[builtins fixtures/list.pyi] + +[case testWideOuterContextUnionValues] +from typing import TypeVar, Union + +class A: ... +class B: ... + +T = TypeVar('T', B, int) +def f(x: T) -> T: ... +def outer(x: Union[A, B]) -> None: ... + +outer(f(B())) +x: Union[A, B] = f(B()) + +[case testWideOuterContextUnionValuesGenericReturn] +from typing import TypeVar, Union, Iterable, List + +class A: ... +class B: ... + +T = TypeVar('T', B, int) +def f(x: T) -> List[T]: ... +def outer(x: Iterable[Union[A, B]]) -> None: ... + +outer(f(B())) +x: Iterable[Union[A, B]] = f(B()) +[builtins fixtures/list.pyi] + +[case testWideOuterContextOptional] +# flags: --strict-optional +from typing import Optional, Type, TypeVar + +class Custom: + pass + +T = TypeVar('T', bound=Custom) + +def a(x: T) -> Optional[T]: ... + +def b(x: T) -> Optional[T]: + return a(x) + +[case testWideOuterContextOptionalGenericReturn] +# flags: --strict-optional +from typing import Optional, Type, TypeVar, Iterable + +class Custom: + pass + +T = TypeVar('T', bound=Custom) + +def a(x: T) -> Iterable[Optional[T]]: ... + +def b(x: T) -> Iterable[Optional[T]]: + return a(x) + +[case testWideOuterContextOptionalMethod] +# flags: --strict-optional +from typing import Optional, Type, TypeVar + +class A: pass +class B: pass + +T = TypeVar('T', A, B) +class C: + def meth_a(self) -> Optional[A]: + return self.meth(A) + + def meth(self, cls: Type[T]) -> Optional[T]: ... + +[case testWideOuterContextValuesOverlapping] +from typing import TypeVar, List + +class A: + pass +class B(A): + pass +class C: + pass + +T = TypeVar('T', A, B, C) +def foo(xs: List[T]) -> T: ... + +S = TypeVar('S', B, C) +def bar(xs: List[S]) -> S: + foo(xs) + return xs[0] +[builtins fixtures/list.pyi] + +[case testWideOuterContextOptionalTypeVarReturn] +# flags: --strict-optional +from typing import Callable, Iterable, List, Optional, TypeVar + +class C: + x: str + +T = TypeVar('T') +def f(i: Iterable[T], c: Callable[[T], str]) -> Optional[T]: ... + +def g(l: List[C], x: str) -> Optional[C]: + def pred(c: C) -> str: + return c.x + return f(l, pred) +[builtins fixtures/list.pyi] + +[case testWideOuterContextOptionalTypeVarReturnLambda] +# flags: --strict-optional +from typing import Callable, Iterable, List, Optional, TypeVar + +class C: + x: str + +T = TypeVar('T') +def f(i: Iterable[T], c: Callable[[T], str]) -> Optional[T]: ... + +def g(l: List[C], x: str) -> Optional[C]: + return f(l, lambda c: reveal_type(c).x) # E: Revealed type is '__main__.C' +[builtins fixtures/list.pyi] + +[case testWideOuterContextEmpty] +from typing import List, TypeVar + +T = TypeVar('T', bound=int) +def f(x: List[T]) -> T: ... + +# mypy infers List[] here, and is a subtype of str +y: str = f([]) +[builtins fixtures/list.pyi] + +[case testWideOuterContextEmptyError] +from typing import List, TypeVar + +T = TypeVar('T', bound=int) +def f(x: List[T]) -> List[T]: ... + +# TODO: improve error message for such cases, see #3283 and #5706 +y: List[str] = f([]) # E: Incompatible types in assignment (expression has type "List[]", variable has type "List[str]") +[builtins fixtures/list.pyi] + +[case testWideOuterContextNoArgs] +# flags: --strict-optional +from typing import TypeVar, Optional + +T = TypeVar('T', bound=int) +def f(x: Optional[T] = None) -> T: ... + +y: str = f() + +[case testWideOuterContextNoArgsError] +# flags: --strict-optional +from typing import TypeVar, Optional, List + +T = TypeVar('T', bound=int) +def f(x: Optional[T] = None) -> List[T]: ... + +y: List[str] = f() # E: Incompatible types in assignment (expression has type "List[]", variable has type "List[str]") +[builtins fixtures/list.pyi] diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index a769f53f43b5..fb5b442b88be 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -4923,11 +4923,9 @@ reveal_type(f(g())) # E: Revealed type is 'builtins.list[builtins.int]' [builtins fixtures/list.pyi] [case testOverloadInferringArgumentsUsingContext2-skip] -# This test case ought to work, but is maybe blocked by -# https://github.com/python/mypy/issues/4872? -# -# See https://github.com/python/mypy/pull/5660#discussion_r219669409 for -# more context. +# TODO: Overloads only use outer context to infer type variables in a given overload variant, +# but never use outer context to _choose_ a better overload in ambiguous situations +# like empty containers or multiple inheritance, instead just always choosing the first one. from typing import Optional, List, overload, TypeVar T = TypeVar('T') diff --git a/test-data/unit/check-typevar-values.test b/test-data/unit/check-typevar-values.test index 43ab8f85a7eb..9ef65f51a76e 100644 --- a/test-data/unit/check-typevar-values.test +++ b/test-data/unit/check-typevar-values.test @@ -18,7 +18,7 @@ s = ['x'] o = [object()] i = f(1) s = f('') -o = f(1) # E: Value of type variable "T" of "f" cannot be "object" +o = f(1) # E: Incompatible types in assignment (expression has type "List[int]", variable has type "List[object]") [builtins fixtures/list.pyi] [case testCallGenericFunctionWithTypeVarValueRestrictionAndAnyArgs]