diff --git a/mypy_django_plugin/transformers/managers.py b/mypy_django_plugin/transformers/managers.py index 7fa324437..721bacd98 100644 --- a/mypy_django_plugin/transformers/managers.py +++ b/mypy_django_plugin/transformers/managers.py @@ -1,4 +1,4 @@ -from typing import Final, Optional, Union +from typing import Final, Optional from mypy.checker import TypeChecker from mypy.nodes import ( @@ -8,6 +8,7 @@ FuncBase, FuncDef, MemberExpr, + Node, OverloadedFuncDef, RefExpr, StrExpr, @@ -18,7 +19,7 @@ from mypy.plugin import AttributeContext, ClassDefContext, DynamicClassDefContext from mypy.semanal import SemanticAnalyzer from mypy.semanal_shared import has_placeholder -from mypy.types import AnyType, CallableType, Instance, ProperType, TypeOfAny +from mypy.types import AnyType, CallableType, Instance, ProperType, TypeOfAny, TypeVarType from mypy.types import Type as MypyType from mypy.typevars import fill_typevars @@ -71,7 +72,7 @@ def get_method_type_from_dynamic_manager( queryset_info = helpers.lookup_fully_qualified_typeinfo(api, queryset_fullname) assert queryset_info is not None - def get_funcdef_type(definition: Union[FuncBase, Decorator, None]) -> Optional[ProperType]: + def get_funcdef_type(definition: Node | None) -> Optional[ProperType]: # TODO: Handle @overload? if isinstance(definition, FuncBase) and not isinstance(definition, OverloadedFuncDef): return definition.type @@ -79,7 +80,10 @@ def get_funcdef_type(definition: Union[FuncBase, Decorator, None]) -> Optional[P return definition.func.type return None - method_type = get_funcdef_type(queryset_info.get_method(method_name)) + base_that_has_method = queryset_info.get_containing_type_info(method_name) + if base_that_has_method is None: + return None + method_type = get_funcdef_type(base_that_has_method.names[method_name].node) if method_type is None: return None @@ -104,10 +108,19 @@ def get_funcdef_type(definition: Union[FuncBase, Decorator, None]) -> Optional[P # only needed for pluign-generated querysets. ret_type = Instance(queryset_info, [manager_instance.args[0], manager_instance.args[0]]) variables = [] + args_types = method_type.arg_types[1:] + if _has_compatible_type_vars(base_that_has_method): + ret_type = _replace_type_var( + ret_type, base_that_has_method.defn.type_vars[0].fullname, manager_instance.args[0] + ) + args_types = [ + _replace_type_var(arg_type, base_that_has_method.defn.type_vars[0].fullname, manager_instance.args[0]) + for arg_type in args_types + ] # Drop any 'self' argument as our manager is already initialized return method_type.copy_modified( - arg_types=method_type.arg_types[1:], + arg_types=args_types, arg_kinds=method_type.arg_kinds[1:], arg_names=method_type.arg_names[1:], variables=variables, @@ -115,6 +128,91 @@ def get_funcdef_type(definition: Union[FuncBase, Decorator, None]) -> Optional[P ) +def _has_compatible_type_vars(type_info: TypeInfo) -> bool: + """ + Determines whether the provided 'type_info', + is a generically parameterized subclass of models.QuerySet[T], with exactly + one type variable. + + Criteria for compatibility: + 1. 'type_info' must be a generic class with exactly one type variable. + 2. All superclasses of 'type_info', up to and including models.QuerySet, + must also be generic classes with exactly one type variable. + + Examples: + + Compatible: + class _MyModelQuerySet(models.QuerySet[T]): ... + class MyModelQuerySet(_MyModelQuerySet[T_2]): + def example(self) -> T_2: ... + + Incompatible: + class MyModelQuerySet(models.QuerySet[T], Generic[T, T2]): + def example(self, a: T2) -> T_2: ... + + Returns: + True if the 'base' meets the criteria, otherwise False. + """ + args = type_info.defn.type_vars + if not args or len(args) > 1 or not isinstance(args[0], TypeVarType): + # No type var to manage, or too many + return False + type_var: MypyType | None = None + # check that for all the bases it has only one type vars + for sub_base in type_info.bases: + unic_args = list(set(sub_base.args)) + if not unic_args or len(unic_args) > 1: + # No type var for the sub_base, skipping + continue + if type_var and unic_args and type_var != unic_args[0]: + # There is two different type vars in the bases, we are not compatible + return False + type_var = unic_args[0] + if not type_var: + # No type var found in the bases. + return False + + if type_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): + # If it is a subclass of _QuerySet, it is compatible. + return True + # check that at least one base is a subclass of queryset with Generic type vars + return any(_has_compatible_type_vars(sub_base.type) for sub_base in type_info.bases) + + +def _replace_type_var(ret_type: MypyType, to_replace: str, replace_by: MypyType) -> MypyType: + """ + Substitutes a specified type variable within a Mypy type expression with an actual type. + + This function is recursive, and it operates on various kinds of Mypy types like Instance, + ProperType, etc., to deeply replace the specified type variable. + + Parameters: + - ret_type: A Mypy type expression where the substitution should occur. + - to_replace: The type variable to be replaced, specified as its full name. + - replace_by: The actual Mypy type to substitute in place of 'to_replace'. + + Example: + Given: + ret_type = "typing.Collection[T]" + to_replace = "T" + replace_by = "myapp.models.MyModel" + Result: + "typing.Collection[myapp.models.MyModel]" + """ + if isinstance(ret_type, TypeVarType) and ret_type.fullname == to_replace: + return replace_by + elif isinstance(ret_type, Instance): + # Since it is an instance, recursively find the type var for all its args. + ret_type.args = tuple(_replace_type_var(item, to_replace, replace_by) for item in ret_type.args) + if isinstance(ret_type, ProperType) and hasattr(ret_type, "item"): + # For example TypeType has an item. find the type_var for this item + ret_type.item = _replace_type_var(ret_type.item, to_replace, replace_by) + if isinstance(ret_type, ProperType) and hasattr(ret_type, "items"): + # For example TypeList has items. find recursively type_var for its items + ret_type.items = [_replace_type_var(item, to_replace, replace_by) for item in ret_type.items] + return ret_type + + def get_method_type_from_reverse_manager( api: TypeChecker, method_name: str, manager_type_info: TypeInfo ) -> Optional[ProperType]: diff --git a/tests/typecheck/managers/querysets/test_as_manager.yml b/tests/typecheck/managers/querysets/test_as_manager.yml index 06882368b..a698cca0f 100644 --- a/tests/typecheck/managers/querysets/test_as_manager.yml +++ b/tests/typecheck/managers/querysets/test_as_manager.yml @@ -146,6 +146,121 @@ class MyModel(models.Model): objects = ManagerFromModelQuerySet +- case: handles_type_var_in_subclasses_of_subclasses_of_queryset + main: | + from myapp.models import MyModel, MyOtherModel + reveal_type(MyModel.objects.example_2()) # N: Revealed type is "myapp.models.MyModel" + reveal_type(MyModel.objects.example()) # N: Revealed type is "myapp.models.MyModel" + reveal_type(MyModel.objects.example_2()) # N: Revealed type is "myapp.models.MyModel" + reveal_type(MyModel.objects.override()) # N: Revealed type is "myapp.models.MyModel" + reveal_type(MyModel.objects.override2()) # N: Revealed type is "myapp.models.MyModel" + reveal_type(MyModel.objects.dummy_override()) # N: Revealed type is "myapp.models.MyModel" + reveal_type(MyModel.objects.example_mixin(MyModel())) # N: Revealed type is "myapp.models.MyModel" + reveal_type(MyModel.objects.example_other_mixin()) # N: Revealed type is "myapp.models.MyModel" + reveal_type(MyOtherModel.objects.example()) # N: Revealed type is "myapp.models.MyOtherModel" + reveal_type(MyOtherModel.objects.example_2()) # N: Revealed type is "myapp.models.MyOtherModel" + reveal_type(MyOtherModel.objects.override()) # N: Revealed type is "myapp.models.MyOtherModel" + reveal_type(MyOtherModel.objects.override2()) # N: Revealed type is "myapp.models.MyOtherModel" + reveal_type(MyOtherModel.objects.dummy_override()) # N: Revealed type is "myapp.models.MyOtherModel" + reveal_type(MyOtherModel.objects.example_mixin(MyOtherModel())) # N: Revealed type is "myapp.models.MyOtherModel" + reveal_type(MyOtherModel.objects.example_other_mixin()) # N: Revealed type is "myapp.models.MyOtherModel" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from typing import TypeVar, Generic + from django.db import models + + T = TypeVar("T", bound=models.Model) + T_2 = TypeVar("T_2", bound=models.Model) + + class SomeMixin: + def example_mixin(self, a: T) -> T: ... + + class OtherMixin(models.QuerySet[T]): + def example_other_mixin(self) -> T: ... + + class _MyModelQuerySet(OtherMixin[T], models.QuerySet[T], Generic[T]): + def example(self) -> T: ... + def override(self) -> T: ... + def override2(self) -> T: ... + def dummy_override(self) -> int: ... + + class _MyModelQuerySet2(SomeMixin, _MyModelQuerySet[T_2]): + def example_2(self) -> T_2: ... + def override(self) -> T_2: ... + def override2(self) -> T_2: ... + def dummy_override(self) -> T_2: ... # type: ignore[override] + + class MyModelQuerySet(_MyModelQuerySet2["MyModel"]): + def override(self) -> "MyModel": ... + + class MyModel(models.Model): + objects = MyModelQuerySet.as_manager() + + class MyOtherModel(models.Model): + objects = _MyModelQuerySet2.as_manager() # type: ignore + +- case: handles_type_vars + main: | + from myapp.models import MyModel, BaseQuerySet + reveal_type(MyModel.objects.example()) # N: Revealed type is "myapp.models.MyModel" + reveal_type(MyModel.objects.example_list()) # N: Revealed type is "builtins.list[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_type()) # N: Revealed type is "Type[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_tuple_simple()) # N: Revealed type is "Tuple[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_tuple_list()) # N: Revealed type is "builtins.tuple[myapp.models.MyModel, ...]" + reveal_type(MyModel.objects.example_tuple_double()) # N: Revealed type is "Tuple[builtins.int, myapp.models.MyModel]" + reveal_type(MyModel.objects.example_class()) # N: Revealed type is "myapp.models.Example[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_type_class()) # N: Revealed type is "Type[myapp.models.Example[myapp.models.MyModel]]" + reveal_type(MyModel.objects.example_collection()) # N: Revealed type is "typing.Collection[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_set()) # N: Revealed type is "builtins.set[myapp.models.MyModel]" + reveal_type(MyModel.objects.example_dict()) # N: Revealed type is "builtins.dict[builtins.str, myapp.models.MyModel]" + reveal_type(MyModel.objects.example_list_dict()) # N: Revealed type is "builtins.list[builtins.dict[builtins.str, myapp.models.MyModel]]" + class TestQuerySet(BaseQuerySet[str]): ... # E: Type argument "str" of "BaseQuerySet" must be a subtype of "Model" + reveal_type(MyModel.objects.example_t(5)) # N: Revealed type is "builtins.int" + MyModel.objects.example_arg(5, "5") # E: Argument 1 to "example_arg" of "BaseQuerySet" has incompatible type "int"; expected "MyModel" + reveal_type(MyModel.objects.example_arg(MyModel(), "5")) # N: Revealed type is "myapp.models.MyModel" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from typing import TypeVar, Generic, Collection + + from django.db import models + + _CTE = TypeVar("_CTE", bound=models.Model) + T = TypeVar("T") + + class Example(Generic[_CTE]): ... + + class BaseQuerySet(models.QuerySet[_CTE], Generic[_CTE]): + + def example(self) -> _CTE: ... + def example_list(self) -> list[_CTE]: ... + def example_type(self) -> type[_CTE]: ... + def example_tuple_simple(self) -> tuple[_CTE]: ... + def example_tuple_list(self) -> tuple[_CTE, ...]: ... + def example_tuple_double(self) -> tuple[int, _CTE]: ... + def example_class(self) -> Example[_CTE]: ... + def example_type_class(self) -> type[Example[_CTE]]: ... + def example_collection(self) -> Collection[_CTE]: ... + def example_set(self) -> set[_CTE]: ... + def example_dict(self) -> dict[str, _CTE]: ... + def example_list_dict(self) -> list[dict[str, _CTE]]: ... + def example_t(self, a: T) -> T: ... + def example_arg(self, a: _CTE, b: str) -> _CTE: ... + + + class MyModelQuerySet(BaseQuerySet["MyModel"]): + ... + + class MyModel(models.Model): + objects = MyModelQuerySet.as_manager() + - case: reuses_generated_type_when_called_identically_for_multiple_managers main: | from myapp.models import MyModel diff --git a/tests/typecheck/managers/querysets/test_from_queryset.yml b/tests/typecheck/managers/querysets/test_from_queryset.yml index aa152983f..f9c5af288 100644 --- a/tests/typecheck/managers/querysets/test_from_queryset.yml +++ b/tests/typecheck/managers/querysets/test_from_queryset.yml @@ -70,6 +70,41 @@ NewManager = BaseManager.from_queryset(ModelQuerySet) class MyModel(models.Model): objects = NewManager() +- case: handles_subclasses_of_queryset + main: | + from myapp.models import MyModel + reveal_type(MyModel.objects.example()) # N: Revealed type is "myapp.models.MyModel" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/querysets.py + content: | + from typing import TypeVar, TYPE_CHECKING + + from django.db import models + from django.db.models.manager import BaseManager + if TYPE_CHECKING: + from .models import MyModel + + _CTE = TypeVar("_CTE", bound=models.Model) + + class _MyModelQuerySet(models.QuerySet[_CTE]): + + def example(self) -> _CTE: ... + + class MyModelQuerySet(_MyModelQuerySet["MyModel"]): + ... + + - path: myapp/models.py + content: | + from django.db import models + from django.db.models.manager import BaseManager + from .querysets import MyModelQuerySet + + NewManager = BaseManager.from_queryset(MyModelQuerySet) + class MyModel(models.Model): + objects = NewManager() - case: from_queryset_generated_manager_imported_from_other_module main: |