From 3ee9603ad566aa8c3aa210c268d7e886176f8bae Mon Sep 17 00:00:00 2001 From: Matthias Schoettle Date: Sat, 14 Sep 2024 00:50:30 +0000 Subject: [PATCH] Properly handle the post delete signal in transactions If the model changes are committed in a transaction, at the time when the post delete signal is handled instance.pk returns None. Store the instance's ID beforehand so that it is available when creating the CRUDEvent. --- easyaudit/signals/crud_flows.py | 10 +++++++--- easyaudit/signals/model_signals.py | 3 +++ tests/test_main.py | 23 +++++++++++++++++++++++ 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/easyaudit/signals/crud_flows.py b/easyaudit/signals/crud_flows.py index a87627df..84eeabb3 100644 --- a/easyaudit/signals/crud_flows.py +++ b/easyaudit/signals/crud_flows.py @@ -35,7 +35,7 @@ def get_current_user_details(): return user_id, user_pk_as_string -def log_event(event_type, instance, object_json_repr, **kwargs): +def log_event(event_type, instance, object_id, object_json_repr, **kwargs): user_id, user_pk_as_string = get_current_user_details() with transaction.atomic(using=DATABASE_ALIAS): audit_logger.crud( @@ -43,7 +43,7 @@ def log_event(event_type, instance, object_json_repr, **kwargs): "content_type_id": ContentType.objects.get_for_model(instance).id, "datetime": timezone.now(), "event_type": event_type, - "object_id": instance.pk, + "object_id": object_id, "object_json_repr": object_json_repr or "", "object_repr": str(instance), "user_id": user_id, @@ -70,6 +70,7 @@ def pre_save_crud_flow(instance, object_json_repr, changed_fields): log_event( CRUDEvent.UPDATE, instance, + instance.pk, object_json_repr, changed_fields=changed_fields, ) @@ -82,6 +83,7 @@ def post_save_crud_flow(instance, object_json_repr): log_event( CRUDEvent.CREATE, instance, + instance.pk, object_json_repr, ) except Exception: @@ -102,6 +104,7 @@ def m2m_changed_crud_flow( # noqa: PLR0913 log_event( event_type, instance, + instance.pk, object_json_repr, changed_fields=changed_fields, ) @@ -109,11 +112,12 @@ def m2m_changed_crud_flow( # noqa: PLR0913 handle_flow_exception(instance, "pre_save") -def post_delete_crud_flow(instance, object_json_repr): +def post_delete_crud_flow(instance, object_id, object_json_repr): try: log_event( CRUDEvent.DELETE, instance, + object_id, object_json_repr, ) diff --git a/easyaudit/signals/model_signals.py b/easyaudit/signals/model_signals.py index 132aa7bc..175f72dc 100644 --- a/easyaudit/signals/model_signals.py +++ b/easyaudit/signals/model_signals.py @@ -241,10 +241,13 @@ def post_delete(sender, instance, using, **kwargs): with transaction.atomic(using=using): object_json_repr = serializers.serialize("json", [instance]) + # instance.pk returns None if the changes are performed within a transaction + object_id = instance.pk crud_flow = partial( post_delete_crud_flow, instance=instance, + object_id=object_id, object_json_repr=object_json_repr, ) if getattr(settings, "TEST", False): diff --git a/tests/test_main.py b/tests/test_main.py index b938fd2d..2b2e5724 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,6 +4,7 @@ from asgiref.sync import sync_to_async from django.contrib.contenttypes.models import ContentType from django.core import management +from django.db import transaction from django.urls import reverse from django.utils.version import get_version from pytest_django.asserts import assertInHTML @@ -189,6 +190,28 @@ def test_delete(self, model): ) assert crud_event_qs.count() == 2 + @pytest.mark.django_db(transaction=True) + def test_delete_transaction(self, model, settings): + settings.TEST = False + + with transaction.atomic(): + obj = model.objects.create() + model.objects.all().delete() + + crud_event_qs = CRUDEvent.objects.filter( + object_id=obj.id, + content_type=ContentType.objects.get_for_model(obj), + event_type=CRUDEvent.CREATE, + ) + assert crud_event_qs.count() == 1 + + crud_event_qs = CRUDEvent.objects.filter( + object_id=obj.id, + content_type=ContentType.objects.get_for_model(obj), + event_type=CRUDEvent.DELETE, + ) + assert crud_event_qs.count() == 1 + @pytest.mark.usefixtures("_audit_logger") def test_propagate_exceptions(self, model, settings): with pytest.raises(ValueError, match="Test exception"):