From 59f475470494ce5b8cbff816b1e5dafcbd10a3a3 Mon Sep 17 00:00:00 2001 From: Francesco Panico Date: Sat, 4 Mar 2023 17:59:53 +0000 Subject: [PATCH] Fixed #34362 -- Fixed FilteredRelation() crash on conditional expressions. Thanks zhu for the report and Simon Charette for reviews. --- django/db/models/sql/query.py | 45 ++++++++++++++--- tests/filtered_relation/tests.py | 85 ++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 7 deletions(-) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 6a9348af665d..a7839ccb4df4 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -65,20 +65,52 @@ def get_field_names_from_opts(opts): ) +def get_paths_from_expression(expr): + if isinstance(expr, F): + yield expr.name + elif hasattr(expr, "flatten"): + for child in expr.flatten(): + if isinstance(child, F): + yield child.name + elif isinstance(child, Q): + yield from get_children_from_q(child) + + def get_children_from_q(q): for child in q.children: if isinstance(child, Node): yield from get_children_from_q(child) - else: - yield child + elif isinstance(child, tuple): + lhs, rhs = child + yield lhs + if hasattr(rhs, "resolve_expression"): + yield from get_paths_from_expression(rhs) + elif hasattr(child, "resolve_expression"): + yield from get_paths_from_expression(child) def get_child_with_renamed_prefix(prefix, replacement, child): if isinstance(child, Node): return rename_prefix_from_q(prefix, replacement, child) - lhs, rhs = child - lhs = lhs.replace(prefix, replacement, 1) - return lhs, rhs + if isinstance(child, tuple): + lhs, rhs = child + lhs = lhs.replace(prefix, replacement, 1) + if not isinstance(rhs, F) and hasattr(rhs, "resolve_expression"): + rhs = get_child_with_renamed_prefix(prefix, replacement, rhs) + return lhs, rhs + + if isinstance(child, F): + child = child.copy() + child.name = child.name.replace(prefix, replacement, 1) + elif hasattr(child, "resolve_expression"): + child = child.copy() + child.set_source_expressions( + [ + get_child_with_renamed_prefix(prefix, replacement, grand_child) + for grand_child in child.get_source_expressions() + ] + ) + return child def rename_prefix_from_q(prefix, replacement, q): @@ -1618,7 +1650,6 @@ def _add_q( def add_filtered_relation(self, filtered_relation, alias): filtered_relation.alias = alias - lookups = dict(get_children_from_q(filtered_relation.condition)) relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type( filtered_relation.relation_name ) @@ -1627,7 +1658,7 @@ def add_filtered_relation(self, filtered_relation, alias): "FilteredRelation's relation_name cannot contain lookups " "(got %r)." % filtered_relation.relation_name ) - for lookup in chain(lookups): + for lookup in get_children_from_q(filtered_relation.condition): lookup_parts, lookup_field_parts, _ = self.solve_lookup_type(lookup) shift = 2 if not lookup_parts else 1 lookup_field_path = lookup_field_parts[:-shift] diff --git a/tests/filtered_relation/tests.py b/tests/filtered_relation/tests.py index f9735ca37148..5a21a47f369f 100644 --- a/tests/filtered_relation/tests.py +++ b/tests/filtered_relation/tests.py @@ -4,9 +4,11 @@ from django.db import connection, transaction from django.db.models import ( + BooleanField, Case, Count, DecimalField, + ExpressionWrapper, F, FilteredRelation, Q, @@ -15,6 +17,7 @@ When, ) from django.db.models.functions import Concat +from django.db.models.lookups import Exact, IStartsWith from django.test import TestCase from django.test.testcases import skipUnlessDBFeature @@ -707,6 +710,88 @@ def test_eq(self): FilteredRelation("book", condition=Q(book__title="b")), mock.ANY ) + def test_conditional_expression(self): + qs = Author.objects.annotate( + the_book=FilteredRelation("book", condition=Q(Value(False))), + ).filter(the_book__isnull=False) + self.assertSequenceEqual(qs, []) + + def test_expression_outside_relation_name(self): + qs = Author.objects.annotate( + book_editor=FilteredRelation( + "book__editor", + condition=Q( + Exact(F("book__author__name"), "Alice"), + Value(True), + book__title__startswith="Poem", + ), + ), + ).filter(book_editor__isnull=False) + self.assertSequenceEqual(qs, [self.author1]) + + def test_conditional_expression_with_case(self): + qs = Book.objects.annotate( + alice_author=FilteredRelation( + "author", + condition=Q( + Case(When(author__name="Alice", then=True), default=False), + ), + ), + ).filter(alice_author__isnull=False) + self.assertCountEqual(qs, [self.book1, self.book4]) + + def test_conditional_expression_outside_relation_name(self): + tests = [ + Q(Case(When(book__author__name="Alice", then=True), default=False)), + Q( + ExpressionWrapper( + Q(Value(True), Exact(F("book__author__name"), "Alice")), + output_field=BooleanField(), + ), + ), + ] + for condition in tests: + with self.subTest(condition=condition): + qs = Author.objects.annotate( + book_editor=FilteredRelation("book__editor", condition=condition), + ).filter(book_editor__isnull=True) + self.assertSequenceEqual(qs, [self.author2, self.author2]) + + def test_conditional_expression_with_lookup(self): + lookups = [ + Q(book__title__istartswith="poem"), + Q(IStartsWith(F("book__title"), "poem")), + ] + for condition in lookups: + with self.subTest(condition=condition): + qs = Author.objects.annotate( + poem_book=FilteredRelation("book", condition=condition) + ).filter(poem_book__isnull=False) + self.assertSequenceEqual(qs, [self.author1]) + + def test_conditional_expression_with_expressionwrapper(self): + qs = Author.objects.annotate( + poem_book=FilteredRelation( + "book", + condition=Q( + ExpressionWrapper( + Q(Exact(F("book__title"), "Poem by Alice")), + output_field=BooleanField(), + ), + ), + ), + ).filter(poem_book__isnull=False) + self.assertSequenceEqual(qs, [self.author1]) + + def test_conditional_expression_with_multiple_fields(self): + qs = Author.objects.annotate( + my_books=FilteredRelation( + "book__author", + condition=Q(Exact(F("book__author__name"), F("book__author__name"))), + ), + ).filter(my_books__isnull=True) + self.assertSequenceEqual(qs, []) + class FilteredRelationAggregationTests(TestCase): @classmethod