diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 61e16a97..4051fa9a 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -51,7 +51,7 @@ if BAKER_CONTENTTYPES: from django.contrib.contenttypes import models as contenttypes_models - from django.contrib.contenttypes.fields import GenericRelation + from django.contrib.contenttypes.fields import GenericRelation, GenericForeignKey else: contenttypes_models = None GenericRelation = None @@ -622,7 +622,7 @@ def _skip_field(self, field: Field) -> bool: # noqa: C901 if field.name not in self.model_attrs: # noqa: SIM102 if field.name not in self.rel_fields and ( - field.null and not field.fill_optional + not field.fill_optional and field.null ): return True @@ -702,12 +702,16 @@ def generate_value(self, field: Field, commit: bool = True) -> Any: # noqa: C90 model. """ is_content_type_fk = False + is_generic_fk = False if BAKER_CONTENTTYPES: is_content_type_fk = isinstance(field, ForeignKey) and issubclass( self._remote_field(field).model, contenttypes_models.ContentType ) + is_generic_fk = isinstance(field, GenericForeignKey) + if is_generic_fk: + generator = self.type_mapping[GenericForeignKey] # we only use default unless the field is overwritten in `self.rel_fields` - if field.has_default() and field.name not in self.rel_fields: + elif field.has_default() and field.name not in self.rel_fields: if callable(field.default): return field.default() return field.default diff --git a/model_bakery/content_types.py b/model_bakery/content_types.py index d0d0c880..d530d393 100644 --- a/model_bakery/content_types.py +++ b/model_bakery/content_types.py @@ -8,7 +8,10 @@ if BAKER_CONTENTTYPES: from django.contrib.contenttypes.models import ContentType + from django.contrib.contenttypes.fields import GenericForeignKey from . import random_gen default_contenttypes_mapping[ContentType] = random_gen.gen_content_type + # a small hack to generate random object for GenericForeignKey + default_contenttypes_mapping[GenericForeignKey] = random_gen.gen_content_type diff --git a/tests/test_filling_fields.py b/tests/test_filling_fields.py index 6f25f638..d4297a0c 100644 --- a/tests/test_filling_fields.py +++ b/tests/test_filling_fields.py @@ -310,6 +310,13 @@ def test_iteratively_filling_generic_foreign_key_field(self): assert dummies[1].content_type == expected_content_type assert dummies[1].object_id == objects[1].pk + def test_with_fill_optional(self): + from django.contrib.contenttypes.models import ContentType + + dummy = baker.make(models.DummyGenericForeignKeyModel, _fill_optional=True) + assert isinstance(dummy.content_type, ContentType) + assert dummy.content_type.model_class() is not None + @pytest.mark.django_db class TestFillingForeignKeyFieldWithDefaultFunctionReturningId: