From 9339228c642b8a917046c68d7ebfabbfb03b6f09 Mon Sep 17 00:00:00 2001 From: Rust Saiargaliev Date: Fri, 16 Aug 2024 14:28:38 +0200 Subject: [PATCH] Fix #488 -- Align GFK and content type fields --- model_bakery/baker.py | 26 ++++++++++++++++++++++---- model_bakery/content_types.py | 3 +-- tests/test_filling_fields.py | 10 ++++++++++ 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 7dd24da7..ccf35b26 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -519,7 +519,11 @@ def instance( auto_now_keys[k] = attrs[k] if BAKER_CONTENTTYPES and isinstance(field, GenericForeignKey): - generic_foreign_keys[k] = attrs.pop(k) + generic_foreign_keys[k] = { + "value": attrs.pop(k), + "content_type_field": field.ct_field, + "object_id_field": field.fk_field, + } instance = self.model(**attrs) if _commit: @@ -603,7 +607,7 @@ def _skip_field(self, field: Field) -> bool: # noqa: C901 ] if BAKER_CONTENTTYPES: - other_fields_to_skip.append(GenericRelation) + other_fields_to_skip.extend([GenericRelation, GenericForeignKey]) if isinstance(field, tuple(other_fields_to_skip)): return True @@ -687,8 +691,22 @@ def _handle_m2m(self, instance: Model): make(through_model, _using=self._using, **base_kwargs) def _handle_generic_foreign_keys(self, instance: Model, attrs: Dict[str, Any]): - for key, value in attrs.items(): - setattr(instance, key, value) + """Set content type and object id for GenericForeignKey fields.""" + for _field_name, data in attrs.items(): + content_type_field = data["content_type_field"] + object_id_field = data["object_id_field"] + value = data["value"] + # if the value is iterable, we assume it's a list of objects + # so we should take the next object from the iterator + if isinstance(value, collections.abc.Iterable): + value = next(value) + + setattr( + instance, + content_type_field, + contenttypes_models.ContentType.objects.get_for_model(value), + ) + setattr(instance, object_id_field, value.pk) def _remote_field( self, field: Union[ForeignKey, OneToOneField] diff --git a/model_bakery/content_types.py b/model_bakery/content_types.py index aced00be..0d6e49ba 100644 --- a/model_bakery/content_types.py +++ b/model_bakery/content_types.py @@ -7,11 +7,10 @@ __all__ = ["BAKER_CONTENTTYPES", "default_contenttypes_mapping"] if BAKER_CONTENTTYPES: - from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType from . import random_gen default_contenttypes_mapping[ContentType] = random_gen.gen_content_type # a small hack to generate random object for GenericForeignKey - default_contenttypes_mapping[GenericForeignKey] = random_gen.gen_content_type + # default_contenttypes_mapping[GenericForeignKey] = random_gen.gen_content_type diff --git a/tests/test_filling_fields.py b/tests/test_filling_fields.py index d4297a0c..37c38a56 100644 --- a/tests/test_filling_fields.py +++ b/tests/test_filling_fields.py @@ -284,6 +284,16 @@ def test_filling_content_type_field(self): assert isinstance(dummy.content_type, ContentType) assert dummy.content_type.model_class() is not None + def test_filling_from_content_object(self): + from django.contrib.contenttypes.models import ContentType + + dummy = baker.make( + models.DummyGenericForeignKeyModel, + content_object=baker.make(models.Profile), + ) + assert dummy.content_type == ContentType.objects.get_for_model(models.Profile) + assert dummy.object_id == models.Profile.objects.get().pk + def test_iteratively_filling_generic_foreign_key_field(self): """ Ensures private_fields are included in ``Baker.get_fields()``.