Skip to content
This repository has been archived by the owner on Jun 23, 2020. It is now read-only.

WIP: Add default many-to-many support #1

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
32 changes: 32 additions & 0 deletions simple_history/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,35 @@ def bulk_history_create(self, objs, batch_size=None):
return self.model.objects.bulk_create(
historical_instances, batch_size=batch_size
)

bmedx marked this conversation as resolved.
Show resolved Hide resolved

class M2MHistoryDescriptor(object):
def __init__(self, model):
self.model = model

def __get__(self, instance, owner):
if instance is None:
return M2MHistoryManager(self.model)
return M2MHistoryManager(self.model, instance)


class M2MHistoryManager(models.Manager):
def __init__(self, model, instance=None):
super(M2MHistoryManager, self).__init__()
self.model = model
self.instance = instance

def get_super_queryset(self):
return super(M2MHistoryManager, self).get_queryset()

def get_queryset(self):
qs = self.get_super_queryset()

if self.instance is None:
return qs

if isinstance(self.instance._meta.pk, models.ForeignKey):
key_name = self.instance._meta.pk.name + "_id"
else:
key_name = self.instance._meta.pk.name
return self.get_super_queryset().filter(**{key_name: self.instance.pk})
108 changes: 107 additions & 1 deletion simple_history/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
import uuid
import warnings
from functools import partial

import django
import six
Expand All @@ -16,13 +17,14 @@
from django.db import models
from django.db.models import Q
from django.db.models.fields.proxy import OrderWrt
from django.db.models.signals import m2m_changed
from django.forms.models import model_to_dict
from django.urls import reverse
from django.utils.text import format_lazy
from django.utils.timezone import now
from simple_history import utils
from . import exceptions
from .manager import HistoryDescriptor
from .manager import HistoryDescriptor, M2MHistoryDescriptor
bmedx marked this conversation as resolved.
Show resolved Hide resolved
from .signals import post_create_historical_record, pre_create_historical_record

if django.VERSION < (2,):
Expand Down Expand Up @@ -60,6 +62,7 @@ def _history_user_setter(historical_instance, user):

class HistoricalRecords(object):
thread = threading.local()
m2m_models = {}

def __init__(
self,
Expand All @@ -81,6 +84,7 @@ def __init__(
history_user_setter=_history_user_setter,
related_name=None,
use_base_model_db=False,
m2m_fields=(),
):
self.user_set_verbose_name = verbose_name
self.user_related_name = user_related_name
Expand All @@ -98,6 +102,7 @@ def __init__(
self.user_setter = history_user_setter
self.related_name = related_name
self.use_base_model_db = use_base_model_db
self.m2m_fields = m2m_fields

if excluded_fields is None:
excluded_fields = []
Expand Down Expand Up @@ -153,6 +158,7 @@ def finalize(self, sender, **kwargs):
)
)
history_model = self.create_history_model(sender, inherited)

if inherited:
# Make sure history model is in same module as concrete model
module = importlib.import_module(history_model.__module__)
Expand All @@ -164,11 +170,28 @@ def finalize(self, sender, **kwargs):
# so the signal handlers can't use weak references.
models.signals.post_save.connect(self.post_save, sender=sender, weak=False)
models.signals.post_delete.connect(self.post_delete, sender=sender, weak=False)
for field in self.m2m_fields:
m2m_changed.connect(partial(self.m2m_changed, attr=field.name),
sender=field.remote_field.through, weak=False)

descriptor = HistoryDescriptor(history_model)
setattr(sender, self.manager_name, descriptor)
sender._meta.simple_history_manager_attribute = self.manager_name

for field in self.m2m_fields:
m2m_model = self.create_history_m2m_model(
history_model,
field.remote_field.through
)
self.m2m_models[field] = m2m_model

module = importlib.import_module(self.module)
setattr(module, m2m_model.__name__, m2m_model)

m2m_descriptor = M2MHistoryDescriptor(m2m_model)
setattr(sender, "historical_{}".format(field.name), m2m_descriptor)
setattr(history_model, "historical_{}".format(field.name), m2m_descriptor)

def get_history_model_name(self, model):
if not self.custom_model_name:
return "Historical{}".format(model._meta.object_name)
Expand All @@ -191,6 +214,59 @@ def get_history_model_name(self, model):
)
)

def create_history_m2m_model(self, model, through_model):
attrs = {
"__module__": self.module,
}

app_module = "%s.models" % model._meta.app_label

if model.__module__ != self.module:
# registered under different app
attrs["__module__"] = self.module
elif app_module != self.module:
# Abuse an internal API because the app registry is loading.
app = apps.app_configs[model._meta.app_label]
models_module = app.name
attrs["__module__"] = models_module

# Get the primary key to the history model this model will look up to
attrs["m2m_history_id"] = self._get_history_id_field()
attrs["history"] = models.ForeignKey(
model,
db_constraint=False,
on_delete=models.DO_NOTHING,
)

"""for field in through_model._meta.fields:
f = copy.copy(field)
if isinstance(f, models.ForeignKey):
f.__class__ = models.BigIntegerField
attrs[f.name + '_id'] = f
"""
fields = self.copy_fields(through_model)
attrs.update(fields)

# Set as the default then check for overrides
name = self.get_history_model_name(through_model)

registered_models[through_model._meta.db_table] = through_model

meta_fields = {"verbose_name": name}

if self.app:
meta_fields["app_label"] = self.app

attrs.update(Meta=type(str("Meta"), (), meta_fields))

history_model = type(str(name), (models.Model,), attrs)

return (
python_2_unicode_compatible(history_model)
if django.VERSION < (2,)
else history_model
)

def create_history_model(self, model, inherited):
"""
Creates a historical model to associate with the model provided.
Expand Down Expand Up @@ -363,6 +439,12 @@ def _get_history_related_field(self, model):
else:
return {}

def _get_many_to_many_fields(self):
fields = {}
for field in self.m2m_fields:
fields[field.name] = models.ManyToManyField(self.m2m_models[field])
return fields

def get_extra_fields(self, model, fields):
"""Return dict of extra fields added to the historical record model"""

Expand Down Expand Up @@ -441,6 +523,7 @@ def get_prev_record(self):

extra_fields.update(self._get_history_related_field(model))
extra_fields.update(self._get_history_user_fields())
# extra_fields.update(self._get_many_to_many_fields())

return extra_fields

Expand Down Expand Up @@ -475,6 +558,28 @@ def post_delete(self, instance, using=None, **kwargs):
else:
self.create_historical_record(instance, "-", using=using)

def m2m_changed(self, instance, action, attr, pk_set, reverse, **_):
if action in ('post_add', 'post_remove', 'post_clear'):
self.create_historical_record(instance, "~")

def create_historical_record_m2ms(self, history_instance, instance):
for field in self.m2m_fields:
m2m_history_model = self.m2m_models[field]
original_instance = history_instance.instance
through_model = getattr(original_instance, field.name).through

rows = through_model.objects.all()

for row in rows:
insert_row = {'history': history_instance}

for through_model_field in through_model._meta.fields:
insert_row[through_model_field.name] = getattr(
row, through_model_field.name
)

m2m_history_model.objects.create(**insert_row)

def create_historical_record(self, instance, history_type, using=None):
using = using if self.use_base_model_db else None
history_date = getattr(instance, "_history_date", now())
Expand Down Expand Up @@ -509,6 +614,7 @@ def create_historical_record(self, instance, history_type, using=None):
)

history_instance.save(using=using)
self.create_historical_record_m2ms(history_instance, instance)

post_create_historical_record.send(
sender=manager.model,
Expand Down
11 changes: 11 additions & 0 deletions simple_history/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ def get_absolute_url(self):
return reverse("poll-detail", kwargs={"pk": self.pk})


class PollWithManyToMany(models.Model):
question = models.CharField(max_length=200)
pub_date = models.DateTimeField("date published")
places = models.ManyToManyField("Place")

history = HistoricalRecords(m2m_fields=[places])

def get_absolute_url(self):
return reverse("poll-detail", kwargs={"pk": self.pk})


class CustomAttrNameForeignKey(models.ForeignKey):
def __init__(self, *args, **kwargs):
self.attr_name = kwargs.pop("attr_name", None)
Expand Down
Loading