Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check model fields on filtering methods of queryset types #2277

Merged
merged 1 commit into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
DynamicClassDefContext,
FunctionContext,
MethodContext,
SemanticAnalyzerPluginInterface,
)
from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, Instance, LiteralType, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType
Expand Down Expand Up @@ -63,9 +62,7 @@ def get_django_metadata(model_info: TypeInfo) -> DjangoTypeMetadata:
return cast(DjangoTypeMetadata, model_info.metadata.setdefault("django", {}))


def get_django_metadata_bases(
model_info: TypeInfo, key: Literal["baseform_bases", "manager_bases", "queryset_bases"]
) -> Dict[str, int]:
def get_django_metadata_bases(model_info: TypeInfo, key: Literal["baseform_bases", "queryset_bases"]) -> Dict[str, int]:
return get_django_metadata(model_info).setdefault(key, cast(Dict[str, int], {}))


Expand Down Expand Up @@ -422,13 +419,6 @@ def add_new_sym_for_info(
info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True, no_serialize=no_serialize)


def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None:
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
bases = get_django_metadata_bases(sym.node, "manager_bases")
bases[fullname] = 1


def is_abstract_model(model: TypeInfo) -> bool:
if model.fullname in fullnames.DJANGO_ABSTRACT_MODELS:
return True
Expand Down
74 changes: 25 additions & 49 deletions mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
import sys
from functools import partial
from functools import cached_property, partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

from mypy.build import PRI_MED, PRI_MYPY
Expand All @@ -19,7 +19,6 @@
)
from mypy.types import Type as MypyType

import mypy_django_plugin.transformers.orm_lookups
from mypy_django_plugin.config import DjangoPluginConfig
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
Expand All @@ -31,6 +30,7 @@
manytomany,
manytoone,
meta,
orm_lookups,
querysets,
request,
settings,
Expand Down Expand Up @@ -60,10 +60,6 @@ def transform_form_class(ctx: ClassDefContext) -> None:
forms.make_meta_nested_class_inherit_from_any(ctx)


def add_new_manager_base_hook(ctx: ClassDefContext) -> None:
helpers.add_new_manager_base(ctx.api, ctx.cls.fullname)


class NewSemanalDjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
Expand All @@ -83,15 +79,6 @@ def _get_current_queryset_bases(self) -> Dict[str, int]:
else:
return {}

def _get_current_manager_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.MANAGER_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
bases = helpers.get_django_metadata_bases(model_sym.node, "manager_bases")
bases[fullnames.MANAGER_CLASS_FULLNAME] = 1
return bases
else:
return {}

def _get_current_form_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.BASEFORM_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
Expand Down Expand Up @@ -165,10 +152,6 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
if fullname == "django.contrib.auth.get_user_model":
return partial(settings.get_user_model_hook, django_context=self.django_context)

manager_bases = self._get_current_manager_bases()
if fullname in manager_bases:
return querysets.determine_proper_manager_type

info = self._get_typeinfo_or_none(fullname)
if info:
if info.has_base(fullnames.FIELD_FULLNAME):
Expand All @@ -177,8 +160,26 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
if helpers.is_model_type(info):
return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)

if info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
return querysets.determine_proper_manager_type

return None

@cached_property
def manager_and_queryset_method_hooks(self) -> Dict[str, Callable[[MethodContext], MypyType]]:
typecheck_filtering_method = partial(orm_lookups.typecheck_queryset_filter, django_context=self.django_context)
return {
"values": partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context),
"values_list": partial(
querysets.extract_proper_type_queryset_values_list, django_context=self.django_context
),
"annotate": partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context),
"create": partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context),
"filter": typecheck_filtering_method,
"get": typecheck_filtering_method,
"exclude": typecheck_filtering_method,
}

def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]:
class_fullname, _, method_name = fullname.rpartition(".")
# Methods called very often -- short circuit for minor speed up
Expand Down Expand Up @@ -208,38 +209,17 @@ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], M
}
return hooks.get(class_fullname)

manager_classes = self._get_current_manager_bases()

if method_name == "values":
if method_name in self.manager_and_queryset_method_hooks:
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)

elif method_name == "values_list":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)

elif method_name == "annotate":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context)

if info and helpers.has_any_of_bases(
info, [fullnames.QUERYSET_CLASS_FULLNAME, fullnames.MANAGER_CLASS_FULLNAME]
):
return self.manager_and_queryset_method_hooks[method_name]
elif method_name == "get_field":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)

elif method_name == "create":
# We need `BASE_MANAGER_CLASS_FULLNAME` to check abstract models.
if class_fullname in manager_classes or class_fullname == fullnames.BASE_MANAGER_CLASS_FULLNAME:
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
elif method_name in {"filter", "get", "exclude"} and class_fullname in manager_classes:
return partial(
mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter,
django_context=self.django_context,
)

return None

def get_customize_class_mro_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
Expand All @@ -262,10 +242,6 @@ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefConte
if sym is not None and isinstance(sym.node, TypeInfo) and helpers.is_model_type(sym.node):
return partial(process_model_class, django_context=self.django_context)

# Base class is a Manager class definition
if fullname in self._get_current_manager_bases():
return add_new_manager_base_hook

# Base class is a Form class definition
if fullname in self._get_current_form_bases():
return transform_form_class
Expand Down
6 changes: 0 additions & 6 deletions mypy_django_plugin/transformers/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,6 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
ctx.api.defer()
return

# So that the plugin will reparameterize the manager when it is constructed inside of a Model definition
helpers.add_new_manager_base(semanal_api, new_manager_info.fullname)


def register_dynamically_created_manager(fullname: str, manager_name: str, manager_base: TypeInfo) -> None:
manager_base.metadata.setdefault("from_queryset_managers", {})
Expand Down Expand Up @@ -558,9 +555,6 @@ def create_new_manager_class_from_as_manager_method(ctx: DynamicClassDefContext)
manager_base=manager_base,
)

# So that the plugin will reparameterize the manager when it is constructed inside of a Model definition
helpers.add_new_manager_base(semanal_api, new_manager_info.fullname)

# Whenever `<QuerySet>.as_manager()` isn't called at class level, we want to ensure
# that the variable is an instance of our generated manager. Instead of the return
# value of `.as_manager()`. Though model argument is populated as `Any`.
Expand Down
1 change: 0 additions & 1 deletion mypy_django_plugin/transformers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,6 @@ def create_many_related_manager(self, model: Instance) -> None:
helpers.set_many_to_many_manager_info(
to=model.type, derived_from="_default_manager", manager_info=related_manager_info
)
helpers.add_new_manager_base(self.api, related_manager_info.fullname)


class MetaclassAdjustments(ModelClassInitializer):
Expand Down
17 changes: 17 additions & 0 deletions tests/typecheck/managers/test_managers.yml
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,20 @@
def get_instance(self) -> int:
pass
objects = MyManager()

- case: test_typechecks_filter_methods_of_queryset_type
main: |
from myapp.models import MyModel
MyModel.objects.filter(id=1).filter(invalid=1) # E: Cannot resolve keyword 'invalid' into field. Choices are: id [misc]
MyModel.objects.filter(id=1).get(invalid=1) # E: Cannot resolve keyword 'invalid' into field. Choices are: id [misc]
MyModel.objects.filter(id=1).exclude(invalid=1) # E: Cannot resolve keyword 'invalid' into field. Choices are: id [misc]
MyModel.objects.filter(id=1).create(invalid=1) # E: Unexpected attribute "invalid" for model "MyModel" [misc]
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models

class MyModel(models.Model): ...
Loading