From 94b6f101f7dc363a8e71593570b17527dbb9f77f Mon Sep 17 00:00:00 2001 From: Nick Pope Date: Sat, 30 Dec 2023 07:24:30 +0000 Subject: [PATCH] Fixed #29049 -- Added slicing notation to F expressions. Co-authored-by: Priyansh Saxena Co-authored-by: Niclas Olofsson Co-authored-by: David Smith Co-authored-by: Mariusz Felisiak Co-authored-by: Abhinav Yadav --- django/contrib/postgres/fields/array.py | 14 +++- django/db/models/expressions.py | 60 ++++++++++++++ django/db/models/fields/__init__.py | 15 ++++ docs/ref/models/expressions.txt | 22 ++++++ docs/releases/5.1.txt | 8 ++ tests/expressions/models.py | 4 + tests/expressions/tests.py | 101 ++++++++++++++++++++++++ tests/postgres_tests/test_array.py | 36 ++++++++- 8 files changed, 256 insertions(+), 4 deletions(-) diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index c8e8e132e01d..4171af82f9d2 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -234,6 +234,12 @@ def formfield(self, **kwargs): } ) + def slice_expression(self, expression, start, length): + # If length is not provided, don't specify an end to slice to the end + # of the array. + end = None if length is None else start + length - 1 + return SliceTransform(start, end, expression) + class ArrayRHSMixin: def __init__(self, lhs, rhs): @@ -351,9 +357,11 @@ def __init__(self, start, end, *args, **kwargs): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - if not lhs.endswith("]"): - lhs = "(%s)" % lhs - return "%s[%%s:%%s]" % lhs, (*params, self.start, self.end) + # self.start is set to 1 if slice start is not provided. + if self.end is None: + return f"({lhs})[%s:]", (*params, self.start) + else: + return f"({lhs})[%s:%s]", (*params, self.start, self.end) class SliceTransformFactory: diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index b67a2418d449..c20de5995a34 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -851,6 +851,9 @@ def __init__(self, name): def __repr__(self): return "{}({})".format(self.__class__.__name__, self.name) + def __getitem__(self, subscript): + return Sliced(self, subscript) + def resolve_expression( self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False ): @@ -925,6 +928,63 @@ def relabeled_clone(self, relabels): return self +class Sliced(F): + """ + An object that contains a slice of an F expression. + + Object resolves the column on which the slicing is applied, and then + applies the slicing if possible. + """ + + def __init__(self, obj, subscript): + super().__init__(obj.name) + self.obj = obj + if isinstance(subscript, int): + if subscript < 0: + raise ValueError("Negative indexing is not supported.") + self.start = subscript + 1 + self.length = 1 + elif isinstance(subscript, slice): + if (subscript.start is not None and subscript.start < 0) or ( + subscript.stop is not None and subscript.stop < 0 + ): + raise ValueError("Negative indexing is not supported.") + if subscript.step is not None: + raise ValueError("Step argument is not supported.") + if subscript.stop and subscript.start and subscript.stop < subscript.start: + raise ValueError("Slice stop must be greater than slice start.") + self.start = 1 if subscript.start is None else subscript.start + 1 + if subscript.stop is None: + self.length = None + else: + self.length = subscript.stop - (subscript.start or 0) + else: + raise TypeError("Argument to slice must be either int or slice instance.") + + def __repr__(self): + start = self.start - 1 + stop = None if self.length is None else start + self.length + subscript = slice(start, stop) + return f"{self.__class__.__qualname__}({self.obj!r}, {subscript!r})" + + def resolve_expression( + self, + query=None, + allow_joins=True, + reuse=None, + summarize=False, + for_save=False, + ): + resolved = query.resolve_ref(self.name, allow_joins, reuse, summarize) + if isinstance(self.obj, (OuterRef, self.__class__)): + expr = self.obj.resolve_expression( + query, allow_joins, reuse, summarize, for_save + ) + else: + expr = resolved + return resolved.output_field.slice_expression(expr, self.start, self.length) + + @deconstructible(path="django.db.models.Func") class Func(SQLiteNumericMixin, Expression): """An SQL function call.""" diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 41735d3b7f18..5186f0c414dd 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -15,6 +15,7 @@ from django.db import connection, connections, router from django.db.models.constants import LOOKUP_SEP from django.db.models.query_utils import DeferredAttribute, RegisterLookupMixin +from django.db.utils import NotSupportedError from django.utils import timezone from django.utils.choices import ( BlankChoiceIterator, @@ -1143,6 +1144,10 @@ def value_from_object(self, obj): """Return the value of this field in the given model instance.""" return getattr(obj, self.attname) + def slice_expression(self, expression, start, length): + """Return a slice of this field.""" + raise NotSupportedError("This field does not support slicing.") + class BooleanField(Field): empty_strings_allowed = False @@ -1303,6 +1308,11 @@ def deconstruct(self): kwargs["db_collation"] = self.db_collation return name, path, args, kwargs + def slice_expression(self, expression, start, length): + from django.db.models.functions import Substr + + return Substr(expression, start, length) + class CommaSeparatedIntegerField(CharField): default_validators = [validators.validate_comma_separated_integer_list] @@ -2497,6 +2507,11 @@ def deconstruct(self): kwargs["db_collation"] = self.db_collation return name, path, args, kwargs + def slice_expression(self, expression, start, length): + from django.db.models.functions import Substr + + return Substr(expression, start, length) + class TimeField(DateTimeCheckMixin, Field): empty_strings_allowed = False diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 9d85442d9ca6..67baef7dfc26 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -183,6 +183,28 @@ the field value of each one, and saving each one back to the database:: * getting the database, rather than Python, to do work * reducing the number of queries some operations require +.. _slicing-using-f: + +Slicing ``F()`` expressions +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 5.1 + +For string-based fields, text-based fields, and +:class:`~django.contrib.postgres.fields.ArrayField`, you can use Python's +array-slicing syntax. The indices are 0-based and the ``step`` argument to +``slice`` is not supported. For example: + +.. code-block:: pycon + + >>> # Replacing a name with a substring of itself. + >>> writer = Writers.objects.get(name="Priyansh") + >>> writer.name = F("name")[1:5] + >>> writer.save() + >>> writer.refresh_from_db() + >>> writer.name + 'riya' + .. _avoiding-race-conditions-using-f: Avoiding race conditions using ``F()`` diff --git a/docs/releases/5.1.txt b/docs/releases/5.1.txt index cc72346eef01..b825e9be4f2a 100644 --- a/docs/releases/5.1.txt +++ b/docs/releases/5.1.txt @@ -184,6 +184,14 @@ Models * :meth:`.QuerySet.order_by` now supports ordering by annotation transforms such as ``JSONObject`` keys and ``ArrayAgg`` indices. +* :class:`F() ` and :class:`OuterRef() + ` expressions that output + :class:`~django.db.models.CharField`, :class:`~django.db.models.EmailField`, + :class:`~django.db.models.SlugField`, :class:`~django.db.models.URLField`, + :class:`~django.db.models.TextField`, or + :class:`~django.contrib.postgres.fields.ArrayField` can now be :ref:`sliced + `. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/expressions/models.py b/tests/expressions/models.py index 0a8a0a6584db..bd4db9050e6a 100644 --- a/tests/expressions/models.py +++ b/tests/expressions/models.py @@ -106,3 +106,7 @@ class UUIDPK(models.Model): class UUID(models.Model): uuid = models.UUIDField(null=True) uuid_fk = models.ForeignKey(UUIDPK, models.CASCADE, null=True) + + +class Text(models.Model): + name = models.TextField() diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 073e2e925838..cbb441601cd4 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -84,6 +84,7 @@ RemoteEmployee, Result, SimulationRun, + Text, Time, ) @@ -205,6 +206,100 @@ def test_update(self): ], ) + def _test_slicing_of_f_expressions(self, model): + tests = [ + (F("name")[:], "Example Inc.", "Example Inc."), + (F("name")[:7], "Example Inc.", "Example"), + (F("name")[:6][:5], "Example", "Examp"), # Nested slicing. + (F("name")[0], "Examp", "E"), + (F("name")[5], "E", ""), + (F("name")[7:], "Foobar Ltd.", "Ltd."), + (F("name")[0:10], "Ltd.", "Ltd."), + (F("name")[2:7], "Test GmbH", "st Gm"), + (F("name")[1:][:3], "st Gm", "t G"), + (F("name")[2:2], "t G", ""), + ] + for expression, name, expected in tests: + with self.subTest(expression=expression, name=name, expected=expected): + obj = model.objects.get(name=name) + obj.name = expression + obj.save() + obj.refresh_from_db() + self.assertEqual(obj.name, expected) + + def test_slicing_of_f_expressions_charfield(self): + self._test_slicing_of_f_expressions(Company) + + def test_slicing_of_f_expressions_textfield(self): + Text.objects.bulk_create( + [Text(name=company.name) for company in Company.objects.all()] + ) + self._test_slicing_of_f_expressions(Text) + + def test_slicing_of_f_expressions_with_annotate(self): + qs = Company.objects.annotate( + first_three=F("name")[:3], + after_three=F("name")[3:], + random_four=F("name")[2:5], + first_letter_slice=F("name")[:1], + first_letter_index=F("name")[0], + ) + tests = [ + ("first_three", ["Exa", "Foo", "Tes"]), + ("after_three", ["mple Inc.", "bar Ltd.", "t GmbH"]), + ("random_four", ["amp", "oba", "st "]), + ("first_letter_slice", ["E", "F", "T"]), + ("first_letter_index", ["E", "F", "T"]), + ] + for annotation, expected in tests: + with self.subTest(annotation): + self.assertCountEqual(qs.values_list(annotation, flat=True), expected) + + def test_slicing_of_f_expression_with_annotated_expression(self): + qs = Company.objects.annotate( + new_name=Case( + When(based_in_eu=True, then=Concat(Value("EU:"), F("name"))), + default=F("name"), + ), + first_two=F("new_name")[:3], + ) + self.assertCountEqual( + qs.values_list("first_two", flat=True), + ["Exa", "EU:", "Tes"], + ) + + def test_slicing_of_f_expressions_with_negative_index(self): + msg = "Negative indexing is not supported." + indexes = [slice(0, -4), slice(-4, 0), slice(-4), -5] + for i in indexes: + with self.subTest(i=i), self.assertRaisesMessage(ValueError, msg): + F("name")[i] + + def test_slicing_of_f_expressions_with_slice_stop_less_than_slice_start(self): + msg = "Slice stop must be greater than slice start." + with self.assertRaisesMessage(ValueError, msg): + F("name")[4:2] + + def test_slicing_of_f_expressions_with_invalid_type(self): + msg = "Argument to slice must be either int or slice instance." + with self.assertRaisesMessage(TypeError, msg): + F("name")["error"] + + def test_slicing_of_f_expressions_with_step(self): + msg = "Step argument is not supported." + with self.assertRaisesMessage(ValueError, msg): + F("name")[::4] + + def test_slicing_of_f_unsupported_field(self): + msg = "This field does not support slicing." + with self.assertRaisesMessage(NotSupportedError, msg): + Company.objects.update(num_chairs=F("num_chairs")[:4]) + + def test_slicing_of_outerref(self): + inner = Company.objects.filter(name__startswith=OuterRef("ceo__firstname")[0]) + outer = Company.objects.filter(Exists(inner)).values_list("name", flat=True) + self.assertSequenceEqual(outer, ["Foobar Ltd."]) + def test_arithmetic(self): # We can perform arithmetic operations in expressions # Make sure we have 2 spare chairs @@ -2359,6 +2454,12 @@ def test_expressions(self): repr(Func("published", function="TO_CHAR")), "Func(F(published), function=TO_CHAR)", ) + self.assertEqual( + repr(F("published")[0:2]), "Sliced(F(published), slice(0, 2, None))" + ) + self.assertEqual( + repr(OuterRef("name")[1:5]), "Sliced(OuterRef(name), slice(1, 5, None))" + ) self.assertEqual(repr(OrderBy(Value(1))), "OrderBy(Value(1), descending=False)") self.assertEqual(repr(RawSQL("table.col", [])), "RawSQL(table.col, [])") self.assertEqual( diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 8aaa7be07773..386a0afa3a97 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -10,7 +10,7 @@ from django.core.exceptions import FieldError from django.core.management import call_command from django.db import IntegrityError, connection, models -from django.db.models.expressions import Exists, OuterRef, RawSQL, Value +from django.db.models.expressions import Exists, F, OuterRef, RawSQL, Value from django.db.models.functions import Cast, JSONObject, Upper from django.test import TransactionTestCase, override_settings, skipUnlessDBFeature from django.test.utils import isolate_apps @@ -594,6 +594,40 @@ def test_slice_annotation(self): [None, [1], [2], [2, 3], [20, 30]], ) + def test_slicing_of_f_expressions(self): + tests = [ + (F("field")[:2], [1, 2]), + (F("field")[2:], [3, 4]), + (F("field")[1:3], [2, 3]), + (F("field")[3], [4]), + (F("field")[:3][1:], [2, 3]), # Nested slicing. + (F("field")[:3][1], [2]), # Slice then index. + ] + for expression, expected in tests: + with self.subTest(expression=expression, expected=expected): + instance = IntegerArrayModel.objects.create(field=[1, 2, 3, 4]) + instance.field = expression + instance.save() + instance.refresh_from_db() + self.assertEqual(instance.field, expected) + + def test_slicing_of_f_expressions_with_annotate(self): + IntegerArrayModel.objects.create(field=[1, 2, 3]) + annotated = IntegerArrayModel.objects.annotate( + first_two=F("field")[:2], + after_two=F("field")[2:], + random_two=F("field")[1:3], + ).get() + self.assertEqual(annotated.first_two, [1, 2]) + self.assertEqual(annotated.after_two, [3]) + self.assertEqual(annotated.random_two, [2, 3]) + + def test_slicing_of_f_expressions_with_len(self): + queryset = NullableIntegerArrayModel.objects.annotate( + subarray=F("field")[:1] + ).filter(field__len=F("subarray__len")) + self.assertSequenceEqual(queryset, self.objs[:2]) + def test_usage_in_subquery(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(