diff --git a/docs/models.rst b/docs/models.rst index a2aac4d1..31707f4d 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -107,5 +107,22 @@ Also you can override the default uuid version. Versions 1,3,4 and 5 are now sup pass - .. _`UUIDField`: https://github.com/jazzband/django-model-utils/blob/master/docs/fields.rst#uuidfield + + +SaveSignalHandlingModel +----------------------- + +An abstract base class model to pass a parameter ``signals_to_disable`` +to ``save`` method in order to disable signals + +.. code-block:: python + + from model_utils.models import SaveSignalHandlingModel + + class SaveSignalTestModel(SaveSignalHandlingModel): + name = models.CharField(max_length=20) + + obj = SaveSignalTestModel(name='Test') + # Note: If you use `Model.objects.create`, the signals can't be disabled + obj.save(signals_to_disable=['pre_save'] # disable `pre_save` signal diff --git a/model_utils/models.py b/model_utils/models.py index 96c472b5..43c34dd1 100644 --- a/model_utils/models.py +++ b/model_utils/models.py @@ -2,7 +2,8 @@ import django from django.core.exceptions import ImproperlyConfigured -from django.db import models +from django.db import models, transaction, router +from django.db.models.signals import post_save, pre_save from django.utils.translation import ugettext_lazy as _ from model_utils.fields import ( @@ -159,3 +160,60 @@ class UUIDModel(models.Model): class Meta: abstract = True + + +class SaveSignalHandlingModel(models.Model): + """ + An abstract base class model to pass a parameter ``signals_to_disable`` + to ``save`` method in order to disable signals + """ + class Meta: + abstract = True + + def save(self, signals_to_disable=None, *args, **kwargs): + """ + Add an extra parameters to hold which signals to disable + If empty, nothing will change + """ + + self.signals_to_disable = signals_to_disable or [] + + super(SaveSignalHandlingModel, self).save(*args, **kwargs) + + def save_base(self, raw=False, force_insert=False, + force_update=False, using=None, update_fields=None): + """ + Copied from base class for a minor change. + This is an ugly overwriting but since Django's ``save_base`` method + does not differ between versions 1.8 and 1.10, + that way of implementing wouldn't harm the flow + """ + using = using or router.db_for_write(self.__class__, instance=self) + assert not (force_insert and (force_update or update_fields)) + assert update_fields is None or len(update_fields) > 0 + cls = origin = self.__class__ + + if cls._meta.proxy: + cls = cls._meta.concrete_model + meta = cls._meta + if not meta.auto_created and not 'pre_save' in self.signals_to_disable: + pre_save.send( + sender=origin, instance=self, raw=raw, using=using, + update_fields=update_fields, + ) + with transaction.atomic(using=using, savepoint=False): + if not raw: + self._save_parents(cls, using, update_fields) + updated = self._save_table(raw, cls, force_insert, force_update, using, update_fields) + + self._state.db = using + self._state.adding = False + + if not meta.auto_created and not 'post_save' in self.signals_to_disable: + post_save.send( + sender=origin, instance=self, created=(not updated), + update_fields=update_fields, raw=raw, using=using, + ) + + # Empty the signals in case it might be used somewhere else in future + self.signals_to_disable = [] diff --git a/tests/models.py b/tests/models.py index 6a1f822b..40e3ff69 100644 --- a/tests/models.py +++ b/tests/models.py @@ -4,6 +4,7 @@ from django.db import models from django.db.models.query_utils import DeferredAttribute from django.db.models import Manager +from django.dispatch import receiver from django.utils.encoding import python_2_unicode_compatible from django.utils.translation import ugettext_lazy as _ @@ -25,6 +26,7 @@ TimeFramedModel, TimeStampedModel, UUIDModel, + SaveSignalHandlingModel, ) from tests.fields import MutableField from tests.managers import CustomSoftDeleteManager @@ -437,3 +439,7 @@ class CustomUUIDModel(UUIDModel): class CustomNotPrimaryUUIDModel(models.Model): uuid = UUIDField(primary_key=False) + + +class SaveSignalHandlingTestModel(SaveSignalHandlingModel): + name = models.CharField(max_length=20) diff --git a/tests/signals.py b/tests/signals.py new file mode 100644 index 00000000..f75efbec --- /dev/null +++ b/tests/signals.py @@ -0,0 +1,5 @@ +def pre_save_test(instance, *args, **kwargs): + instance.pre_save_runned = True + +def post_save_test(instance, created, *args, **kwargs): + instance.post_save_runned = True diff --git a/tests/test_models/test_savesignalhandling_model.py b/tests/test_models/test_savesignalhandling_model.py new file mode 100644 index 00000000..e281f78e --- /dev/null +++ b/tests/test_models/test_savesignalhandling_model.py @@ -0,0 +1,45 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from tests.models import SaveSignalHandlingTestModel +from tests.signals import pre_save_test, post_save_test +from django.db.models.signals import pre_save, post_save + + +class SaveSignalHandlingModelTests(TestCase): + + def test_pre_save(self): + pre_save.connect(pre_save_test, sender=SaveSignalHandlingTestModel) + + obj = SaveSignalHandlingTestModel.objects.create(name='Test') + delattr(obj, 'pre_save_runned') + obj.name = 'Test A' + obj.save() + self.assertEqual(obj.name, 'Test A') + self.assertTrue(hasattr(obj, 'pre_save_runned')) + + obj = SaveSignalHandlingTestModel.objects.create(name='Test') + delattr(obj, 'pre_save_runned') + obj.name = 'Test B' + obj.save(signals_to_disable=['pre_save']) + self.assertEqual(obj.name, 'Test B') + self.assertFalse(hasattr(obj, 'pre_save_runned')) + + + def test_post_save(self): + post_save.connect(post_save_test, sender=SaveSignalHandlingTestModel) + + obj = SaveSignalHandlingTestModel.objects.create(name='Test') + delattr(obj, 'post_save_runned') + obj.name = 'Test A' + obj.save() + self.assertEqual(obj.name, 'Test A') + self.assertTrue(hasattr(obj, 'post_save_runned')) + + obj = SaveSignalHandlingTestModel.objects.create(name='Test') + delattr(obj, 'post_save_runned') + obj.name = 'Test B' + obj.save(signals_to_disable=['post_save']) + self.assertEqual(obj.name, 'Test B') + self.assertFalse(hasattr(obj, 'post_save_runned'))