From 4001d8dd8d06de7c8b39038e2ed25f92e0e8a043 Mon Sep 17 00:00:00 2001 From: Kouroche Bouchiat Date: Sat, 16 Dec 2023 14:22:37 +0100 Subject: [PATCH] Substitute type variables in return type of static methods `add_class_tvars` correctly instantiates type variables for class methods but not for static methods. Check if the analyzed member is a static method in `analyze_class_attribute_access` and substitute the type variable in the return type in `add_class_tvars` accordingly. Fixes #16668. --- mypy/checkmember.py | 15 +++++++++++++-- test-data/unit/check-generics.test | 13 +++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 5a4f3875ad04..c24edacf0ee1 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -1063,11 +1063,14 @@ def analyze_class_attribute_access( is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or ( isinstance(node.node, FuncBase) and node.node.is_class ) + is_staticmethod = (is_decorated and cast(Decorator, node.node).func.is_static) or ( + isinstance(node.node, FuncBase) and node.node.is_static + ) t = get_proper_type(t) if isinstance(t, FunctionLike) and is_classmethod: t = check_self_arg(t, mx.self_type, False, mx.context, name, mx.msg) result = add_class_tvars( - t, isuper, is_classmethod, mx.self_type, original_vars=original_vars + t, isuper, is_classmethod, is_staticmethod, mx.self_type, original_vars=original_vars ) if not mx.is_lvalue: result = analyze_descriptor_access(result, mx) @@ -1177,6 +1180,7 @@ def add_class_tvars( t: ProperType, isuper: Instance | None, is_classmethod: bool, + is_staticmethod: bool, original_type: Type, original_vars: Sequence[TypeVarLikeType] | None = None, ) -> Type: @@ -1195,6 +1199,7 @@ class B(A[str]): pass isuper: Current instance mapped to the superclass where method was defined, this is usually done by map_instance_to_supertype() is_classmethod: True if this method is decorated with @classmethod + is_staticmethod: True if this method is decorated with @staticmethod original_type: The value of the type B in the expression B.foo() or the corresponding component in case of a union (this is used to bind the self-types) original_vars: Type variables of the class callable on which the method was accessed @@ -1220,6 +1225,7 @@ class B(A[str]): pass t = freshen_all_functions_type_vars(t) if is_classmethod: t = bind_self(t, original_type, is_classmethod=True) + if is_classmethod or is_staticmethod: assert isuper is not None t = expand_type_by_instance(t, isuper) freeze_all_type_vars(t) @@ -1230,7 +1236,12 @@ class B(A[str]): pass cast( CallableType, add_class_tvars( - item, isuper, is_classmethod, original_type, original_vars=original_vars + item, + isuper, + is_classmethod, + is_staticmethod, + original_type, + original_vars=original_vars, ), ) for item in t.items diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index ef3f359e4989..e2f65ed39c1e 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2297,6 +2297,19 @@ def func(x: S) -> S: return C[S].get() [builtins fixtures/classmethod.pyi] +[case testGenericStaticMethodInGenericFunction] +from typing import Generic, TypeVar +T = TypeVar('T') +S = TypeVar('S') + +class C(Generic[T]): + @staticmethod + def get() -> T: ... + +def func(x: S) -> S: + return C[S].get() +[builtins fixtures/staticmethod.pyi] + [case testMultipleAssignmentFromAnyIterable] from typing import Any class A: