From e7611b3feec7599c579afaf0c3ca947c673b2fca Mon Sep 17 00:00:00 2001 From: Petter Friberg Date: Fri, 9 Aug 2024 20:08:02 +0200 Subject: [PATCH] Check annotated fields in `.filter` call (#2332) --- mypy_django_plugin/django/context.py | 16 ++++++- .../transformers/orm_lookups.py | 20 +++++++-- .../managers/querysets/test_annotate.yml | 43 +++++++++++++++++-- 3 files changed, 70 insertions(+), 9 deletions(-) diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index c9cbb2078..6c5301e59 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -472,11 +472,23 @@ def resolve_lookup_into_field( raise LookupsAreUnsupported() return self._resolve_field_from_parts(field_parts, model_cls) - def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType: + def resolve_lookup_expected_type( + self, ctx: MethodContext, model_cls: Type[Model], lookup: str, model_instance: Instance + ) -> MypyType: try: solved_lookup = self.solve_lookup_type(model_cls, lookup) except FieldError as exc: - ctx.api.fail(exc.args[0], ctx.context) + if ( + helpers.is_annotated_model(model_instance.type) + and model_instance.extra_attrs + and lookup in model_instance.extra_attrs.attrs + ): + return model_instance.extra_attrs.attrs[lookup] + + msg = exc.args[0] + if model_instance.extra_attrs: + msg = ", ".join((msg, *model_instance.extra_attrs.attrs.keys())) + ctx.api.fail(msg, ctx.context) return AnyType(TypeOfAny.from_error) if solved_lookup is None: diff --git a/mypy_django_plugin/transformers/orm_lookups.py b/mypy_django_plugin/transformers/orm_lookups.py index 956417a93..9860ec3f6 100644 --- a/mypy_django_plugin/transformers/orm_lookups.py +++ b/mypy_django_plugin/transformers/orm_lookups.py @@ -14,12 +14,26 @@ def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) lookup_kwargs = ctx.arg_names[1] if len(ctx.arg_names) >= 2 else [] provided_lookup_types = ctx.arg_types[1] if len(ctx.arg_types) >= 2 else [] - if not isinstance(ctx.type, Instance) or not ctx.type.args or not isinstance(ctx.type.args[0], Instance): + if ( + not isinstance(ctx.type, Instance) + or not ctx.type.args + or not isinstance(ctx.type.args[0], Instance) + or not helpers.is_model_type(ctx.type.args[0].type) + ): return ctx.default_return_type + api = helpers.get_typechecker_api(ctx) manager_info = ctx.type.type + model_type = ctx.type.args[0] model_cls_fullname = helpers.get_manager_to_model(manager_info) or ctx.type.args[0].type.fullname - model_cls = django_context.get_model_class_by_fullname(model_cls_fullname) + model_info = helpers.lookup_fully_qualified_typeinfo(api, model_cls_fullname) + if model_info is None: + return ctx.default_return_type + model_cls = ( + django_context.get_model_class_by_fullname(model_info.bases[0].type.fullname) + if helpers.is_annotated_model(model_info) + else django_context.get_model_class_by_fullname(model_cls_fullname) + ) if model_cls is None: return ctx.default_return_type @@ -33,7 +47,7 @@ def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) lookup_type: MypyType try: - lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg) + lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg, model_type) except UnregisteredModelError: lookup_type = AnyType(TypeOfAny.from_error) # Managers as provided_type is not supported yet diff --git a/tests/typecheck/managers/querysets/test_annotate.yml b/tests/typecheck/managers/querysets/test_annotate.yml index cd415db6f..c32c7235b 100644 --- a/tests/typecheck/managers/querysets/test_annotate.yml +++ b/tests/typecheck/managers/querysets/test_annotate.yml @@ -262,10 +262,7 @@ from myapp.models import User from django.db.models.expressions import F User.objects.annotate(abc=F('id')).filter(abc=1).values_list() - - # Invalid lookups are currently allowed after calling .annotate. - # It would be nice to in the future store the annotated names and use it when checking for valid lookups. - User.objects.annotate(abc=F('id')).filter(unknown_field=1).values_list() + User.objects.annotate(abc=F('id')).filter(unknown_field=1).values_list() # E: Cannot resolve keyword 'unknown_field' into field. Choices are: id, abc [misc] installed_apps: - myapp files: @@ -360,3 +357,41 @@ class Blog(models.Model): num_posts = models.IntegerField() text = models.CharField(max_length=100) + +- case: test_annotate_with_filter + main: | + from django.db import models + from myapp.models import Blog + + qs = Blog.objects.annotate(xyz=models.Count("entry")) + qs.filter(xyz=1) + qs.filter(annotate_wrong__gt=5) + + qs2 = qs.alias(alias_entries=models.Count("entry")) + qs2.filter(alias_wrong__gt=5, annotate_wrong__gt=5) + + Blog.objects.annotate().filter(xyz=1) + ( + Blog.objects.filter(xyz=1) + .annotate(xyz=models.Count("entry")) + .filter(xyz=1) + ) + out: | + main:6: error: Cannot resolve keyword 'annotate_wrong' into field. Choices are: entry, id, xyz [misc] + main:9: error: Cannot resolve keyword 'alias_wrong' into field. Choices are: entry, id, xyz [misc] + main:9: error: Cannot resolve keyword 'annotate_wrong' into field. Choices are: entry, id, xyz [misc] + main:11: error: Cannot resolve keyword 'xyz' into field. Choices are: entry, id [misc] + main:13: error: Cannot resolve keyword 'xyz' into field. Choices are: entry, id [misc] + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + + class Blog(models.Model): + pass + + class Entry(models.Model): + blog = models.ForeignKey(Blog, on_delete=models.CASCADE)