diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 3332ae9b2..862717236 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -33,7 +33,6 @@ DynamicClassDefContext, FunctionContext, MethodContext, - SemanticAnalyzerPluginInterface, ) from mypy.semanal import SemanticAnalyzer from mypy.types import AnyType, Instance, LiteralType, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType @@ -63,9 +62,7 @@ def get_django_metadata(model_info: TypeInfo) -> DjangoTypeMetadata: return cast(DjangoTypeMetadata, model_info.metadata.setdefault("django", {})) -def get_django_metadata_bases( - model_info: TypeInfo, key: Literal["baseform_bases", "manager_bases", "queryset_bases"] -) -> Dict[str, int]: +def get_django_metadata_bases(model_info: TypeInfo, key: Literal["baseform_bases", "queryset_bases"]) -> Dict[str, int]: return get_django_metadata(model_info).setdefault(key, cast(Dict[str, int], {})) @@ -422,13 +419,6 @@ def add_new_sym_for_info( info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True, no_serialize=no_serialize) -def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None: - sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME) - if sym is not None and isinstance(sym.node, TypeInfo): - bases = get_django_metadata_bases(sym.node, "manager_bases") - bases[fullname] = 1 - - def is_abstract_model(model: TypeInfo) -> bool: if model.fullname in fullnames.DJANGO_ABSTRACT_MODELS: return True diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index a2f52ab2e..e9f549b70 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,6 +1,6 @@ import itertools import sys -from functools import partial +from functools import cached_property, partial from typing import Any, Callable, Dict, List, Optional, Tuple, Type from mypy.build import PRI_MED, PRI_MYPY @@ -19,7 +19,6 @@ ) from mypy.types import Type as MypyType -import mypy_django_plugin.transformers.orm_lookups from mypy_django_plugin.config import DjangoPluginConfig from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.exceptions import UnregisteredModelError @@ -31,6 +30,7 @@ manytomany, manytoone, meta, + orm_lookups, querysets, request, settings, @@ -60,10 +60,6 @@ def transform_form_class(ctx: ClassDefContext) -> None: forms.make_meta_nested_class_inherit_from_any(ctx) -def add_new_manager_base_hook(ctx: ClassDefContext) -> None: - helpers.add_new_manager_base(ctx.api, ctx.cls.fullname) - - class NewSemanalDjangoPlugin(Plugin): def __init__(self, options: Options) -> None: super().__init__(options) @@ -83,15 +79,6 @@ def _get_current_queryset_bases(self) -> Dict[str, int]: else: return {} - def _get_current_manager_bases(self) -> Dict[str, int]: - model_sym = self.lookup_fully_qualified(fullnames.MANAGER_CLASS_FULLNAME) - if model_sym is not None and isinstance(model_sym.node, TypeInfo): - bases = helpers.get_django_metadata_bases(model_sym.node, "manager_bases") - bases[fullnames.MANAGER_CLASS_FULLNAME] = 1 - return bases - else: - return {} - def _get_current_form_bases(self) -> Dict[str, int]: model_sym = self.lookup_fully_qualified(fullnames.BASEFORM_CLASS_FULLNAME) if model_sym is not None and isinstance(model_sym.node, TypeInfo): @@ -165,10 +152,6 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext if fullname == "django.contrib.auth.get_user_model": return partial(settings.get_user_model_hook, django_context=self.django_context) - manager_bases = self._get_current_manager_bases() - if fullname in manager_bases: - return querysets.determine_proper_manager_type - info = self._get_typeinfo_or_none(fullname) if info: if info.has_base(fullnames.FIELD_FULLNAME): @@ -177,8 +160,26 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext if helpers.is_model_type(info): return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context) + if info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME): + return querysets.determine_proper_manager_type + return None + @cached_property + def manager_and_queryset_method_hooks(self) -> Dict[str, Callable[[MethodContext], MypyType]]: + typecheck_filtering_method = partial(orm_lookups.typecheck_queryset_filter, django_context=self.django_context) + return { + "values": partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context), + "values_list": partial( + querysets.extract_proper_type_queryset_values_list, django_context=self.django_context + ), + "annotate": partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context), + "create": partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context), + "filter": typecheck_filtering_method, + "get": typecheck_filtering_method, + "exclude": typecheck_filtering_method, + } + def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]: class_fullname, _, method_name = fullname.rpartition(".") # Methods called very often -- short circuit for minor speed up @@ -208,38 +209,17 @@ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], M } return hooks.get(class_fullname) - manager_classes = self._get_current_manager_bases() - - if method_name == "values": + if method_name in self.manager_and_queryset_method_hooks: info = self._get_typeinfo_or_none(class_fullname) - if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes: - return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context) - - elif method_name == "values_list": - info = self._get_typeinfo_or_none(class_fullname) - if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes: - return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context) - - elif method_name == "annotate": - info = self._get_typeinfo_or_none(class_fullname) - if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes: - return partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context) - + if info and helpers.has_any_of_bases( + info, [fullnames.QUERYSET_CLASS_FULLNAME, fullnames.MANAGER_CLASS_FULLNAME] + ): + return self.manager_and_queryset_method_hooks[method_name] elif method_name == "get_field": info = self._get_typeinfo_or_none(class_fullname) if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME): return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context) - elif method_name == "create": - # We need `BASE_MANAGER_CLASS_FULLNAME` to check abstract models. - if class_fullname in manager_classes or class_fullname == fullnames.BASE_MANAGER_CLASS_FULLNAME: - return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context) - elif method_name in {"filter", "get", "exclude"} and class_fullname in manager_classes: - return partial( - mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter, - django_context=self.django_context, - ) - return None def get_customize_class_mro_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: @@ -262,10 +242,6 @@ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefConte if sym is not None and isinstance(sym.node, TypeInfo) and helpers.is_model_type(sym.node): return partial(process_model_class, django_context=self.django_context) - # Base class is a Manager class definition - if fullname in self._get_current_manager_bases(): - return add_new_manager_base_hook - # Base class is a Form class definition if fullname in self._get_current_form_bases(): return transform_form_class diff --git a/mypy_django_plugin/transformers/managers.py b/mypy_django_plugin/transformers/managers.py index e4be536a2..951f1507b 100644 --- a/mypy_django_plugin/transformers/managers.py +++ b/mypy_django_plugin/transformers/managers.py @@ -324,9 +324,6 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte ctx.api.defer() return - # So that the plugin will reparameterize the manager when it is constructed inside of a Model definition - helpers.add_new_manager_base(semanal_api, new_manager_info.fullname) - def register_dynamically_created_manager(fullname: str, manager_name: str, manager_base: TypeInfo) -> None: manager_base.metadata.setdefault("from_queryset_managers", {}) @@ -558,9 +555,6 @@ def create_new_manager_class_from_as_manager_method(ctx: DynamicClassDefContext) manager_base=manager_base, ) - # So that the plugin will reparameterize the manager when it is constructed inside of a Model definition - helpers.add_new_manager_base(semanal_api, new_manager_info.fullname) - # Whenever `.as_manager()` isn't called at class level, we want to ensure # that the variable is an instance of our generated manager. Instead of the return # value of `.as_manager()`. Though model argument is populated as `Any`. diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 2bb17ed08..c85c87bf1 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -934,7 +934,6 @@ def create_many_related_manager(self, model: Instance) -> None: helpers.set_many_to_many_manager_info( to=model.type, derived_from="_default_manager", manager_info=related_manager_info ) - helpers.add_new_manager_base(self.api, related_manager_info.fullname) class MetaclassAdjustments(ModelClassInitializer): diff --git a/tests/typecheck/managers/test_managers.yml b/tests/typecheck/managers/test_managers.yml index d8c2f3e38..fd360c727 100644 --- a/tests/typecheck/managers/test_managers.yml +++ b/tests/typecheck/managers/test_managers.yml @@ -653,3 +653,20 @@ def get_instance(self) -> int: pass objects = MyManager() + +- case: test_typechecks_filter_methods_of_queryset_type + main: | + from myapp.models import MyModel + MyModel.objects.filter(id=1).filter(invalid=1) # E: Cannot resolve keyword 'invalid' into field. Choices are: id [misc] + MyModel.objects.filter(id=1).get(invalid=1) # E: Cannot resolve keyword 'invalid' into field. Choices are: id [misc] + MyModel.objects.filter(id=1).exclude(invalid=1) # E: Cannot resolve keyword 'invalid' into field. Choices are: id [misc] + MyModel.objects.filter(id=1).create(invalid=1) # E: Unexpected attribute "invalid" for model "MyModel" [misc] + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + + class MyModel(models.Model): ...