From 5875f03ce61b85dfd9ad34f7b871c231c358d432 Mon Sep 17 00:00:00 2001 From: Mariusz Felisiak Date: Mon, 13 Nov 2023 05:33:25 +0100 Subject: [PATCH] Fixed #34944 -- Made GeneratedField.output_field required. Regression in f333e3513e8bdf5ffeb6eeb63021c230082e6f95. --- django/db/models/fields/generated.py | 12 +-- docs/ref/models/fields.txt | 12 ++- docs/releases/5.0.txt | 6 +- tests/admin_views/models.py | 1 + .../test_ordinary_fields.py | 28 +++++-- tests/migrations/test_operations.py | 28 +++++-- tests/model_fields/models.py | 28 +++++-- tests/model_fields/test_generatedfield.py | 81 ++++++++++++++----- tests/schema/tests.py | 8 +- 9 files changed, 150 insertions(+), 54 deletions(-) diff --git a/django/db/models/fields/generated.py b/django/db/models/fields/generated.py index abafc3ad2748..ce811789c333 100644 --- a/django/db/models/fields/generated.py +++ b/django/db/models/fields/generated.py @@ -16,7 +16,7 @@ class GeneratedField(Field): _resolved_expression = None output_field = None - def __init__(self, *, expression, db_persist=None, output_field=None, **kwargs): + def __init__(self, *, expression, output_field, db_persist=None, **kwargs): if kwargs.setdefault("editable", False): raise ValueError("GeneratedField cannot be editable.") if not kwargs.setdefault("blank", True): @@ -29,7 +29,7 @@ def __init__(self, *, expression, db_persist=None, output_field=None, **kwargs): raise ValueError("GeneratedField.db_persist must be True or False.") self.expression = expression - self._output_field = output_field + self.output_field = output_field self.db_persist = db_persist super().__init__(**kwargs) @@ -51,11 +51,6 @@ def contribute_to_class(self, *args, **kwargs): self._resolved_expression = self.expression.resolve_expression( self._query, allow_joins=False ) - self.output_field = ( - self._output_field - if self._output_field is not None - else self._resolved_expression.output_field - ) # Register lookups from the output_field class. for lookup_name, lookup in self.output_field.get_class_lookups().items(): self.register_lookup(lookup, lookup_name=lookup_name) @@ -150,8 +145,7 @@ def deconstruct(self): del kwargs["editable"] kwargs["db_persist"] = self.db_persist kwargs["expression"] = self.expression - if self._output_field is not None: - kwargs["output_field"] = self._output_field + kwargs["output_field"] = self.output_field return name, path, args, kwargs def get_internal_type(self): diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index 9e945b7f27a7..a29d06c00a8a 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -1237,7 +1237,7 @@ when :attr:`~django.forms.Field.localize` is ``False`` or .. versionadded:: 5.0 -.. class:: GeneratedField(expression, db_persist=None, output_field=None, **kwargs) +.. class:: GeneratedField(expression, output_field, db_persist=None, **kwargs) A field that is always computed based on other fields in the model. This field is managed and updated by the database itself. Uses the ``GENERATED ALWAYS`` @@ -1259,6 +1259,10 @@ materialized view. the model (in the same database table). Generated fields cannot reference other generated fields. Database backends can impose further restrictions. +.. attribute:: GeneratedField.output_field + + A model field instance to define the field's data type. + .. attribute:: GeneratedField.db_persist Determines if the database column should occupy storage as if it were a @@ -1268,12 +1272,6 @@ materialized view. PostgreSQL only supports persisted columns. Oracle only supports virtual columns. -.. attribute:: GeneratedField.output_field - - An optional model field instance to define the field's data type. This can - be used to customize attributes like the field's collation. By default, the - output field is derived from ``expression``. - .. admonition:: Refresh the data Since the database always computed the value, the object must be reloaded diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index b60f7a6c7640..db75a6b0a39d 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -142,7 +142,11 @@ to create a field that is always computed from other fields. For example:: class Square(models.Model): side = models.IntegerField() - area = models.GeneratedField(expression=F("side") * F("side"), db_persist=True) + area = models.GeneratedField( + expression=F("side") * F("side"), + output_field=models.BigIntegerField(), + db_persist=True, + ) More options for declaring field choices ---------------------------------------- diff --git a/tests/admin_views/models.py b/tests/admin_views/models.py index 67d3ec4c86c7..bd2dc65d2e17 100644 --- a/tests/admin_views/models.py +++ b/tests/admin_views/models.py @@ -1147,6 +1147,7 @@ class Square(models.Model): area = models.GeneratedField( db_persist=True, expression=models.F("side") * models.F("side"), + output_field=models.BigIntegerField(), ) class Meta: diff --git a/tests/invalid_models_tests/test_ordinary_fields.py b/tests/invalid_models_tests/test_ordinary_fields.py index ceeb254e578d..82b431890664 100644 --- a/tests/invalid_models_tests/test_ordinary_fields.py +++ b/tests/invalid_models_tests/test_ordinary_fields.py @@ -1216,7 +1216,9 @@ def test_not_supported(self): class Model(models.Model): name = models.IntegerField() field = models.GeneratedField( - expression=models.F("name"), db_persist=db_persist + expression=models.F("name"), + output_field=models.IntegerField(), + db_persist=db_persist, ) expected_errors = [] @@ -1252,7 +1254,11 @@ class Model(models.Model): def test_not_supported_stored_required_db_features(self): class Model(models.Model): name = models.IntegerField() - field = models.GeneratedField(expression=models.F("name"), db_persist=True) + field = models.GeneratedField( + expression=models.F("name"), + output_field=models.IntegerField(), + db_persist=True, + ) class Meta: required_db_features = {"supports_stored_generated_columns"} @@ -1262,7 +1268,11 @@ class Meta: def test_not_supported_virtual_required_db_features(self): class Model(models.Model): name = models.IntegerField() - field = models.GeneratedField(expression=models.F("name"), db_persist=False) + field = models.GeneratedField( + expression=models.F("name"), + output_field=models.IntegerField(), + db_persist=False, + ) class Meta: required_db_features = {"supports_virtual_generated_columns"} @@ -1273,7 +1283,11 @@ class Meta: def test_not_supported_virtual(self): class Model(models.Model): name = models.IntegerField() - field = models.GeneratedField(expression=models.F("name"), db_persist=False) + field = models.GeneratedField( + expression=models.F("name"), + output_field=models.IntegerField(), + db_persist=False, + ) a = models.TextField() excepted_errors = ( @@ -1298,7 +1312,11 @@ class Model(models.Model): def test_not_supported_stored(self): class Model(models.Model): name = models.IntegerField() - field = models.GeneratedField(expression=models.F("name"), db_persist=True) + field = models.GeneratedField( + expression=models.F("name"), + output_field=models.IntegerField(), + db_persist=True, + ) a = models.TextField() expected_errors = ( diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index 57a9086c1963..21e8ab069a2e 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -5664,10 +5664,14 @@ def assertModelsAndTables(after_db): def _test_invalid_generated_field_changes(self, db_persist): regular = models.IntegerField(default=1) generated_1 = models.GeneratedField( - expression=F("pink") + F("pink"), db_persist=db_persist + expression=F("pink") + F("pink"), + output_field=models.IntegerField(), + db_persist=db_persist, ) generated_2 = models.GeneratedField( - expression=F("pink") + F("pink") + F("pink"), db_persist=db_persist + expression=F("pink") + F("pink") + F("pink"), + output_field=models.IntegerField(), + db_persist=db_persist, ) tests = [ ("test_igfc_1", regular, generated_1), @@ -5707,12 +5711,20 @@ def test_invalid_generated_field_persistency_change(self): migrations.AddField( "Pony", "modified_pink", - models.GeneratedField(expression=F("pink"), db_persist=True), + models.GeneratedField( + expression=F("pink"), + output_field=models.IntegerField(), + db_persist=True, + ), ), migrations.AlterField( "Pony", "modified_pink", - models.GeneratedField(expression=F("pink"), db_persist=False), + models.GeneratedField( + expression=F("pink"), + output_field=models.IntegerField(), + db_persist=False, + ), ), ] msg = ( @@ -5729,7 +5741,9 @@ def _test_add_generated_field(self, db_persist): "Pony", "modified_pink", models.GeneratedField( - expression=F("pink") + F("pink"), db_persist=db_persist + expression=F("pink") + F("pink"), + output_field=models.IntegerField(), + db_persist=db_persist, ), ) project_state, new_state = self.make_test_state(app_label, operation) @@ -5760,7 +5774,9 @@ def _test_remove_generated_field(self, db_persist): "Pony", "modified_pink", models.GeneratedField( - expression=F("pink") + F("pink"), db_persist=db_persist + expression=F("pink") + F("pink"), + output_field=models.IntegerField(), + db_persist=db_persist, ), ) project_state, new_state = self.make_test_state(app_label, operation) diff --git a/tests/model_fields/models.py b/tests/model_fields/models.py index 7804c198815a..b4b7b5bd4c5f 100644 --- a/tests/model_fields/models.py +++ b/tests/model_fields/models.py @@ -485,7 +485,11 @@ class UUIDGrandchild(UUIDChild): class GeneratedModel(models.Model): a = models.IntegerField() b = models.IntegerField() - field = models.GeneratedField(expression=F("a") + F("b"), db_persist=True) + field = models.GeneratedField( + expression=F("a") + F("b"), + output_field=models.IntegerField(), + db_persist=True, + ) class Meta: required_db_features = {"supports_stored_generated_columns"} @@ -494,7 +498,11 @@ class Meta: class GeneratedModelVirtual(models.Model): a = models.IntegerField() b = models.IntegerField() - field = models.GeneratedField(expression=F("a") + F("b"), db_persist=False) + field = models.GeneratedField( + expression=F("a") + F("b"), + output_field=models.IntegerField(), + db_persist=False, + ) class Meta: required_db_features = {"supports_virtual_generated_columns"} @@ -503,6 +511,7 @@ class Meta: class GeneratedModelParams(models.Model): field = models.GeneratedField( expression=Value("Constant", output_field=models.CharField(max_length=10)), + output_field=models.CharField(max_length=10), db_persist=True, ) @@ -513,6 +522,7 @@ class Meta: class GeneratedModelParamsVirtual(models.Model): field = models.GeneratedField( expression=Value("Constant", output_field=models.CharField(max_length=10)), + output_field=models.CharField(max_length=10), db_persist=False, ) @@ -520,7 +530,7 @@ class Meta: required_db_features = {"supports_virtual_generated_columns"} -class GeneratedModelOutputField(models.Model): +class GeneratedModelOutputFieldDbCollation(models.Model): name = models.CharField(max_length=10) lower_name = models.GeneratedField( expression=Lower("name"), @@ -532,7 +542,7 @@ class Meta: required_db_features = {"supports_stored_generated_columns"} -class GeneratedModelOutputFieldVirtual(models.Model): +class GeneratedModelOutputFieldDbCollationVirtual(models.Model): name = models.CharField(max_length=10) lower_name = models.GeneratedField( expression=Lower("name"), @@ -547,7 +557,10 @@ class Meta: class GeneratedModelNull(models.Model): name = models.CharField(max_length=10, null=True) lower_name = models.GeneratedField( - expression=Lower("name"), db_persist=True, null=True + expression=Lower("name"), + output_field=models.CharField(max_length=10), + db_persist=True, + null=True, ) class Meta: @@ -557,7 +570,10 @@ class Meta: class GeneratedModelNullVirtual(models.Model): name = models.CharField(max_length=10, null=True) lower_name = models.GeneratedField( - expression=Lower("name"), db_persist=False, null=True + expression=Lower("name"), + output_field=models.CharField(max_length=10), + db_persist=False, + null=True, ) class Meta: diff --git a/tests/model_fields/test_generatedfield.py b/tests/model_fields/test_generatedfield.py index a37e37498138..04d52f679974 100644 --- a/tests/model_fields/test_generatedfield.py +++ b/tests/model_fields/test_generatedfield.py @@ -1,5 +1,12 @@ from django.db import IntegrityError, connection -from django.db.models import F, FloatField, GeneratedField, IntegerField, Model +from django.db.models import ( + CharField, + F, + FloatField, + GeneratedField, + IntegerField, + Model, +) from django.db.models.functions import Lower from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test.utils import isolate_apps @@ -8,8 +15,8 @@ GeneratedModel, GeneratedModelNull, GeneratedModelNullVirtual, - GeneratedModelOutputField, - GeneratedModelOutputFieldVirtual, + GeneratedModelOutputFieldDbCollation, + GeneratedModelOutputFieldDbCollationVirtual, GeneratedModelParams, GeneratedModelParamsVirtual, GeneratedModelVirtual, @@ -19,41 +26,77 @@ class BaseGeneratedFieldTests(SimpleTestCase): def test_editable_unsupported(self): with self.assertRaisesMessage(ValueError, "GeneratedField cannot be editable."): - GeneratedField(expression=Lower("name"), editable=True, db_persist=False) + GeneratedField( + expression=Lower("name"), + output_field=CharField(max_length=255), + editable=True, + db_persist=False, + ) def test_blank_unsupported(self): with self.assertRaisesMessage(ValueError, "GeneratedField must be blank."): - GeneratedField(expression=Lower("name"), blank=False, db_persist=False) + GeneratedField( + expression=Lower("name"), + output_field=CharField(max_length=255), + blank=False, + db_persist=False, + ) def test_default_unsupported(self): msg = "GeneratedField cannot have a default." with self.assertRaisesMessage(ValueError, msg): - GeneratedField(expression=Lower("name"), default="", db_persist=False) + GeneratedField( + expression=Lower("name"), + output_field=CharField(max_length=255), + default="", + db_persist=False, + ) def test_database_default_unsupported(self): msg = "GeneratedField cannot have a database default." with self.assertRaisesMessage(ValueError, msg): - GeneratedField(expression=Lower("name"), db_default="", db_persist=False) + GeneratedField( + expression=Lower("name"), + output_field=CharField(max_length=255), + db_default="", + db_persist=False, + ) def test_db_persist_required(self): msg = "GeneratedField.db_persist must be True or False." with self.assertRaisesMessage(ValueError, msg): - GeneratedField(expression=Lower("name")) + GeneratedField( + expression=Lower("name"), output_field=CharField(max_length=255) + ) with self.assertRaisesMessage(ValueError, msg): - GeneratedField(expression=Lower("name"), db_persist=None) + GeneratedField( + expression=Lower("name"), + output_field=CharField(max_length=255), + db_persist=None, + ) def test_deconstruct(self): - field = GeneratedField(expression=F("a") + F("b"), db_persist=True) + field = GeneratedField( + expression=F("a") + F("b"), output_field=IntegerField(), db_persist=True + ) _, path, args, kwargs = field.deconstruct() self.assertEqual(path, "django.db.models.GeneratedField") self.assertEqual(args, []) - self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")}) + self.assertEqual(kwargs["db_persist"], True) + self.assertEqual(kwargs["expression"], F("a") + F("b")) + self.assertEqual( + kwargs["output_field"].deconstruct(), IntegerField().deconstruct() + ) @isolate_apps("model_fields") def test_get_col(self): class Square(Model): side = IntegerField() - area = GeneratedField(expression=F("side") * F("side"), db_persist=True) + area = GeneratedField( + expression=F("side") * F("side"), + output_field=IntegerField(), + db_persist=True, + ) col = Square._meta.get_field("area").get_col("alias") self.assertIsInstance(col.output_field, IntegerField) @@ -74,7 +117,9 @@ def test_cached_col(self): class Sum(Model): a = IntegerField() b = IntegerField() - total = GeneratedField(expression=F("a") + F("b"), db_persist=True) + total = GeneratedField( + expression=F("a") + F("b"), output_field=IntegerField(), db_persist=True + ) field = Sum._meta.get_field("total") cached_col = field.cached_col @@ -165,9 +210,9 @@ def test_output_field_lookups(self): with self.assertNumQueries(0), self.assertRaises(does_not_exist): self.base_model.objects.get(field__gte=overflow_value) - def test_output_field(self): + def test_output_field_db_collation(self): collation = connection.features.test_collations["virtual"] - m = self.output_field_model.objects.create(name="NAME") + m = self.output_field_db_collation_model.objects.create(name="NAME") field = m._meta.get_field("lower_name") db_parameters = field.db_parameters(connection) self.assertEqual(db_parameters["collation"], collation) @@ -178,7 +223,7 @@ def test_output_field(self): ) def test_db_type_parameters(self): - db_type_parameters = self.output_field_model._meta.get_field( + db_type_parameters = self.output_field_db_collation_model._meta.get_field( "lower_name" ).db_type_parameters(connection) self.assertEqual(db_type_parameters["max_length"], 11) @@ -202,7 +247,7 @@ def test_nullable(self): class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase): base_model = GeneratedModel nullable_model = GeneratedModelNull - output_field_model = GeneratedModelOutputField + output_field_db_collation_model = GeneratedModelOutputFieldDbCollation params_model = GeneratedModelParams @@ -210,5 +255,5 @@ class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase): class VirtualGeneratedFieldTests(GeneratedFieldTestMixin, TestCase): base_model = GeneratedModelVirtual nullable_model = GeneratedModelNullVirtual - output_field_model = GeneratedModelOutputFieldVirtual + output_field_db_collation_model = GeneratedModelOutputFieldDbCollationVirtual params_model = GeneratedModelParamsVirtual diff --git a/tests/schema/tests.py b/tests/schema/tests.py index 72f90c934b73..46d16e9fdb6b 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -829,7 +829,11 @@ def test_add_binaryfield_mediumblob(self): def test_add_generated_field_with_kt_model(self): class GeneratedFieldKTModel(Model): data = JSONField() - status = GeneratedField(expression=KT("data__status"), db_persist=True) + status = GeneratedField( + expression=KT("data__status"), + output_field=TextField(), + db_persist=True, + ) class Meta: app_label = "schema" @@ -844,7 +848,7 @@ class Meta: @isolate_apps("schema") @skipUnlessDBFeature("supports_stored_generated_columns") - def test_add_generated_field_with_output_field(self): + def test_add_generated_field(self): class GeneratedFieldOutputFieldModel(Model): price = DecimalField(max_digits=7, decimal_places=2) vat_price = GeneratedField(