Skip to content

Commit

Permalink
Fix #488 -- Align GFK and content type fields
Browse files Browse the repository at this point in the history
  • Loading branch information
amureki committed Aug 16, 2024
1 parent d758c7a commit 9339228
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
26 changes: 22 additions & 4 deletions model_bakery/baker.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,11 @@ def instance(
auto_now_keys[k] = attrs[k]

if BAKER_CONTENTTYPES and isinstance(field, GenericForeignKey):
generic_foreign_keys[k] = attrs.pop(k)
generic_foreign_keys[k] = {
"value": attrs.pop(k),
"content_type_field": field.ct_field,
"object_id_field": field.fk_field,
}

instance = self.model(**attrs)
if _commit:
Expand Down Expand Up @@ -603,7 +607,7 @@ def _skip_field(self, field: Field) -> bool: # noqa: C901
]

if BAKER_CONTENTTYPES:
other_fields_to_skip.append(GenericRelation)
other_fields_to_skip.extend([GenericRelation, GenericForeignKey])

if isinstance(field, tuple(other_fields_to_skip)):
return True
Expand Down Expand Up @@ -687,8 +691,22 @@ def _handle_m2m(self, instance: Model):
make(through_model, _using=self._using, **base_kwargs)

def _handle_generic_foreign_keys(self, instance: Model, attrs: Dict[str, Any]):
for key, value in attrs.items():
setattr(instance, key, value)
"""Set content type and object id for GenericForeignKey fields."""
for _field_name, data in attrs.items():
content_type_field = data["content_type_field"]
object_id_field = data["object_id_field"]
value = data["value"]
# if the value is iterable, we assume it's a list of objects
# so we should take the next object from the iterator
if isinstance(value, collections.abc.Iterable):
value = next(value)

setattr(
instance,
content_type_field,
contenttypes_models.ContentType.objects.get_for_model(value),
)
setattr(instance, object_id_field, value.pk)

def _remote_field(
self, field: Union[ForeignKey, OneToOneField]
Expand Down
3 changes: 1 addition & 2 deletions model_bakery/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
__all__ = ["BAKER_CONTENTTYPES", "default_contenttypes_mapping"]

if BAKER_CONTENTTYPES:
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType

from . import random_gen

default_contenttypes_mapping[ContentType] = random_gen.gen_content_type
# a small hack to generate random object for GenericForeignKey
default_contenttypes_mapping[GenericForeignKey] = random_gen.gen_content_type
# default_contenttypes_mapping[GenericForeignKey] = random_gen.gen_content_type
10 changes: 10 additions & 0 deletions tests/test_filling_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ def test_filling_content_type_field(self):
assert isinstance(dummy.content_type, ContentType)
assert dummy.content_type.model_class() is not None

def test_filling_from_content_object(self):
from django.contrib.contenttypes.models import ContentType

dummy = baker.make(
models.DummyGenericForeignKeyModel,
content_object=baker.make(models.Profile),
)
assert dummy.content_type == ContentType.objects.get_for_model(models.Profile)
assert dummy.object_id == models.Profile.objects.get().pk

def test_iteratively_filling_generic_foreign_key_field(self):
"""
Ensures private_fields are included in ``Baker.get_fields()``.
Expand Down

0 comments on commit 9339228

Please sign in to comment.