Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid recursion when a model instance has multiple deferred audit fields #31

Merged
merged 6 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions field_audit/field_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def wrapper(cls):
cls.__init__ = _decorate_init(cls.__init__)
cls.save = _decorate_db_write(cls.save)
cls.delete = _decorate_db_write(cls.delete)
cls.refresh_from_db = _decorate_refresh_from_db(cls.refresh_from_db)
_audited_models[cls] = get_fqcn(cls) if class_path is None else class_path # noqa: E501
return cls
if not field_names:
Expand Down Expand Up @@ -136,6 +137,23 @@ def wrapper(self, *args, **kw):
return wrapper


def _decorate_refresh_from_db(func):
"""Decorates the "refresh from db" method on Model subclasses. This is
necessary to ensure that all audited fields are included in the refresh
to avoid recursively calling the refresh for deferred fields.

:param func: the "refresh from db" method to decorate
"""
@wraps(func)
def wrapper(self, using=None, fields=None, **kwargs):
if fields is not None:
fields = set(fields) | set(AuditEvent.field_names(self))
func(self, using, fields, **kwargs)

from .models import AuditEvent
return wrapper


def get_audited_models():
return _audited_models.copy()

Expand Down
16 changes: 15 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import ANY, Mock, patch

import django
from django.apps import apps
from django.conf import settings
from django.db import connection, models, transaction
from django.db.models import Case, When, Value
Expand Down Expand Up @@ -33,7 +34,6 @@
MakeAuditEventFromValuesException,
)
from .mocks import NoopAtomicTransaction

from .models import (
Aerodrome,
Aircraft,
Expand Down Expand Up @@ -376,6 +376,11 @@ class FlyByTailNumber(models.Model):
to_field="tail_number",
)

def cleanup_flyby():
del apps.all_models["tests"]["flybytailnumber"]
apps.clear_cache()

self.addCleanup(cleanup_flyby)
flyby = FlyByTailNumber(aircraft=Aircraft(tail_number="CGXII"))
self.assertEqual("CGXII", AuditEvent.get_field_value(flyby, "aircraft"))

Expand Down Expand Up @@ -1216,6 +1221,15 @@ def test_delete_audit_action_ignore_calls_super(self):
queryset.delete(audit_action=AuditAction.IGNORE)
super_meth.assert_called()

def test_delete_does_not_cause_recursion_error(self):
aircraft = Aircraft.objects.create()
object_pk = aircraft.id
aircraft = Aircraft.objects.defer("tail_number", "make_model").get(id=object_pk) # noqa: E501
aircraft.delete()
event = AuditEvent.objects.last()
assert event.is_delete is True
assert event.object_pk == object_pk


class TestAuditEventBootstrapping(TestCase):

Expand Down