Skip to content

Commit

Permalink
feat: Support F and Concat expressions in annotate()
Browse files Browse the repository at this point in the history
Refs #735
  • Loading branch information
wookkl authored May 24, 2024
1 parent 47dcb9c commit a0aeb58
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
26 changes: 24 additions & 2 deletions modeltranslation/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.backends.utils import CursorWrapper
from django.db.models import Field, Model
from django.db.models import Field, Model, F
from django.db.models.expressions import Col
from django.db.models.functions import Concat, ConcatPair
from django.db.models.lookups import Lookup
from django.db.models.query import QuerySet, ValuesIterable
from django.db.models.utils import create_namedtuple_class
from django.utils.tree import Node

from modeltranslation._typing import Self, AutoPopulate
from modeltranslation.fields import TranslationField
from modeltranslation.thread_context import auto_populate_mode
from modeltranslation.utils import (
Expand All @@ -30,7 +32,6 @@
get_language,
resolution_order,
)
from modeltranslation._typing import Self, AutoPopulate

_C2F_CACHE: dict[tuple[type[Model], str], Field] = {}
_F2TM_CACHE: dict[type[Model], dict[str, type[Model]]] = {}
Expand Down Expand Up @@ -513,6 +514,27 @@ def dates(self, field_name: str, *args: Any, **kwargs: Any) -> Self:
new_key = rewrite_lookup_key(self.model, field_name)
return super().dates(new_key, *args, **kwargs)

def _rewrite_concat(self, concat: Concat | ConcatPair):
new_source_expressions = []
for exp in concat.source_expressions:
if isinstance(exp, (Concat, ConcatPair)):
exp = self._rewrite_concat(exp)
if isinstance(exp, F):
exp = self._rewrite_f(exp)
new_source_expressions.append(exp)
concat.set_source_expressions(new_source_expressions)
return concat

def annotate(self, *args: Any, **kwargs: Any) -> Self:
if not self._rewrite:
return super().annotate(*args, **kwargs)
for key, val in list(kwargs.items()):
if isinstance(val, models.F):
kwargs[key] = self._rewrite_f(val)
if isinstance(val, Concat):
kwargs[key] = self._rewrite_concat(val)
return super().annotate(*args, **kwargs)


class FallbackValuesIterable(ValuesIterable):
queryset: MultilingualQuerySet[Model]
Expand Down
41 changes: 40 additions & 1 deletion modeltranslation/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from django.core.management.base import CommandError
from django.db import IntegrityError
from django.db.models import CharField, Count, F, Q, TextField, Value
from django.db.models.functions import Cast
from django.db.models.functions import Cast, Concat
from django.test import TestCase, TransactionTestCase
from django.test.utils import override_settings
from django.utils.translation import get_language, override, trans_real
Expand Down Expand Up @@ -3670,6 +3670,45 @@ def test_distinct(self):
assert titles_for_en == (("title_1_en", "desc_1_en"), ("title_2_en", "desc_1_en"))
assert titles_for_de == (("title_1_de", "desc_1_de"), ("title_2_de", "desc_1_de"))

def test_annotate(self):
"""Test if annotating is language-aware."""
test = models.TestModel.objects.create(title_en="title_en", title_de="title_de")

assert "en" == get_language()
assert (
models.TestModel.objects.annotate(custom_title=F("title")).values_list(
"custom_title", flat=True
)[0]
== "title_en"
)
with override("de"):
assert (
models.TestModel.objects.annotate(custom_title=F("title")).values_list(
"custom_title", flat=True
)[0]
== "title_de"
)
assert (
models.TestModel.objects.annotate(
custom_title=Concat(F("title"), Value("value1"), Value("value2"))
).values_list("custom_title", flat=True)[0]
== "title_devalue1value2"
)
assert (
models.TestModel.objects.annotate(
custom_title=Concat(F("title"), Concat(F("title"), Value("value")))
).values_list("custom_title", flat=True)[0]
== "title_detitle_devalue"
)
models.ForeignKeyModel.objects.create(test=test)
models.ForeignKeyModel.objects.create(test=test)
assert (
models.TestModel.objects.annotate(Count("test_fks")).values_list(
"test_fks__count", flat=True
)[0]
== 2
)


class TranslationModelFormTest(ModeltranslationTestBase):
def test_fields(self):
Expand Down

0 comments on commit a0aeb58

Please sign in to comment.