Skip to content

Commit

Permalink
Add type argument to DescriptorWrapper
Browse files Browse the repository at this point in the history
This preserves the type of the wrapped descriptor (usually a field).

Maybe this is overkill, as `DescriptorWrapper` seems to only be used
as part of the `FieldTracker` implementation and is not documented
and barely tested. But technically, it is public API.
  • Loading branch information
mthuurne committed May 7, 2024
1 parent 3c3787b commit aaa645b
Showing 1 changed file with 39 additions and 14 deletions.
53 changes: 39 additions & 14 deletions model_utils/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@

from copy import deepcopy
from functools import wraps
from typing import TYPE_CHECKING, Any, Iterable, TypeVar, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterable,
Protocol,
TypeVar,
cast,
overload,
)

from django.core.exceptions import FieldError
from django.db import models
Expand All @@ -16,10 +25,26 @@ class _AugmentedModel(models.Model):
_instance_initialized: bool
_deferred_fields: set[str]


T = TypeVar("T")


class Descriptor(Protocol[T]):
def __get__(self, instance: object, owner: type[object]) -> T:
...

def __set__(self, instance: object, value: T) -> None:
...


class FullDescriptor(Descriptor[T]):
def __delete__(self, instance: object) -> None:
...


DescriptorT = TypeVar("DescriptorT", bound=Descriptor[Any])
FullDescriptorT = TypeVar("FullDescriptorT", bound=FullDescriptor[Any])


class LightStateFieldFile(FieldFile):
"""
FieldFile subclass with the only aim to remove the instance from the state.
Expand Down Expand Up @@ -53,22 +78,22 @@ def lightweight_deepcopy(value: T) -> T:
return deepcopy(value)


class DescriptorWrapper:
class DescriptorWrapper(Generic[DescriptorT]):

def __init__(self, field_name: str, descriptor: models.Field, tracker_attname: str):
def __init__(self, field_name: str, descriptor: DescriptorT, tracker_attname: str):
self.field_name = field_name
self.descriptor = descriptor
self.tracker_attname = tracker_attname

@overload
def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper:
def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper[DescriptorT]:
...

@overload
def __get__(self, instance: models.Model, owner: type[models.Model]) -> models.Field:
def __get__(self, instance: models.Model, owner: type[models.Model]) -> DescriptorT:
...

def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper | models.Field:
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper[DescriptorT] | DescriptorT:
if instance is None:
return self
was_deferred = self.field_name in instance.get_deferred_fields()
Expand All @@ -81,7 +106,7 @@ def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> D
tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value)
return value

def __set__(self, instance: models.Model, value: models.Field) -> None:
def __set__(self, instance: models.Model, value: DescriptorT) -> None:
initialized = hasattr(instance, '_instance_initialized')
was_deferred = self.field_name in instance.get_deferred_fields()

Expand All @@ -104,23 +129,23 @@ def __set__(self, instance: models.Model, value: models.Field) -> None:
else:
instance.__dict__[self.field_name] = value

def __getattr__(self, attr: str) -> models.Field:
def __getattr__(self, attr: str) -> DescriptorT:
return getattr(self.descriptor, attr)

@staticmethod
def cls_for_descriptor(descriptor: models.Field) -> type[DescriptorWrapper]:
def cls_for_descriptor(descriptor: DescriptorT) -> type[DescriptorWrapper[DescriptorT]]:
if hasattr(descriptor, '__delete__'):
return FullDescriptorWrapper
else:
return DescriptorWrapper


class FullDescriptorWrapper(DescriptorWrapper):
class FullDescriptorWrapper(DescriptorWrapper[FullDescriptorT]):
"""
Wrapper for descriptors with all three descriptor methods.
"""
def __delete__(self, obj: models.Field) -> None:
self.descriptor.__delete__(obj) # type: ignore[attr-defined]
def __delete__(self, obj: models.Model) -> None:
self.descriptor.__delete__(obj)


class FieldsContext:
Expand Down Expand Up @@ -351,7 +376,7 @@ def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None:
self.fields = (field.attname for field in sender._meta.fields)
self.fields = set(self.fields)
for field_name in self.fields:
descriptor: models.Field = getattr(sender, field_name)
descriptor: models.Field[Any, Any] = getattr(sender, field_name)
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
setattr(sender, field_name, wrapped_descriptor)
Expand Down

0 comments on commit aaa645b

Please sign in to comment.