From 4d024559f873c2e673232ad0bf95ee6bb9396be3 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Thu, 11 Jan 2024 09:37:01 +0200 Subject: [PATCH 1/6] Reproduce max recursion depth reached error --- tests/models.py | 72 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_models.py | 14 +++++++++ 2 files changed, 86 insertions(+) diff --git a/tests/models.py b/tests/models.py index acf2060..d1dfddf 100644 --- a/tests/models.py +++ b/tests/models.py @@ -10,6 +10,7 @@ ForeignKey, IntegerField, JSONField, + ManyToManyField ) from field_audit import audit_fields @@ -109,3 +110,74 @@ class PkAuto(Model): @audit_fields("id") class PkJson(Model): id = JSONField(primary_key=True) + + +class PromptObjectManager(AuditingManager): + pass + + +class ExperimentObjectManager(AuditingManager): + pass + + +class SourceMaterialObjectManager(AuditingManager): + pass + + +class SafetyLayerObjectManager(AuditingManager): + pass + + +class ConsentFormObjectManager(AuditingManager): + pass + + +@audit_fields("name", audit_special_queryset_writes=True) +class Prompt(Model): + objects = PromptObjectManager() + name = CharField(max_length=50) + + +@audit_fields("name", audit_special_queryset_writes=True) +class SourceMaterial(Model): + objects = SourceMaterialObjectManager() + name = CharField(max_length=50) + + +@audit_fields("name", audit_special_queryset_writes=True) +class SafetyLayer(Model): + objects = SafetyLayerObjectManager() + name = CharField(max_length=50) + + +@audit_fields("name", audit_special_queryset_writes=True) +class ConsentForm(Model): + objects = ConsentFormObjectManager() + name = CharField(max_length=50) + + +@audit_fields( + "chatbot_prompt", + "safety_layers", + "tools_enabled", + "consent_form", + audit_special_queryset_writes=True +) +class Experiment(Model): + objects = ExperimentObjectManager() + chatbot_prompt = ForeignKey( + Prompt, + on_delete=CASCADE, + related_name="experiments" + ) + safety_layers = ManyToManyField( + SafetyLayer, + related_name="experiments", + blank=True + ) + tools_enabled = BooleanField(default=False) + consent_form = ForeignKey( + ConsentForm, + on_delete=CASCADE, + related_name="experiments" + ) diff --git a/tests/test_models.py b/tests/test_models.py index b14db60..7992d58 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -43,6 +43,9 @@ PkAuto, PkJson, SimpleModel, + Experiment, + Prompt, + ConsentForm ) from .test_field_audit import override_audited_models @@ -1216,6 +1219,17 @@ def test_delete_audit_action_ignore_calls_super(self): queryset.delete(audit_action=AuditAction.IGNORE) super_meth.assert_called() + def test_delete_does_not_cause_recursion_error(self): + experiment = Experiment.objects.create( + chatbot_prompt=Prompt.objects.create(name="Test"), + consent_form=ConsentForm.objects.create(name="Test"), + ) + object_pk = experiment.id + experiment.chatbot_prompt.delete() + event = AuditEvent.objects.last() + assert event.is_delete is True + assert event.object_pk == object_pk + class TestAuditEventBootstrapping(TestCase): From 27aaedd3a8ea8790309244819de4ec2cf7edd967 Mon Sep 17 00:00:00 2001 From: Simon Kelly Date: Wed, 17 Jan 2024 12:37:25 +0200 Subject: [PATCH 2/6] simplify test --- tests/models.py | 27 --------------------------- tests/test_models.py | 7 +++---- 2 files changed, 3 insertions(+), 31 deletions(-) diff --git a/tests/models.py b/tests/models.py index d1dfddf..2c5614e 100644 --- a/tests/models.py +++ b/tests/models.py @@ -120,47 +120,25 @@ class ExperimentObjectManager(AuditingManager): pass -class SourceMaterialObjectManager(AuditingManager): - pass - - class SafetyLayerObjectManager(AuditingManager): pass -class ConsentFormObjectManager(AuditingManager): - pass - - @audit_fields("name", audit_special_queryset_writes=True) class Prompt(Model): objects = PromptObjectManager() name = CharField(max_length=50) -@audit_fields("name", audit_special_queryset_writes=True) -class SourceMaterial(Model): - objects = SourceMaterialObjectManager() - name = CharField(max_length=50) - - @audit_fields("name", audit_special_queryset_writes=True) class SafetyLayer(Model): objects = SafetyLayerObjectManager() name = CharField(max_length=50) -@audit_fields("name", audit_special_queryset_writes=True) -class ConsentForm(Model): - objects = ConsentFormObjectManager() - name = CharField(max_length=50) - - @audit_fields( "chatbot_prompt", - "safety_layers", "tools_enabled", - "consent_form", audit_special_queryset_writes=True ) class Experiment(Model): @@ -176,8 +154,3 @@ class Experiment(Model): blank=True ) tools_enabled = BooleanField(default=False) - consent_form = ForeignKey( - ConsentForm, - on_delete=CASCADE, - related_name="experiments" - ) diff --git a/tests/test_models.py b/tests/test_models.py index 7992d58..f6bd6a3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -45,7 +45,6 @@ SimpleModel, Experiment, Prompt, - ConsentForm ) from .test_field_audit import override_audited_models @@ -1220,12 +1219,12 @@ def test_delete_audit_action_ignore_calls_super(self): super_meth.assert_called() def test_delete_does_not_cause_recursion_error(self): + prompt = Prompt.objects.create(name="Test") experiment = Experiment.objects.create( - chatbot_prompt=Prompt.objects.create(name="Test"), - consent_form=ConsentForm.objects.create(name="Test"), + chatbot_prompt=prompt, ) object_pk = experiment.id - experiment.chatbot_prompt.delete() + prompt.delete() event = AuditEvent.objects.last() assert event.is_delete is True assert event.object_pk == object_pk From 53fab9e2038da9088a31ac823bbdb02e4f2b974c Mon Sep 17 00:00:00 2001 From: Simon Kelly Date: Wed, 17 Jan 2024 13:30:41 +0200 Subject: [PATCH 3/6] further simplify test --- tests/models.py | 44 -------------------------------------------- tests/test_models.py | 12 ++++-------- 2 files changed, 4 insertions(+), 52 deletions(-) diff --git a/tests/models.py b/tests/models.py index 2c5614e..eb0d348 100644 --- a/tests/models.py +++ b/tests/models.py @@ -110,47 +110,3 @@ class PkAuto(Model): @audit_fields("id") class PkJson(Model): id = JSONField(primary_key=True) - - -class PromptObjectManager(AuditingManager): - pass - - -class ExperimentObjectManager(AuditingManager): - pass - - -class SafetyLayerObjectManager(AuditingManager): - pass - - -@audit_fields("name", audit_special_queryset_writes=True) -class Prompt(Model): - objects = PromptObjectManager() - name = CharField(max_length=50) - - -@audit_fields("name", audit_special_queryset_writes=True) -class SafetyLayer(Model): - objects = SafetyLayerObjectManager() - name = CharField(max_length=50) - - -@audit_fields( - "chatbot_prompt", - "tools_enabled", - audit_special_queryset_writes=True -) -class Experiment(Model): - objects = ExperimentObjectManager() - chatbot_prompt = ForeignKey( - Prompt, - on_delete=CASCADE, - related_name="experiments" - ) - safety_layers = ManyToManyField( - SafetyLayer, - related_name="experiments", - blank=True - ) - tools_enabled = BooleanField(default=False) diff --git a/tests/test_models.py b/tests/test_models.py index f6bd6a3..8dd4eca 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -43,8 +43,6 @@ PkAuto, PkJson, SimpleModel, - Experiment, - Prompt, ) from .test_field_audit import override_audited_models @@ -1219,12 +1217,10 @@ def test_delete_audit_action_ignore_calls_super(self): super_meth.assert_called() def test_delete_does_not_cause_recursion_error(self): - prompt = Prompt.objects.create(name="Test") - experiment = Experiment.objects.create( - chatbot_prompt=prompt, - ) - object_pk = experiment.id - prompt.delete() + aircraft = Aircraft.objects.create() + object_pk = aircraft.id + aircraft = Aircraft.objects.defer("tail_number", "make_model").get(id=object_pk) + aircraft.delete() event = AuditEvent.objects.last() assert event.is_delete is True assert event.object_pk == object_pk From 64fc8647009bbffcdcf8e4c30233b9bd6c29ac0c Mon Sep 17 00:00:00 2001 From: Simon Kelly Date: Wed, 17 Jan 2024 13:33:16 +0200 Subject: [PATCH 4/6] decorate refresh from DB to prevent recursive call Deferred fields are loaded by calling 'refresh_from_db' which in turn calls the model init method. During init, we attempt to attach initial values for audited fields. If any of the audited fields are deferred, 'refresh_from_db' is called again. --- field_audit/field_audit.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/field_audit/field_audit.py b/field_audit/field_audit.py index f9585c7..b649335 100644 --- a/field_audit/field_audit.py +++ b/field_audit/field_audit.py @@ -62,6 +62,7 @@ def wrapper(cls): cls.__init__ = _decorate_init(cls.__init__) cls.save = _decorate_db_write(cls.save) cls.delete = _decorate_db_write(cls.delete) + cls.refresh_from_db = _decorate_refresh_from_db(cls.refresh_from_db) _audited_models[cls] = get_fqcn(cls) if class_path is None else class_path # noqa: E501 return cls if not field_names: @@ -136,6 +137,23 @@ def wrapper(self, *args, **kw): return wrapper +def _decorate_refresh_from_db(func): + """Decorates the "refresh from db" method on Model subclasses. This is necessary to ensure + that all audited fields are included in the refresh to avoid recursively calling the refresh + for deferred fields. + + :param func: the "refresh from db" method to decorate + """ + @wraps(func) + def wrapper(self, using=None, fields=None, **kwargs): + if fields is not None: + fields = set(fields) | set(AuditEvent.field_names(self)) + func(self, using, fields, **kwargs) + + from .models import AuditEvent + return wrapper + + def get_audited_models(): return _audited_models.copy() From 2e53dc5771822c185d36386e4a3b2174a475a661 Mon Sep 17 00:00:00 2001 From: Simon Kelly Date: Wed, 17 Jan 2024 13:48:21 +0200 Subject: [PATCH 5/6] cleanup test model --- tests/test_models.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 8dd4eca..9b194ff 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,6 +4,7 @@ from unittest.mock import ANY, Mock, patch import django +from django.apps import apps from django.conf import settings from django.db import connection, models, transaction from django.db.models import Case, When, Value @@ -376,6 +377,11 @@ class FlyByTailNumber(models.Model): to_field="tail_number", ) + def cleanup_flyby(): + del apps.all_models["tests"]["flybytailnumber"] + apps.clear_cache() + + self.addCleanup(cleanup_flyby) flyby = FlyByTailNumber(aircraft=Aircraft(tail_number="CGXII")) self.assertEqual("CGXII", AuditEvent.get_field_value(flyby, "aircraft")) @@ -1224,7 +1230,7 @@ def test_delete_does_not_cause_recursion_error(self): event = AuditEvent.objects.last() assert event.is_delete is True assert event.object_pk == object_pk - + class TestAuditEventBootstrapping(TestCase): From 32ee1b11c00660d796cbe046f428e721a9adb170 Mon Sep 17 00:00:00 2001 From: Simon Kelly Date: Wed, 17 Jan 2024 14:11:41 +0200 Subject: [PATCH 6/6] lint --- field_audit/field_audit.py | 6 +++--- tests/models.py | 1 - tests/test_models.py | 3 +-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/field_audit/field_audit.py b/field_audit/field_audit.py index b649335..7d011a5 100644 --- a/field_audit/field_audit.py +++ b/field_audit/field_audit.py @@ -138,9 +138,9 @@ def wrapper(self, *args, **kw): def _decorate_refresh_from_db(func): - """Decorates the "refresh from db" method on Model subclasses. This is necessary to ensure - that all audited fields are included in the refresh to avoid recursively calling the refresh - for deferred fields. + """Decorates the "refresh from db" method on Model subclasses. This is + necessary to ensure that all audited fields are included in the refresh + to avoid recursively calling the refresh for deferred fields. :param func: the "refresh from db" method to decorate """ diff --git a/tests/models.py b/tests/models.py index eb0d348..acf2060 100644 --- a/tests/models.py +++ b/tests/models.py @@ -10,7 +10,6 @@ ForeignKey, IntegerField, JSONField, - ManyToManyField ) from field_audit import audit_fields diff --git a/tests/test_models.py b/tests/test_models.py index 9b194ff..870d7af 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -34,7 +34,6 @@ MakeAuditEventFromValuesException, ) from .mocks import NoopAtomicTransaction - from .models import ( Aerodrome, Aircraft, @@ -1225,7 +1224,7 @@ def test_delete_audit_action_ignore_calls_super(self): def test_delete_does_not_cause_recursion_error(self): aircraft = Aircraft.objects.create() object_pk = aircraft.id - aircraft = Aircraft.objects.defer("tail_number", "make_model").get(id=object_pk) + aircraft = Aircraft.objects.defer("tail_number", "make_model").get(id=object_pk) # noqa: E501 aircraft.delete() event = AuditEvent.objects.last() assert event.is_delete is True