diff --git a/mypy_django_plugin/transformers/managers.py b/mypy_django_plugin/transformers/managers.py index 7fa3244373..7798196c5d 100644 --- a/mypy_django_plugin/transformers/managers.py +++ b/mypy_django_plugin/transformers/managers.py @@ -18,7 +18,7 @@ from mypy.plugin import AttributeContext, ClassDefContext, DynamicClassDefContext from mypy.semanal import SemanticAnalyzer from mypy.semanal_shared import has_placeholder -from mypy.types import AnyType, CallableType, Instance, ProperType, TypeOfAny +from mypy.types import AnyType, CallableType, Instance, ProperType, TypeOfAny, TypeVarType from mypy.types import Type as MypyType from mypy.typevars import fill_typevars @@ -105,6 +105,8 @@ def get_funcdef_type(definition: Union[FuncBase, Decorator, None]) -> Optional[P ret_type = Instance(queryset_info, [manager_instance.args[0], manager_instance.args[0]]) variables = [] + ret_type = _manage_type_var(ret_type, queryset_info) + # Drop any 'self' argument as our manager is already initialized return method_type.copy_modified( arg_types=method_type.arg_types[1:], @@ -115,6 +117,48 @@ def get_funcdef_type(definition: Union[FuncBase, Decorator, None]) -> Optional[P ) +def _manage_type_var(ret_type: MypyType, queryset_info: TypeInfo) -> MypyType: + """ + Attempts to manage the concrete type corresponding to a `TypeVarType` in a queryset's base types. + If nothing is found, keep it unchanged. + + Example: + + T = TypeVar("T") + for def example_list_dict(self) -> list[dict[str, T]] + + Tries to find the corresponding `TypeVar` of `T` if it has been set. + + """ + if isinstance(ret_type, TypeVarType): + return _find_type_var(queryset_info, ret_type) or ret_type + elif isinstance(ret_type, Instance): + # Since it is an instance, recursively find the type var for all its args. + ret_type.args = tuple(_manage_type_var(item, queryset_info) for item in ret_type.args) + if isinstance(ret_type, ProperType) and hasattr(ret_type, "item"): + # For example TypeType has an item. find the type_var for this item + ret_type.item = _manage_type_var(ret_type.item, queryset_info) + if isinstance(ret_type, ProperType) and hasattr(ret_type, "items"): + # For example TypeList has items. find recursively type_var for its items + ret_type.items = [_manage_type_var(item, queryset_info) for item in ret_type.items] + return ret_type + + +def _find_type_var(queryset_info: TypeInfo, ret_type: TypeVarType) -> Optional[MypyType]: + """ + Attempts to find the concrete type corresponding to a TypeVarType in a queryset's base types. + + Example: + Suppose queryset_info is based on a generic class `QuerySet[T]` and ret_type corresponds + to `T`, then this function will return the concrete type that `T` was instantiated with. + """ + for base in queryset_info.bases: + for i, type_var in enumerate(base.type.defn.type_vars): + if type_var.fullname == ret_type.fullname: + return base.args[i] + return None + + def get_method_type_from_reverse_manager( api: TypeChecker, method_name: str, manager_type_info: TypeInfo ) -> Optional[ProperType]: diff --git a/tests/typecheck/managers/querysets/test_as_manager.yml b/tests/typecheck/managers/querysets/test_as_manager.yml index 06882368b9..fab0813fd6 100644 --- a/tests/typecheck/managers/querysets/test_as_manager.yml +++ b/tests/typecheck/managers/querysets/test_as_manager.yml @@ -146,6 +146,63 @@ class MyModel(models.Model): objects = ManagerFromModelQuerySet +- case: handles_subclasses_of_queryset + main: | + from myapp.models import MyModel + reveal_type(MyModel.objects.example()) # N: Revealed type is "myapp.models.MyModel" + reveal_type(MyModel.objects.example_2()) # N: Revealed type is "myapp.models.MyOtherModel" + reveal_type(MyModel.objects.example_list()) # N: Revealed type is "builtins.list[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_type()) # N: Revealed type is "Type[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_tuple_simple()) # N: Revealed type is "Tuple[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_tuple_list()) # N: Revealed type is "builtins.tuple[myapp.models.MyModel, ...]" + reveal_type(MyModel.objects.example_tuple_double()) # N: Revealed type is "Tuple[builtins.int, myapp.models.MyModel]" + reveal_type(MyModel.objects.example_class()) # N: Revealed type is "myapp.models.Example[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_type_class()) # N: Revealed type is "Type[myapp.models.Example[myapp.models.MyModel]]" + reveal_type(MyModel.objects.example_collection()) # N: Revealed type is "typing.Collection[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_set()) # N: Revealed type is "builtins.set[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_dict()) # N: Revealed type is "builtins.dict[builtins.str, myapp.models.MyModel]" + reveal_type(MyModel.objects.example_list_dict()) # N: Revealed type is "builtins.list[builtins.dict[myapp.models.MyOtherModel, myapp.models.MyModel]]" + + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from typing import TypeVar, Generic, Collection + + from django.db import models + + _CTE = TypeVar("_CTE", bound=models.Model) + _CTE_2 = TypeVar("_CTE_2", bound=models.Model) + + class Example(Generic[_CTE]): ... + + class _MyModelQuerySet(models.QuerySet[_CTE], Generic[_CTE, _CTE_2]): + + def example(self) -> _CTE: ... + def example_2(self) -> _CTE_2: ... + def example_list(self) -> list[_CTE]: ... + def example_type(self) -> type[_CTE]: ... + def example_tuple_simple(self) -> tuple[_CTE]: ... + def example_tuple_list(self) -> tuple[_CTE, ...]: ... + def example_tuple_double(self) -> tuple[int, _CTE]: ... + def example_class(self) -> Example[_CTE]: ... + def example_type_class(self) -> type[Example[_CTE]]: ... + def example_collection(self) -> Collection[_CTE]: ... + def example_set(self) -> set[_CTE]: ... + def example_dict(self) -> dict[str, _CTE]: ... + def example_list_dict(self) -> list[dict[_CTE_2, _CTE]]: ... + + + class MyModelQuerySet(_MyModelQuerySet["MyModel", "MyOtherModel"]): + ... + + class MyModel(models.Model): + objects = MyModelQuerySet.as_manager() + + class MyOtherModel(models.Model): ... + - case: reuses_generated_type_when_called_identically_for_multiple_managers main: | from myapp.models import MyModel diff --git a/tests/typecheck/managers/querysets/test_from_queryset.yml b/tests/typecheck/managers/querysets/test_from_queryset.yml index aa152983f5..f9c5af288f 100644 --- a/tests/typecheck/managers/querysets/test_from_queryset.yml +++ b/tests/typecheck/managers/querysets/test_from_queryset.yml @@ -70,6 +70,41 @@ NewManager = BaseManager.from_queryset(ModelQuerySet) class MyModel(models.Model): objects = NewManager() +- case: handles_subclasses_of_queryset + main: | + from myapp.models import MyModel + reveal_type(MyModel.objects.example()) # N: Revealed type is "myapp.models.MyModel" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/querysets.py + content: | + from typing import TypeVar, TYPE_CHECKING + + from django.db import models + from django.db.models.manager import BaseManager + if TYPE_CHECKING: + from .models import MyModel + + _CTE = TypeVar("_CTE", bound=models.Model) + + class _MyModelQuerySet(models.QuerySet[_CTE]): + + def example(self) -> _CTE: ... + + class MyModelQuerySet(_MyModelQuerySet["MyModel"]): + ... + + - path: myapp/models.py + content: | + from django.db import models + from django.db.models.manager import BaseManager + from .querysets import MyModelQuerySet + + NewManager = BaseManager.from_queryset(MyModelQuerySet) + class MyModel(models.Model): + objects = NewManager() - case: from_queryset_generated_manager_imported_from_other_module main: |