Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ForeignKey queryset filters on un-swapped models #1495

Merged
merged 7 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions mypy_django_plugin/django/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mypy.types import AnyType, Instance, TypeOfAny, UnionType
from mypy.types import Type as MypyType

from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME

Expand Down Expand Up @@ -118,7 +119,19 @@ def get_model_fields(self, model_cls: Type[Model]) -> Iterator["Field[Any, Any]"
if isinstance(field, Field):
yield field

def get_model_foreign_keys(self, model_cls: Type[Model]) -> Iterator["ForeignKey[Any, Any]"]:
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignKey):
yield field

def get_model_related_fields(self, model_cls: Type[Model]) -> Iterator["RelatedField[Any, Any]"]:
"""Get model forward relations"""
for field in model_cls._meta.get_fields():
if isinstance(field, RelatedField):
yield field

def get_model_relations(self, model_cls: Type[Model]) -> Iterator[ForeignObjectRel]:
"""Get model reverse relations"""
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignObjectRel):
yield field
Expand All @@ -127,7 +140,7 @@ def get_field_lookup_exact_type(
self, api: TypeChecker, field: Union["Field[Any, Any]", ForeignObjectRel]
) -> MypyType:
if isinstance(field, (RelatedField, ForeignObjectRel)):
related_model_cls = field.related_model
related_model_cls = self.get_field_related_model_cls(field)
primary_key_field = self.get_primary_key_field(related_model_cls)
primary_key_type = self.get_field_get_type(api, primary_key_field, method="init")

Expand Down Expand Up @@ -210,9 +223,6 @@ def get_field_set_type_from_model_type_info(info: Optional[TypeInfo], field_name
continue

related_model = self.get_field_related_model_cls(field)
if related_model is None:
expected_types[field_name] = AnyType(TypeOfAny.from_error)
continue

if related_model._meta.proxy_for_model is not None:
related_model = related_model._meta.proxy_for_model
Expand Down Expand Up @@ -312,8 +322,6 @@ def get_field_get_type(
is_nullable = self.get_field_nullability(field, method)
if isinstance(field, RelatedField):
related_model_cls = self.get_field_related_model_cls(field)
if related_model_cls is None:
return AnyType(TypeOfAny.from_error)

if method in ("values", "values_list"):
primary_key_field = self.get_primary_key_field(related_model_cls)
Expand All @@ -327,9 +335,7 @@ def get_field_get_type(
else:
return helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_nullable)

def get_field_related_model_cls(
self, field: Union["RelatedField[Any, Any]", ForeignObjectRel]
) -> Optional[Type[Model]]:
def get_field_related_model_cls(self, field: Union["RelatedField[Any, Any]", ForeignObjectRel]) -> Type[Model]:
if isinstance(field, RelatedField):
related_model_cls = field.remote_field.model
else:
Expand All @@ -341,13 +347,15 @@ def get_field_related_model_cls(
related_model_cls = field.model
elif "." not in related_model_cls:
# same file model
related_model_fullname = field.model.__module__ + "." + related_model_cls
related_model_fullname = f"{field.model.__module__}.{related_model_cls}"
related_model_cls = self.get_model_class_by_fullname(related_model_fullname)
if related_model_cls is None:
raise UnregisteredModelError
else:
try:
related_model_cls = self.apps_registry.get_model(related_model_cls)
except LookupError:
return None
except LookupError as e:
raise UnregisteredModelError from e

return related_model_cls

Expand All @@ -363,13 +371,13 @@ def _resolve_field_from_parts(

field = currently_observed_model._meta.get_field(field_part)
if isinstance(field, RelatedField):
currently_observed_model = field.related_model
currently_observed_model = self.get_field_related_model_cls(field)
model_name = currently_observed_model._meta.model_name
if model_name is not None and field_part == (model_name + "_id"):
field = self.get_primary_key_field(currently_observed_model)

if isinstance(field, ForeignObjectRel):
currently_observed_model = field.related_model
currently_observed_model = self.get_field_related_model_cls(field)

# Guaranteed by `query.solve_lookup_type` before.
assert isinstance(field, (Field, ForeignObjectRel))
Expand Down Expand Up @@ -397,9 +405,15 @@ def solve_lookup_type(
field = query.get_meta().get_field(query_parts[0])
except FieldDoesNotExist:
return None

if len(query_parts) == 1:
return [], [query_parts[0]], False
sub_query = Query(field.related_model).solve_lookup_type("__")

if not isinstance(field, (RelatedField, ForeignObjectRel)):
return None

related_model = self.get_field_related_model_cls(field)
sub_query = Query(related_model).solve_lookup_type("__".join(query_parts[1:]))
entire_query_parts = [query_parts[0], *sub_query[1]]
return sub_query[0], entire_query_parts, sub_query[2]

Expand Down
2 changes: 2 additions & 0 deletions mypy_django_plugin/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class UnregisteredModelError(Exception):
"""The requested model is not registered"""
26 changes: 13 additions & 13 deletions mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import itertools
import sys
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Type

from django.db.models.fields.related import RelatedField
from mypy.modulefinder import mypy_path
from mypy.nodes import MypyFile, TypeInfo
from mypy.options import Options
Expand All @@ -20,6 +20,7 @@
import mypy_django_plugin.transformers.orm_lookups
from mypy_django_plugin.config import DjangoPluginConfig
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import fields, forms, init_create, meta, querysets, request, settings
from mypy_django_plugin.transformers.functional import resolve_str_promise_attribute
Expand Down Expand Up @@ -147,23 +148,22 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]:
if not defined_model_classes:
return []
deps = set()

for model_class in defined_model_classes:
# forward relations
for field in self.django_context.get_model_fields(model_class):
if isinstance(field, RelatedField):
for field in itertools.chain(
# forward relations
self.django_context.get_model_related_fields(model_class),
# reverse relations - `related_objects` is private API (according to docstring)
model_class._meta.related_objects, # type: ignore[attr-defined]
):
try:
related_model_cls = self.django_context.get_field_related_model_cls(field)
if related_model_cls is None:
continue
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname:
deps.add(self._new_dependency(related_model_module))
# reverse relations
# `related_objects` is private API (according to docstring)
for relation in model_class._meta.related_objects: # type: ignore[attr-defined]
related_model_cls = self.django_context.get_field_related_model_cls(relation)
except UnregisteredModelError:
continue
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname:
deps.add(self._new_dependency(related_model_module))

return list(deps) + [
# for QuerySet.annotate
self._new_dependency("django_stubs_ext"),
Expand Down
6 changes: 4 additions & 2 deletions mypy_django_plugin/transformers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mypy.types import Type as MypyType

from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,8 +60,9 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context

assert isinstance(current_field, RelatedField)

related_model_cls = django_context.get_field_related_model_cls(current_field)
if related_model_cls is None:
try:
related_model_cls = django_context.get_field_related_model_cls(current_field)
except UnregisteredModelError:
return AnyType(TypeOfAny.from_error)

default_related_field_type = set_descriptor_types_for_field(ctx)
Expand Down
64 changes: 31 additions & 33 deletions mypy_django_plugin/transformers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from django.db.models import Manager, Model
from django.db.models.fields import DateField, DateTimeField, Field
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.reverse_related import ForeignObjectRel, OneToOneRel
from mypy.checker import TypeChecker
from mypy.nodes import ARG_STAR2, Argument, AssignmentStmt, CallExpr, Context, NameExpr, TypeInfo, Var
Expand All @@ -15,6 +14,7 @@

from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.errorcodes import MANAGER_MISSING
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME
from mypy_django_plugin.transformers import fields
Expand Down Expand Up @@ -234,41 +234,41 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:

class AddRelatedModelsId(ModelClassInitializer):
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignKey):
for field in self.django_context.get_model_foreign_keys(model_cls):
try:
related_model_cls = self.django_context.get_field_related_model_cls(field)
if related_model_cls is None:
error_context: Context = self.ctx.cls
field_sym = self.ctx.cls.info.get(field.name)
if field_sym is not None and field_sym.node is not None:
error_context = field_sym.node
self.api.fail(
f"Cannot find model {field.related_model!r} referenced in field {field.name!r}",
ctx=error_context,
)
self.add_new_node_to_model_class(field.attname, AnyType(TypeOfAny.explicit))
continue
except UnregisteredModelError:
error_context: Context = self.ctx.cls
field_sym = self.ctx.cls.info.get(field.name)
if field_sym is not None and field_sym.node is not None:
error_context = field_sym.node
self.api.fail(
f"Cannot find model {field.related_model!r} referenced in field {field.name!r}",
ctx=error_context,
)
self.add_new_node_to_model_class(field.attname, AnyType(TypeOfAny.explicit))
continue

if related_model_cls._meta.abstract:
continue
if related_model_cls._meta.abstract:
continue

rel_target_field = self.django_context.get_related_target_field(related_model_cls, field)
if not rel_target_field:
continue
rel_target_field = self.django_context.get_related_target_field(related_model_cls, field)
if not rel_target_field:
continue

try:
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_target_field.__class__)
except helpers.IncompleteDefnException as exc:
if not self.api.final_iteration:
raise exc
else:
continue
try:
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_target_field.__class__)
except helpers.IncompleteDefnException as exc:
if not self.api.final_iteration:
raise exc
else:
continue

is_nullable = self.django_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(
field_info, is_set_nullable=is_nullable, is_get_nullable=is_nullable
)
self.add_new_node_to_model_class(field.attname, Instance(field_info, [set_type, get_type]))
is_nullable = self.django_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(
field_info, is_set_nullable=is_nullable, is_get_nullable=is_nullable
)
self.add_new_node_to_model_class(field.attname, Instance(field_info, [set_type, get_type]))


class AddManagers(ModelClassInitializer):
Expand Down Expand Up @@ -448,8 +448,6 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:
continue

related_model_cls = self.django_context.get_field_related_model_cls(relation)
if related_model_cls is None:
continue

try:
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls)
Expand Down
6 changes: 5 additions & 1 deletion mypy_django_plugin/transformers/orm_lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mypy.types import Type as MypyType

from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.helpers import is_annotated_model_fullname

Expand Down Expand Up @@ -36,7 +37,10 @@ def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext)
if is_annotated_model_fullname(model_cls_fullname):
lookup_type = AnyType(TypeOfAny.implementation_artifact)
else:
lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
try:
lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
except UnregisteredModelError:
lookup_type = AnyType(TypeOfAny.from_error)
# Managers as provided_type is not supported yet
if isinstance(provided_type, Instance) and helpers.has_any_of_bases(
provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME, fullnames.QUERYSET_CLASS_FULLNAME)
Expand Down
2 changes: 0 additions & 2 deletions mypy_django_plugin/transformers/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def get_field_type_from_lookup(
lookup_field, ForeignObjectRel
):
related_model_cls = django_context.get_field_related_model_cls(lookup_field)
if related_model_cls is None:
return AnyType(TypeOfAny.from_error)
lookup_field = django_context.get_primary_key_field(related_model_cls)

field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), lookup_field, method=method)
Expand Down
32 changes: 32 additions & 0 deletions tests/typecheck/fields/test_related.yml
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,7 @@
- case: test_fails_if_app_label_is_unknown_in_relation_field
main: |
from installed.models import InstalledModel
InstalledModel.objects.filter(non_installed__isnull=True)
installed_apps:
- installed
files:
Expand Down Expand Up @@ -975,3 +976,34 @@

class Book(PrintedGood):
name = models.CharField()
- case: test_foreign_key_to_as_string_filter_on_abstract
main: |
from myapp.models import Book, Publisher

installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models/__init__.py
content: |
from django.db import models
from django.db.models.query import QuerySet

class Publisher(models.Model):
name = models.CharField()

class MyModel(models.Model):
pass

class PrintedGood(MyModel):
publisher = models.ForeignKey(to="myapp.Publisher", on_delete=models.CASCADE)

@property
def siblings(self) -> QuerySet['PrintedGood']:
return self.__class__.objects.filter(publisher=self.publisher)

class Meta:
abstract = True

class Book(PrintedGood):
name = models.CharField()