diff --git a/.coveragerc b/.coveragerc index 8708371a..77531508 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,2 +1,8 @@ [run] include = model_utils/*.py + +[report] +exclude_also = + # Exclusive to mypy: + if TYPE_CHECKING:$ + \.\.\.$ diff --git a/model_utils/choices.py b/model_utils/choices.py index 46c3877c..ce709be3 100644 --- a/model_utils/choices.py +++ b/model_utils/choices.py @@ -1,7 +1,39 @@ +from __future__ import annotations + import copy +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload + +T = TypeVar("T") + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + + # The type aliases defined here are evaluated when the django-stubs mypy plugin + # loads this module, so they must be able to execute under the lowest supported + # Python VM: + # - typing.List, typing.Tuple become obsolete in Pyton 3.9 + # - typing.Union becomes obsolete in Pyton 3.10 + from typing import List, Tuple, Union + + from django_stubs_ext import StrOrPromise + + # The type argument 'T' to 'Choices' is the database representation type. + _Double = Tuple[T, StrOrPromise] + _Triple = Tuple[T, str, StrOrPromise] + _Group = Tuple[StrOrPromise, Sequence["_Choice[T]"]] + _Choice = Union[_Double[T], _Triple[T], _Group[T]] + # Choices can only be given as a single string if 'T' is 'str'. + _GroupStr = Tuple[StrOrPromise, Sequence["_ChoiceStr"]] + _ChoiceStr = Union[str, _Double[str], _Triple[str], _GroupStr] + # Note that we only accept lists and tuples in groups, not arbitrary sequences. + # However, annotating it as such causes many problems. + + _DoubleRead = Union[_Double[T], Tuple[StrOrPromise, Iterable["_DoubleRead[T]"]]] + _DoubleCollector = List[Union[_Double[T], Tuple[StrOrPromise, "_DoubleCollector[T]"]]] + _TripleCollector = List[Union[_Triple[T], Tuple[StrOrPromise, "_TripleCollector[T]"]]] -class Choices: +class Choices(Generic[T]): """ A class to encapsulate handy functionality for lists of choices for a Django model field. @@ -41,36 +73,60 @@ class Choices: """ - def __init__(self, *choices): + @overload + def __init__(self: Choices[str], *choices: _ChoiceStr): + ... + + @overload + def __init__(self, *choices: _Choice[T]): + ... + + def __init__(self, *choices: _ChoiceStr | _Choice[T]): # list of choices expanded to triples - can include optgroups - self._triples = [] + self._triples: _TripleCollector[T] = [] # list of choices as (db, human-readable) - can include optgroups - self._doubles = [] + self._doubles: _DoubleCollector[T] = [] # dictionary mapping db representation to human-readable - self._display_map = {} + self._display_map: dict[T, StrOrPromise | list[_Triple[T]]] = {} # dictionary mapping Python identifier to db representation - self._identifier_map = {} + self._identifier_map: dict[str, T] = {} # set of db representations - self._db_values = set() + self._db_values: set[T] = set() self._process(choices) - def _store(self, triple, triple_collector, double_collector): + def _store( + self, + triple: tuple[T, str, StrOrPromise], + triple_collector: _TripleCollector[T], + double_collector: _DoubleCollector[T] + ) -> None: self._identifier_map[triple[1]] = triple[0] self._display_map[triple[0]] = triple[2] self._db_values.add(triple[0]) triple_collector.append(triple) double_collector.append((triple[0], triple[2])) - def _process(self, choices, triple_collector=None, double_collector=None): + def _process( + self, + choices: Iterable[_ChoiceStr | _Choice[T]], + triple_collector: _TripleCollector[T] | None = None, + double_collector: _DoubleCollector[T] | None = None + ) -> None: if triple_collector is None: triple_collector = self._triples if double_collector is None: double_collector = self._doubles - store = lambda c: self._store(c, triple_collector, double_collector) + def store(c: tuple[Any, str, StrOrPromise]) -> None: + self._store(c, triple_collector, double_collector) for choice in choices: + # The type inference is not very accurate here: + # - we lied in the type aliases, stating groups contain an arbitrary Sequence + # rather than only list or tuple + # - there is no way to express that _ChoiceStr is only used when T=str + # - mypy 1.9.0 doesn't narrow types based on the value of len() if isinstance(choice, (list, tuple)): if len(choice) == 3: store(choice) @@ -79,13 +135,13 @@ def _process(self, choices, triple_collector=None, double_collector=None): # option group group_name = choice[0] subchoices = choice[1] - tc = [] + tc: _TripleCollector[T] = [] triple_collector.append((group_name, tc)) - dc = [] + dc: _DoubleCollector[T] = [] double_collector.append((group_name, dc)) self._process(subchoices, tc, dc) else: - store((choice[0], choice[0], choice[1])) + store((choice[0], cast(str, choice[0]), cast('StrOrPromise', choice[1]))) else: raise ValueError( "Choices can't take a list of length %s, only 2 or 3" @@ -94,54 +150,74 @@ def _process(self, choices, triple_collector=None, double_collector=None): else: store((choice, choice, choice)) - def __len__(self): + def __len__(self) -> int: return len(self._doubles) - def __iter__(self): + def __iter__(self) -> Iterator[_DoubleRead[T]]: return iter(self._doubles) - def __reversed__(self): + def __reversed__(self) -> Iterator[_DoubleRead[T]]: return reversed(self._doubles) - def __getattr__(self, attname): + def __getattr__(self, attname: str) -> T: try: return self._identifier_map[attname] except KeyError: raise AttributeError(attname) - def __getitem__(self, key): + def __getitem__(self, key: T) -> StrOrPromise | Sequence[_Triple[T]]: return self._display_map[key] - def __add__(self, other): + @overload + def __add__(self: Choices[str], other: Choices[str] | Iterable[_ChoiceStr]) -> Choices[str]: + ... + + @overload + def __add__(self, other: Choices[T] | Iterable[_Choice[T]]) -> Choices[T]: + ... + + def __add__(self, other: Choices[Any] | Iterable[_ChoiceStr | _Choice[Any]]) -> Choices[Any]: + other_args: list[Any] if isinstance(other, self.__class__): - other = other._triples + other_args = other._triples else: - other = list(other) - return Choices(*(self._triples + other)) + other_args = list(other) + return Choices(*(self._triples + other_args)) + + @overload + def __radd__(self: Choices[str], other: Iterable[_ChoiceStr]) -> Choices[str]: + ... + + @overload + def __radd__(self, other: Iterable[_Choice[T]]) -> Choices[T]: + ... - def __radd__(self, other): + def __radd__(self, other: Iterable[_ChoiceStr] | Iterable[_Choice[T]]) -> Choices[Any]: # radd is never called for matching types, so we don't check here - other = list(other) - return Choices(*(other + self._triples)) + other_args = list(other) + # The exact type of 'other' depends on our type argument 'T', which + # is expressed in the overloading, but lost within this method body. + return Choices(*(other_args + self._triples)) # type: ignore[arg-type] - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): return self._triples == other._triples return False - def __repr__(self): + def __repr__(self) -> str: return '{}({})'.format( self.__class__.__name__, ', '.join("%s" % repr(i) for i in self._triples) ) - def __contains__(self, item): + def __contains__(self, item: T) -> bool: return item in self._db_values - def __deepcopy__(self, memo): - return self.__class__(*copy.deepcopy(self._triples, memo)) + def __deepcopy__(self, memo: dict[int, Any] | None) -> Choices[T]: + args: list[Any] = copy.deepcopy(self._triples, memo) + return self.__class__(*args) - def subset(self, *new_identifiers): + def subset(self, *new_identifiers: str) -> Choices[T]: identifiers = set(self._identifier_map.keys()) if not identifiers.issuperset(new_identifiers): @@ -150,7 +226,8 @@ def subset(self, *new_identifiers): identifiers.symmetric_difference(new_identifiers), ) - return self.__class__(*[ + args: list[Any] = [ choice for choice in self._triples if choice[1] in new_identifiers - ]) + ] + return self.__class__(*args) diff --git a/model_utils/fields.py b/model_utils/fields.py index 69188bba..ced117cb 100644 --- a/model_utils/fields.py +++ b/model_utils/fields.py @@ -1,15 +1,27 @@ +from __future__ import annotations + import secrets import uuid +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Union from django.conf import settings from django.core.exceptions import ValidationError from django.db import models from django.utils.timezone import now +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from datetime import date, datetime + + DateTimeFieldBase = models.DateTimeField[Union[str, datetime, date], datetime] +else: + DateTimeFieldBase = models.DateTimeField + DEFAULT_CHOICES_NAME = 'STATUS' -class AutoCreatedField(models.DateTimeField): +class AutoCreatedField(DateTimeFieldBase): """ A DateTimeField that automatically populates itself at object creation. @@ -18,7 +30,7 @@ class AutoCreatedField(models.DateTimeField): """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): kwargs.setdefault('editable', False) kwargs.setdefault('default', now) super().__init__(*args, **kwargs) @@ -31,13 +43,13 @@ class AutoLastModifiedField(AutoCreatedField): By default, sets editable=False and default=datetime.now. """ - def get_default(self): + def get_default(self) -> datetime: """Return the default value for this field.""" if not hasattr(self, "_default"): - self._default = self._get_default() + self._default = super().get_default() return self._default - def pre_save(self, model_instance, add): + def pre_save(self, model_instance: models.Model, add: bool) -> datetime: value = now() if add: current_value = getattr(model_instance, self.attname, self.get_default()) @@ -68,13 +80,19 @@ class StatusField(models.CharField): South can handle this field when it freezes a model. """ - def __init__(self, *args, no_check_for_status=False, choices_name=DEFAULT_CHOICES_NAME, **kwargs): + def __init__( + self, + *args: Any, + no_check_for_status: bool = False, + choices_name: str = DEFAULT_CHOICES_NAME, + **kwargs: Any + ): kwargs.setdefault('max_length', 100) self.check_for_status = not no_check_for_status self.choices_name = choices_name super().__init__(*args, **kwargs) - def prepare_class(self, sender, **kwargs): + def prepare_class(self, sender: type[models.Model], **kwargs: Any) -> None: if not sender._meta.abstract and self.check_for_status: assert hasattr(sender, self.choices_name), \ "To use StatusField, the model '%s' must have a %s choices class attribute." \ @@ -83,7 +101,7 @@ def prepare_class(self, sender, **kwargs): if not self.has_default(): self.default = tuple(getattr(sender, self.choices_name))[0][0] # set first as default - def contribute_to_class(self, cls, name, *args, **kwargs): + def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None: models.signals.class_prepared.connect(self.prepare_class, sender=cls) # we don't set the real choices until class_prepared (so we can rely on # the STATUS class attr being available), but we need to set some dummy @@ -91,13 +109,13 @@ def contribute_to_class(self, cls, name, *args, **kwargs): self.choices = [(0, 'dummy')] super().contribute_to_class(cls, name, *args, **kwargs) - def deconstruct(self): + def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]: name, path, args, kwargs = super().deconstruct() kwargs['no_check_for_status'] = True return name, path, args, kwargs -class MonitorField(models.DateTimeField): +class MonitorField(DateTimeFieldBase): """ A DateTimeField that monitors another field on the same model and sets itself to the current date/time whenever the monitored field @@ -105,30 +123,28 @@ class MonitorField(models.DateTimeField): """ - def __init__(self, *args, monitor, when=None, **kwargs): + def __init__(self, *args: Any, monitor: str, when: Iterable[Any] | None = None, **kwargs: Any): default = None if kwargs.get("null") else now kwargs.setdefault('default', default) self.monitor = monitor - if when is not None: - when = set(when) - self.when = when + self.when = None if when is None else set(when) super().__init__(*args, **kwargs) - def contribute_to_class(self, cls, name, *args, **kwargs): + def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None: self.monitor_attname = '_monitor_%s' % name models.signals.post_init.connect(self._save_initial, sender=cls) super().contribute_to_class(cls, name, *args, **kwargs) - def get_monitored_value(self, instance): + def get_monitored_value(self, instance: models.Model) -> Any: return getattr(instance, self.monitor) - def _save_initial(self, sender, instance, **kwargs): + def _save_initial(self, sender: type[models.Model], instance: models.Model, **kwargs: Any) -> None: if self.monitor in instance.get_deferred_fields(): # Fix related to issue #241 to avoid recursive error on double monitor fields return setattr(instance, self.monitor_attname, self.get_monitored_value(instance)) - def pre_save(self, model_instance, add): + def pre_save(self, model_instance: models.Model, add: bool) -> Any: value = now() previous = getattr(model_instance, self.monitor_attname, None) current = self.get_monitored_value(model_instance) @@ -138,7 +154,7 @@ def pre_save(self, model_instance, add): self._save_initial(model_instance.__class__, model_instance) return super().pre_save(model_instance, add) - def deconstruct(self): + def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]: name, path, args, kwargs = super().deconstruct() kwargs['monitor'] = self.monitor if self.when is not None: @@ -152,12 +168,12 @@ def deconstruct(self): SPLIT_DEFAULT_PARAGRAPHS = getattr(settings, 'SPLIT_DEFAULT_PARAGRAPHS', 2) -def _excerpt_field_name(name): +def _excerpt_field_name(name: str) -> str: return '_%s_excerpt' % name -def get_excerpt(content): - excerpt = [] +def get_excerpt(content: str) -> str: + excerpt: list[str] = [] default_excerpt = [] paras_seen = 0 for line in content.splitlines(): @@ -173,7 +189,7 @@ def get_excerpt(content): class SplitText: - def __init__(self, instance, field_name, excerpt_field_name): + def __init__(self, instance: models.Model, field_name: str, excerpt_field_name: str): # instead of storing actual values store a reference to the instance # along with field names, this makes assignment possible self.instance = instance @@ -181,36 +197,36 @@ def __init__(self, instance, field_name, excerpt_field_name): self.excerpt_field_name = excerpt_field_name @property - def content(self): + def content(self) -> str: return self.instance.__dict__[self.field_name] @content.setter - def content(self, val): + def content(self, val: str) -> None: setattr(self.instance, self.field_name, val) @property - def excerpt(self): + def excerpt(self) -> str: return getattr(self.instance, self.excerpt_field_name) @property - def has_more(self): + def has_more(self) -> bool: return self.excerpt.strip() != self.content.strip() - def __str__(self): + def __str__(self) -> str: return self.content class SplitDescriptor: - def __init__(self, field): + def __init__(self, field: SplitField): self.field = field self.excerpt_field_name = _excerpt_field_name(self.field.name) - def __get__(self, instance, owner): + def __get__(self, instance: models.Model, owner: type[models.Model]) -> SplitText: if instance is None: raise AttributeError('Can only be accessed via an instance.') return SplitText(instance, self.field.name, self.excerpt_field_name) - def __set__(self, obj, value): + def __set__(self, obj: models.Model, value: SplitText | str) -> None: if isinstance(value, SplitText): obj.__dict__[self.field.name] = value.content setattr(obj, self.excerpt_field_name, value.excerpt) @@ -218,25 +234,32 @@ def __set__(self, obj, value): obj.__dict__[self.field.name] = value -class SplitField(models.TextField): - def contribute_to_class(self, cls, name, *args, **kwargs): +if TYPE_CHECKING: + _SplitFieldBase = models.TextField[Union[SplitText, str], SplitText] +else: + _SplitFieldBase = models.TextField + + +class SplitField(_SplitFieldBase): + + def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None: if not cls._meta.abstract: - excerpt_field = models.TextField(editable=False) + excerpt_field: models.TextField = models.TextField(editable=False) cls.add_to_class(_excerpt_field_name(name), excerpt_field) super().contribute_to_class(cls, name, *args, **kwargs) setattr(cls, self.name, SplitDescriptor(self)) - def pre_save(self, model_instance, add): - value = super().pre_save(model_instance, add) + def pre_save(self, model_instance: models.Model, add: bool) -> str: + value: SplitText = super().pre_save(model_instance, add) excerpt = get_excerpt(value.content) setattr(model_instance, _excerpt_field_name(self.attname), excerpt) return value.content - def value_to_string(self, obj): + def value_to_string(self, obj: models.Model) -> str: value = self.value_from_object(obj) return value.content - def get_prep_value(self, value): + def get_prep_value(self, value: Any) -> str: try: return value.content except AttributeError: @@ -248,7 +271,14 @@ class UUIDField(models.UUIDField): A field for storing universally unique identifiers. Use Python UUID class. """ - def __init__(self, primary_key=True, version=4, editable=False, *args, **kwargs): + def __init__( + self, + primary_key: bool = True, + version: int = 4, + editable: bool = False, + *args: Any, + **kwargs: Any + ): """ Parameters ---------- @@ -274,6 +304,7 @@ def __init__(self, primary_key=True, version=4, editable=False, *args, **kwargs) raise ValidationError( 'UUID version is not valid.') + default: Callable[..., uuid.UUID] if version == 1: default = uuid.uuid1 elif version == 3: @@ -294,7 +325,15 @@ class UrlsafeTokenField(models.CharField): A field for storing a unique token in database. """ - def __init__(self, editable=False, max_length=128, factory=None, **kwargs): + max_length: int + + def __init__( + self, + editable: bool = False, + max_length: int = 128, + factory: Callable[[int], str] | None = None, + **kwargs: Any + ): """ Parameters ---------- @@ -319,14 +358,14 @@ def __init__(self, editable=False, max_length=128, factory=None, **kwargs): super().__init__(editable=editable, max_length=max_length, **kwargs) - def get_default(self): + def get_default(self) -> str: if self._factory is not None: return self._factory(self.max_length) # generate a token of length x1.33 approx. trim up to max length token = secrets.token_urlsafe(self.max_length)[:self.max_length] return token - def deconstruct(self): + def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]: name, path, args, kwargs = super().deconstruct() kwargs['factory'] = self._factory return name, path, args, kwargs diff --git a/model_utils/managers.py b/model_utils/managers.py index aaa4be84..899b9887 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import warnings +from typing import TYPE_CHECKING, Any, Generic, Sequence, TypeVar, cast, overload from django.core.exceptions import ObjectDoesNotExist from django.db import connection, models @@ -7,57 +10,84 @@ from django.db.models.query import ModelIterable, QuerySet from django.db.models.sql.datastructures import Join - -class InheritanceIterable(ModelIterable): - def __iter__(self): - queryset = self.queryset - iter = ModelIterable(queryset) - if getattr(queryset, 'subclasses', False): - extras = tuple(queryset.query.extra.keys()) - # sort the subclass names longest first, - # so with 'a' and 'a__b' it goes as deep as possible - subclasses = sorted(queryset.subclasses, key=len, reverse=True) - for obj in iter: - sub_obj = None - for s in subclasses: - sub_obj = queryset._get_sub_obj_recurse(obj, s) - if sub_obj: - break - if not sub_obj: - sub_obj = obj - - if getattr(queryset, '_annotated', False): - for k in queryset._annotated: - setattr(sub_obj, k, getattr(obj, k)) - - for k in extras: +ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) + +if TYPE_CHECKING: + from collections.abc import Iterator + + from django.db.models.query import BaseIterable + + +def _iter_inheritance_queryset(queryset: QuerySet[ModelT]) -> Iterator[ModelT]: + iter: ModelIterable[ModelT] = ModelIterable(queryset) + if hasattr(queryset, 'subclasses'): + assert hasattr(queryset, '_get_sub_obj_recurse') + extras = tuple(queryset.query.extra.keys()) + # sort the subclass names longest first, + # so with 'a' and 'a__b' it goes as deep as possible + subclasses = sorted(queryset.subclasses, key=len, reverse=True) + for obj in iter: + sub_obj = None + for s in subclasses: + sub_obj = queryset._get_sub_obj_recurse(obj, s) + if sub_obj: + break + if not sub_obj: + sub_obj = obj + + if hasattr(queryset, '_annotated'): + for k in queryset._annotated: setattr(sub_obj, k, getattr(obj, k)) - yield sub_obj - else: - yield from iter + for k in extras: + setattr(sub_obj, k, getattr(obj, k)) + + yield sub_obj + else: + yield from iter + + +if TYPE_CHECKING: + class InheritanceIterable(ModelIterable[ModelT]): + queryset: QuerySet[ModelT] + def __init__(self, queryset: QuerySet[ModelT], *args: Any, **kwargs: Any): + ... -class InheritanceQuerySetMixin: - def __init__(self, *args, **kwargs): + def __iter__(self) -> Iterator[ModelT]: + ... + +else: + class InheritanceIterable(ModelIterable): + def __iter__(self): + return _iter_inheritance_queryset(self.queryset) + + +class InheritanceQuerySetMixin(Generic[ModelT]): + + model: type[ModelT] + subclasses: Sequence[str] + + def __init__(self, *args: object, **kwargs: object): super().__init__(*args, **kwargs) - self._iterable_class = InheritanceIterable + self._iterable_class: type[BaseIterable[ModelT]] = InheritanceIterable - def select_subclasses(self, *subclasses): - calculated_subclasses = self._get_subclasses_recurse(self.model) + def select_subclasses(self, *subclasses: str | type[models.Model]) -> InheritanceQuerySet[ModelT]: + model: type[ModelT] = self.model + calculated_subclasses = self._get_subclasses_recurse(model) # if none were passed in, we can just short circuit and select all if not subclasses: - subclasses = calculated_subclasses + selected_subclasses = calculated_subclasses else: - verified_subclasses = [] + verified_subclasses: list[str] = [] for subclass in subclasses: # special case for passing in the same model as the queryset # is bound against. Rather than raise an error later, we know # we can allow this through. - if subclass is self.model: + if subclass is model: continue - if not isinstance(subclass, (str,)): + if not isinstance(subclass, str): subclass = self._get_ancestors_path(subclass) if subclass in calculated_subclasses: @@ -67,38 +97,39 @@ def select_subclasses(self, *subclasses): '{!r} is not in the discovered subclasses, tried: {}'.format( subclass, ', '.join(calculated_subclasses)) ) - subclasses = verified_subclasses + selected_subclasses = verified_subclasses - if subclasses: - new_qs = self.select_related(*subclasses) - else: - new_qs = self - new_qs.subclasses = subclasses + new_qs = cast('InheritanceQuerySet[ModelT]', self) + if selected_subclasses: + new_qs = new_qs.select_related(*selected_subclasses) + new_qs.subclasses = selected_subclasses return new_qs - def _chain(self, **kwargs): + def _chain(self, **kwargs: object) -> InheritanceQuerySet[ModelT]: update = {} for name in ['subclasses', '_annotated']: if hasattr(self, name): update[name] = getattr(self, name) - chained = super()._chain(**kwargs) + # django-stubs doesn't include this private API. + chained = super()._chain(**kwargs) # type: ignore[misc] chained.__dict__.update(update) return chained - def _clone(self): - qs = super()._clone() + def _clone(self) -> InheritanceQuerySet[ModelT]: + # django-stubs doesn't include this private API. + qs = super()._clone() # type: ignore[misc] for name in ['subclasses', '_annotated']: if hasattr(self, name): setattr(qs, name, getattr(self, name)) return qs - def annotate(self, *args, **kwargs): - qset = super().annotate(*args, **kwargs) + def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + qset = cast(QuerySet[ModelT], super()).annotate(*args, **kwargs) qset._annotated = [a.default_alias for a in args] + list(kwargs.keys()) return qset - def _get_subclasses_recurse(self, model): + def _get_subclasses_recurse(self, model: type[models.Model]) -> list[str]: """ Given a Model class, find all related objects, exploring children recursively, returning a `list` of strings representing the @@ -124,7 +155,7 @@ def _get_subclasses_recurse(self, model): subclasses.append(rel.get_accessor_name()) return subclasses - def _get_ancestors_path(self, model): + def _get_ancestors_path(self, model: type[models.Model]) -> str: """ Serves as an opposite to _get_subclasses_recurse, instead walking from the Model class up the Model's ancestry and constructing the desired @@ -134,7 +165,7 @@ def _get_ancestors_path(self, model): raise ValueError( f"{model!r} is not a subclass of {self.model!r}") - ancestry = [] + ancestry: list[str] = [] # should be a OneToOneField or None parent_link = model._meta.get_ancestor_link(self.model) @@ -147,7 +178,7 @@ def _get_ancestors_path(self, model): return LOOKUP_SEP.join(ancestry) - def _get_sub_obj_recurse(self, obj, s): + def _get_sub_obj_recurse(self, obj: models.Model, s: str) -> ModelT | None: rel, _, s = s.partition(LOOKUP_SEP) try: @@ -160,12 +191,14 @@ def _get_sub_obj_recurse(self, obj, s): else: return node - def get_subclass(self, *args, **kwargs): + def get_subclass(self, *args: object, **kwargs: object) -> ModelT: return self.select_subclasses().get(*args, **kwargs) -class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet): - def instance_of(self, *models): +# Defining the 'model' attribute using a generic type triggers a bug in mypy: +# https://github.com/python/mypy/issues/9031 +class InheritanceQuerySet(InheritanceQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc] + def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]: """ Fetch only objects that are instances of the provided model(s). """ @@ -188,88 +221,187 @@ def instance_of(self, *models): ) for field in model._meta.parents.values() ]) + ')') - return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)]) + return cast( + 'InheritanceQuerySet[ModelT]', + self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)]) + ) -class InheritanceManagerMixin: +class InheritanceManagerMixin(Generic[ModelT]): _queryset_class = InheritanceQuerySet - def get_queryset(self): - return self._queryset_class(self.model) + if TYPE_CHECKING: + from collections.abc import Sequence + + def none(self) -> InheritanceQuerySet[ModelT]: + ... + + def all(self) -> InheritanceQuerySet[ModelT]: + ... + + def filter(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + ... + + def exclude(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + ... + + def complex_filter(self, filter_obj: Any) -> InheritanceQuerySet[ModelT]: + ... + + def union(self, *other_qs: Any, all: bool = ...) -> InheritanceQuerySet[ModelT]: + ... + + def intersection(self, *other_qs: Any) -> InheritanceQuerySet[ModelT]: + ... + + def difference(self, *other_qs: Any) -> InheritanceQuerySet[ModelT]: + ... + + def select_for_update( + self, nowait: bool = ..., skip_locked: bool = ..., of: Sequence[str] = ..., no_key: bool = ... + ) -> InheritanceQuerySet[ModelT]: + ... + + def select_related(self, *fields: Any) -> InheritanceQuerySet[ModelT]: + ... - def select_subclasses(self, *subclasses): + def prefetch_related(self, *lookups: Any) -> InheritanceQuerySet[ModelT]: + ... + + def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + ... + + def alias(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + ... + + def order_by(self, *field_names: Any) -> InheritanceQuerySet[ModelT]: + ... + + def distinct(self, *field_names: Any) -> InheritanceQuerySet[ModelT]: + ... + + def extra( + self, + select: dict[str, Any] | None = ..., + where: list[str] | None = ..., + params: list[Any] | None = ..., + tables: list[str] | None = ..., + order_by: Sequence[str] | None = ..., + select_params: Sequence[Any] | None = ..., + ) -> InheritanceQuerySet[Any]: + ... + + def reverse(self) -> InheritanceQuerySet[ModelT]: + ... + + def defer(self, *fields: Any) -> InheritanceQuerySet[ModelT]: + ... + + def only(self, *fields: Any) -> InheritanceQuerySet[ModelT]: + ... + + def using(self, alias: str | None) -> InheritanceQuerySet[ModelT]: + ... + + def get_queryset(self) -> InheritanceQuerySet[ModelT]: + model: type[ModelT] = self.model # type: ignore[attr-defined] + return self._queryset_class(model) + + def select_subclasses( + self, *subclasses: str | type[models.Model] + ) -> InheritanceQuerySet[ModelT]: return self.get_queryset().select_subclasses(*subclasses) - def get_subclass(self, *args, **kwargs): + def get_subclass(self, *args: object, **kwargs: object) -> ModelT: return self.get_queryset().get_subclass(*args, **kwargs) - def instance_of(self, *models): + def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]: return self.get_queryset().instance_of(*models) -class InheritanceManager(InheritanceManagerMixin, models.Manager): +class InheritanceManager(InheritanceManagerMixin[ModelT], models.Manager[ModelT]): pass -class QueryManagerMixin: +class QueryManagerMixin(Generic[ModelT]): + + @overload + def __init__(self, *args: models.Q): + ... - def __init__(self, *args, **kwargs): + @overload + def __init__(self, **kwargs: object): + ... + + def __init__(self, *args: models.Q, **kwargs: object): if args: self._q = args[0] else: self._q = models.Q(**kwargs) - self._order_by = None + self._order_by: tuple[Any, ...] | None = None super().__init__() - def order_by(self, *args): + def order_by(self, *args: Any) -> QueryManager[ModelT]: self._order_by = args - return self + return cast('QueryManager[ModelT]', self) - def get_queryset(self): - qs = super().get_queryset().filter(self._q) + def get_queryset(self) -> QuerySet[ModelT]: + qs = super().get_queryset() # type: ignore[misc] + qs = qs.filter(self._q) if self._order_by is not None: return qs.order_by(*self._order_by) return qs -class QueryManager(QueryManagerMixin, models.Manager): +class QueryManager(QueryManagerMixin[ModelT], models.Manager[ModelT]): # type: ignore[misc] pass -class SoftDeletableQuerySetMixin: +class SoftDeletableQuerySetMixin(Generic[ModelT]): """ QuerySet for SoftDeletableModel. Instead of removing instance sets its ``is_removed`` field to True. """ - def delete(self): + def delete(self) -> None: """ Soft delete objects from queryset (set their ``is_removed`` field to True) """ - self.update(is_removed=True) + cast(QuerySet[ModelT], self).update(is_removed=True) -class SoftDeletableQuerySet(SoftDeletableQuerySetMixin, QuerySet): +# Note that our delete() method does not return anything, unlike Django's. +# https://github.com/jazzband/django-model-utils/issues/541 +class SoftDeletableQuerySet(SoftDeletableQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc] pass -class SoftDeletableManagerMixin: +class SoftDeletableManagerMixin(Generic[ModelT]): """ Manager that limits the queryset by default to show only not removed instances of model. """ _queryset_class = SoftDeletableQuerySet - def __init__(self, *args, _emit_deprecation_warnings=False, **kwargs): + _db: str | None + + def __init__( + self, + *args: object, + _emit_deprecation_warnings: bool = False, + **kwargs: object + ): self.emit_deprecation_warnings = _emit_deprecation_warnings super().__init__(*args, **kwargs) - def get_queryset(self): + def get_queryset(self) -> SoftDeletableQuerySet[ModelT]: """ Return queryset limited to not removed entries. """ + model: type[ModelT] = self.model # type: ignore[attr-defined] + if self.emit_deprecation_warnings: warning_message = ( "{0}.objects model manager will include soft-deleted objects in an " @@ -277,23 +409,23 @@ def get_queryset(self): "excluding soft-deleted objects. See " "https://django-model-utils.readthedocs.io/en/stable/models.html" "#softdeletablemodel for more information." - ).format(self.model.__class__.__name__) + ).format(model.__class__.__name__) warnings.warn(warning_message, DeprecationWarning) - kwargs = {'model': self.model, 'using': self._db} - if hasattr(self, '_hints'): - kwargs['hints'] = self._hints + return self._queryset_class( + model=model, + using=self._db, + **({'hints': self._hints} if hasattr(self, '_hints') else {}) + ).filter(is_removed=False) - return self._queryset_class(**kwargs).filter(is_removed=False) - -class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager): +class SoftDeletableManager(SoftDeletableManagerMixin[ModelT], models.Manager[ModelT]): pass -class JoinQueryset(models.QuerySet): +class JoinQueryset(models.QuerySet[Any]): - def join(self, qs=None): + def join(self, qs: QuerySet[Any] | None = None) -> QuerySet[Any]: ''' Join one queryset together with another using a temporary table. If no queryset is used, it will use the current queryset and join that @@ -308,11 +440,11 @@ def join(self, qs=None): to_field = 'id' if qs: - fk = [ + fks = [ fk for fk in qs.model._meta.fields if getattr(fk, 'related_model', None) == self.model ] - fk = fk[0] if fk else None + fk = fks[0] if fks else None model_set = f'{self.model.__name__.lower()}_set' key = fk or getattr(qs.model, model_set, None) @@ -331,7 +463,7 @@ def join(self, qs=None): else: fk_column = 'id' qs = self.only(fk_column) - new_qs = self.model.objects.all() + new_qs = self.model._default_manager.all() TABLE_NAME = 'temp_stuff' query, params = qs.query.sql_with_params() @@ -369,21 +501,24 @@ class Meta: return new_qs -class JoinManagerMixin: - """ - Manager that adds a method join. This method allows you to join two - querysets together. - """ - _queryset_class = JoinQueryset +if not TYPE_CHECKING: + # Hide deprecated API during type checking, to encourage switch to + # 'JoinQueryset.as_manager()', which is supported by the mypy plugin + # of django-stubs. - def get_queryset(self): - warnings.warn( - "JoinManager and JoinManagerMixin are deprecated. " - "Please use 'JoinQueryset.as_manager()' instead.", - DeprecationWarning - ) - return self._queryset_class(model=self.model, using=self._db) + class JoinManagerMixin: + """ + Manager that adds a method join. This method allows you to join two + querysets together. + """ + def get_queryset(self): + warnings.warn( + "JoinManager and JoinManagerMixin are deprecated. " + "Please use 'JoinQueryset.as_manager()' instead.", + DeprecationWarning + ) + return self._queryset_class(model=self.model, using=self._db) -class JoinManager(JoinManagerMixin, models.Manager): - pass + class JoinManager(JoinManagerMixin): + pass diff --git a/model_utils/models.py b/model_utils/models.py index 4eb0e63d..71d0055c 100644 --- a/model_utils/models.py +++ b/model_utils/models.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any, Literal, TypeVar, overload + from django.core.exceptions import ImproperlyConfigured from django.db import models from django.db.models.functions import Now @@ -12,6 +16,8 @@ ) from model_utils.managers import QueryManager, SoftDeletableManager +ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) + now = Now() @@ -24,7 +30,7 @@ class TimeStampedModel(models.Model): created = AutoCreatedField(_('created')) modified = AutoLastModifiedField(_('modified')) - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: """ Overriding the save method in order to make sure that modified field is updated even if it is not given as @@ -65,7 +71,7 @@ class StatusModel(models.Model): status = StatusField(_('status')) status_changed = MonitorField(_('status changed'), monitor='status') - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: """ Overriding the save method in order to make sure that status_changed field is updated even if it is not given as @@ -81,7 +87,7 @@ class Meta: abstract = True -def add_status_query_managers(sender, **kwargs): +def add_status_query_managers(sender: type[models.Model], **kwargs: Any) -> None: """ Add a Querymanager for each status item dynamically. @@ -90,6 +96,7 @@ def add_status_query_managers(sender, **kwargs): return default_manager = sender._meta.default_manager + assert default_manager is not None for value, display in getattr(sender, 'STATUS', ()): if _field_exists(sender, value): @@ -103,7 +110,7 @@ def add_status_query_managers(sender, **kwargs): sender._meta.default_manager_name = default_manager.name -def add_timeframed_query_manager(sender, **kwargs): +def add_timeframed_query_manager(sender: type[models.Model], **kwargs: Any) -> None: """ Add a QueryManager for a specific timeframe. @@ -126,7 +133,7 @@ def add_timeframed_query_manager(sender, **kwargs): models.signals.class_prepared.connect(add_timeframed_query_manager) -def _field_exists(model_class, field_name): +def _field_exists(model_class: type[models.Model], field_name: str) -> bool: return field_name in [f.attname for f in model_class._meta.local_fields] @@ -142,11 +149,28 @@ class SoftDeletableModel(models.Model): class Meta: abstract = True - objects = SoftDeletableManager(_emit_deprecation_warnings=True) - available_objects = SoftDeletableManager() + objects: models.Manager[SoftDeletableModel] = SoftDeletableManager(_emit_deprecation_warnings=True) + available_objects: models.Manager[SoftDeletableModel] = SoftDeletableManager() all_objects = models.Manager() - def delete(self, using=None, *args, soft=True, **kwargs): + # Note that soft delete does not return anything, + # which doesn't conform to Django's interface. + # https://github.com/jazzband/django-model-utils/issues/541 + @overload # type: ignore[override] + def delete( + self, using: Any = None, *args: Any, soft: Literal[True] = True, **kwargs: Any + ) -> None: + ... + + @overload + def delete( + self, using: Any = None, *args: Any, soft: Literal[False], **kwargs: Any + ) -> tuple[int, dict[str, int]]: + ... + + def delete( + self, using: Any = None, *args: Any, soft: bool = True, **kwargs: Any + ) -> tuple[int, dict[str, int]] | None: """ Soft delete object (set its ``is_removed`` field to True). Actually delete object if setting ``soft`` to False. @@ -154,6 +178,7 @@ def delete(self, using=None, *args, soft=True, **kwargs): if soft: self.is_removed = True self.save(using=using) + return None else: return super().delete(using, *args, **kwargs) diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 118d1acf..61093802 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -1,10 +1,45 @@ +from __future__ import annotations + from copy import deepcopy from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterable, + Protocol, + TypeVar, + cast, + overload, +) from django.core.exceptions import FieldError from django.db import models from django.db.models.fields.files import FieldFile +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + from types import TracebackType + + class _AugmentedModel(models.Model): + _instance_initialized: bool + _deferred_fields: set[str] + +T = TypeVar("T") + + +class Descriptor(Protocol[T]): + def __get__(self, instance: object, owner: type[object]) -> T: + ... + + def __set__(self, instance: object, value: T) -> None: + ... + + +class FullDescriptor(Descriptor[T]): + def __delete__(self, instance: object) -> None: + ... + class LightStateFieldFile(FieldFile): """ @@ -16,7 +51,7 @@ class LightStateFieldFile(FieldFile): Django 3.1+ can make the app unusable, as CPU and memory usage gets easily multiplied by magnitudes. """ - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: """ We don't need to deepcopy the instance, so nullify if provided. """ @@ -26,27 +61,35 @@ def __getstate__(self): return state -def lightweight_deepcopy(value): +def lightweight_deepcopy(value: T) -> T: """ Use our lightweight class to avoid copying the instance on a FieldFile deepcopy. """ if isinstance(value, FieldFile): - value = LightStateFieldFile( + value = cast(T, LightStateFieldFile( instance=value.instance, field=value.field, name=value.name, - ) + )) return deepcopy(value) -class DescriptorWrapper: +class DescriptorWrapper(Generic[T]): - def __init__(self, field_name, descriptor, tracker_attname): + def __init__(self, field_name: str, descriptor: Descriptor[T], tracker_attname: str): self.field_name = field_name self.descriptor = descriptor self.tracker_attname = tracker_attname - def __get__(self, instance, owner): + @overload + def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper[T]: + ... + + @overload + def __get__(self, instance: models.Model, owner: type[models.Model]) -> T: + ... + + def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper[T] | T: if instance is None: return self was_deferred = self.field_name in instance.get_deferred_fields() @@ -56,7 +99,7 @@ def __get__(self, instance, owner): tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value) return value - def __set__(self, instance, value): + def __set__(self, instance: models.Model, value: T) -> None: initialized = hasattr(instance, '_instance_initialized') was_deferred = self.field_name in instance.get_deferred_fields() @@ -79,23 +122,23 @@ def __set__(self, instance, value): else: instance.__dict__[self.field_name] = value - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> T: return getattr(self.descriptor, attr) @staticmethod - def cls_for_descriptor(descriptor): + def cls_for_descriptor(descriptor: Descriptor[T]) -> type[DescriptorWrapper[T]]: if hasattr(descriptor, '__delete__'): return FullDescriptorWrapper else: return DescriptorWrapper -class FullDescriptorWrapper(DescriptorWrapper): +class FullDescriptorWrapper(DescriptorWrapper[T]): """ Wrapper for descriptors with all three descriptor methods. """ - def __delete__(self, obj): - self.descriptor.__delete__(obj) + def __delete__(self, obj: models.Model) -> None: + cast(FullDescriptor[T], self.descriptor).__delete__(obj) class FieldsContext: @@ -119,7 +162,12 @@ class FieldsContext: """ - def __init__(self, tracker, *fields, state=None): + def __init__( + self, + tracker: FieldInstanceTracker, + *fields: str, + state: dict[str, int] | None = None + ): """ :param tracker: FieldInstanceTracker instance to be reset after context exit @@ -137,7 +185,7 @@ def __init__(self, tracker, *fields, state=None): self.fields = fields self.state = state - def __enter__(self): + def __enter__(self) -> FieldsContext: """ Increments tracked fields occurrences count in shared state. """ @@ -146,7 +194,12 @@ def __enter__(self): self.state[f] += 1 return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: """ Decrements tracked fields occurrences count in shared state. @@ -164,29 +217,34 @@ def __exit__(self, exc_type, exc_val, exc_tb): class FieldInstanceTracker: - def __init__(self, instance, fields, field_map): - self.instance = instance + def __init__(self, instance: models.Model, fields: Iterable[str], field_map: Mapping[str, str]): + self.instance = cast('_AugmentedModel', instance) self.fields = fields self.field_map = field_map self.context = FieldsContext(self, *self.fields) - def __enter__(self): + def __enter__(self) -> FieldsContext: return self.context.__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: return self.context.__exit__(exc_type, exc_val, exc_tb) - def __call__(self, *fields): + def __call__(self, *fields: str) -> FieldsContext: return FieldsContext(self, *fields, state=self.context.state) @property - def deferred_fields(self): + def deferred_fields(self) -> set[str]: return self.instance.get_deferred_fields() - def get_field_value(self, field): + def get_field_value(self, field: str) -> Any: return getattr(self.instance, self.field_map[field]) - def set_saved_fields(self, fields=None): + def set_saved_fields(self, fields: Iterable[str] | None = None) -> None: if not self.instance.pk: self.saved_data = {} elif fields is None: @@ -198,7 +256,7 @@ def set_saved_fields(self, fields=None): for field, field_value in self.saved_data.items(): self.saved_data[field] = lightweight_deepcopy(field_value) - def current(self, fields=None): + def current(self, fields: Iterable[str] | None = None) -> dict[str, Any]: """Returns dict of current values for all tracked fields""" if fields is None: deferred_fields = self.deferred_fields @@ -212,17 +270,19 @@ def current(self, fields=None): return {f: self.get_field_value(f) for f in fields} - def has_changed(self, field): + def has_changed(self, field: str) -> bool: """Returns ``True`` if field has changed from currently saved value""" if field in self.fields: # deferred fields haven't changed if field in self.deferred_fields and field not in self.instance.__dict__: return False - return self.previous(field) != self.get_field_value(field) + prev: object = self.previous(field) + curr: object = self.get_field_value(field) + return prev != curr else: raise FieldError('field "%s" not tracked' % field) - def previous(self, field): + def previous(self, field: str) -> Any: """Returns currently saved value of given field""" # handle deferred fields that have not yet been loaded from the database @@ -242,7 +302,7 @@ def previous(self, field): return self.saved_data.get(field) - def changed(self): + def changed(self) -> dict[str, Any]: """Returns dict of fields that changed since save (with old values)""" return { field: self.previous(field) @@ -255,13 +315,34 @@ class FieldTracker: tracker_class = FieldInstanceTracker - def __init__(self, fields=None): - self.fields = fields - - def __call__(self, func=None, fields=None): - def decorator(f): + def __init__(self, fields: Iterable[str] | None = None): + # finalize_class() will replace None; pretend it is never None. + self.fields = cast(Iterable[str], fields) + + @overload + def __call__( + self, + func: None = None, + fields: Iterable[str] | None = None + ) -> Callable[[Callable[..., T]], Callable[..., T]]: + ... + + @overload + def __call__( + self, + func: Callable[..., T], + fields: Iterable[str] | None = None + ) -> Callable[..., T]: + ... + + def __call__( + self, + func: Callable[..., T] | None = None, + fields: Iterable[str] | None = None + ) -> Callable[[Callable[..., T]], Callable[..., T]] | Callable[..., T]: + def decorator(f: Callable[..., T]) -> Callable[..., T]: @wraps(f) - def inner(obj, *args, **kwargs): + def inner(obj: models.Model, *args: object, **kwargs: object) -> T: tracker = getattr(obj, self.attname) field_list = tracker.fields if fields is None else fields with tracker(*field_list): @@ -272,7 +353,7 @@ def inner(obj, *args, **kwargs): return decorator return decorator(func) - def get_field_map(self, cls): + def get_field_map(self, cls: type[models.Model]) -> dict[str, str]: """Returns dict mapping fields names to model attribute names""" field_map = {field: field for field in self.fields} all_fields = {f.name: f.attname for f in cls._meta.fields} @@ -280,17 +361,17 @@ def get_field_map(self, cls): if k in field_map}) return field_map - def contribute_to_class(self, cls, name): + def contribute_to_class(self, cls: type[models.Model], name: str) -> None: self.name = name self.attname = '_%s' % name models.signals.class_prepared.connect(self.finalize_class, sender=cls) - def finalize_class(self, sender, **kwargs): - if self.fields is None: + def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None: + if self.fields is None or TYPE_CHECKING: self.fields = (field.attname for field in sender._meta.fields) self.fields = set(self.fields) for field_name in self.fields: - descriptor = getattr(sender, field_name) + descriptor: models.Field[Any, Any] = getattr(sender, field_name) wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor) wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname) setattr(sender, field_name, wrapped_descriptor) @@ -300,34 +381,39 @@ def finalize_class(self, sender, **kwargs): setattr(sender, self.name, self) self.patch_save(sender) - def initialize_tracker(self, sender, instance, **kwargs): + def initialize_tracker( + self, + sender: type[models.Model], + instance: models.Model, + **kwargs: object + ) -> None: if not isinstance(instance, self.model_class): return # Only init instances of given model (including children) tracker = self.tracker_class(instance, self.fields, self.field_map) setattr(instance, self.attname, tracker) tracker.set_saved_fields() - instance._instance_initialized = True + cast('_AugmentedModel', instance)._instance_initialized = True - def patch_init(self, model): + def patch_init(self, model: type[models.Model]) -> None: original = getattr(model, '__init__') @wraps(original) - def inner(instance, *args, **kwargs): + def inner(instance: models.Model, *args: Any, **kwargs: Any) -> None: original(instance, *args, **kwargs) self.initialize_tracker(model, instance) setattr(model, '__init__', inner) - def patch_save(self, model): + def patch_save(self, model: type[models.Model]) -> None: self._patch(model, 'save_base', 'update_fields') self._patch(model, 'refresh_from_db', 'fields') - def _patch(self, model, method, fields_kwarg): + def _patch(self, model: type[models.Model], method: str, fields_kwarg: str) -> None: original = getattr(model, method) @wraps(original) - def inner(instance, *args, **kwargs): - update_fields = kwargs.get(fields_kwarg) + def inner(instance: models.Model, *args: object, **kwargs: Any) -> object: + update_fields: Iterable[str] | None = kwargs.get(fields_kwarg) if update_fields is None: fields = self.fields else: @@ -341,7 +427,15 @@ def inner(instance, *args, **kwargs): setattr(model, method, inner) - def __get__(self, instance, owner): + @overload + def __get__(self, instance: None, owner: type[models.Model]) -> FieldTracker: + ... + + @overload + def __get__(self, instance: models.Model, owner: type[models.Model]) -> FieldInstanceTracker: + ... + + def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> FieldTracker | FieldInstanceTracker: if instance is None: return self else: @@ -350,16 +444,18 @@ def __get__(self, instance, owner): class ModelInstanceTracker(FieldInstanceTracker): - def has_changed(self, field): + def has_changed(self, field: str) -> bool: """Returns ``True`` if field has changed from currently saved value""" if not self.instance.pk: return True elif field in self.saved_data: - return self.previous(field) != self.get_field_value(field) + prev: object = self.previous(field) + curr: object = self.get_field_value(field) + return prev != curr else: raise FieldError('field "%s" not tracked' % field) - def changed(self): + def changed(self) -> dict[str, Any]: """Returns dict of fields that changed since save (with old values)""" if not self.instance.pk: return {} @@ -371,5 +467,5 @@ def changed(self): class ModelTracker(FieldTracker): tracker_class = ModelInstanceTracker - def get_field_map(self, cls): + def get_field_map(self, cls: type[models.Model]) -> dict[str, str]: return {field: field for field in self.fields} diff --git a/mypy.ini b/mypy.ini index 918bc713..d0264e95 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,4 +1,6 @@ [mypy] +disallow_incomplete_defs=True +disallow_untyped_defs=True implicit_reexport=False pretty=True show_error_codes=True diff --git a/requirements-mypy.txt b/requirements-mypy.txt index c0422a75..a7b6bb92 100644 --- a/requirements-mypy.txt +++ b/requirements-mypy.txt @@ -1,2 +1,3 @@ -mypy==1.9.0 -django-stubs==4.2.7 +mypy==1.10.0 +django-stubs==5.0.2 +pytest==7.4.3 diff --git a/tests/fields.py b/tests/fields.py index fa302039..d57960cf 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -1,7 +1,12 @@ +from __future__ import annotations + +from typing import Any + from django.db import models +from django.db.backends.base.base import BaseDatabaseWrapper -def mutable_from_db(value): +def mutable_from_db(value: object) -> Any: if value == '': return None try: @@ -12,7 +17,7 @@ def mutable_from_db(value): return value -def mutable_to_db(value): +def mutable_to_db(value: object) -> str: if value is None: return '' if isinstance(value, list): @@ -21,12 +26,12 @@ def mutable_to_db(value): class MutableField(models.TextField): - def to_python(self, value): + def to_python(self, value: object) -> Any: return mutable_from_db(value) - def from_db_value(self, value, expression, connection): + def from_db_value(self, value: object, expression: object, connection: BaseDatabaseWrapper) -> Any: return mutable_from_db(value) - def get_db_prep_save(self, value, connection): + def get_db_prep_save(self, value: object, connection: BaseDatabaseWrapper) -> str: value = super().get_db_prep_save(value, connection) return mutable_to_db(value) diff --git a/tests/models.py b/tests/models.py index 8b496540..4d345050 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import ClassVar +from typing import Any, ClassVar, TypeVar, overload from django.db import models from django.db.models import Manager +from django.db.models.query import QuerySet from django.db.models.query_utils import DeferredAttribute from django.utils.translation import gettext_lazy as _ @@ -11,7 +12,7 @@ from model_utils.fields import MonitorField, SplitField, StatusField, UUIDField from model_utils.managers import ( InheritanceManager, - JoinManager, + JoinQueryset, QueryManager, SoftDeletableManager, SoftDeletableQuerySet, @@ -26,6 +27,8 @@ from model_utils.tracker import FieldTracker, ModelTracker from tests.fields import MutableField +ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) + class InheritanceManagerTestRelated(models.Model): pass @@ -43,7 +46,7 @@ class InheritanceManagerTestParent(models.Model): on_delete=models.CASCADE) objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager() - def __str__(self): + def __str__(self) -> str: return "{}({})".format( self.__class__.__name__[len('InheritanceManagerTest'):], self.pk, @@ -128,7 +131,7 @@ class DoubleMonitored(models.Model): class Status(StatusModel): - STATUS = Choices( + STATUS: Choices[str] = Choices( ("active", _("active")), ("deleted", _("deleted")), ("on_hold", _("on hold")), @@ -184,7 +187,8 @@ class Post(models.Model): public: ClassVar[QueryManager[Post]] = QueryManager(published=True) public_confirmed: ClassVar[QueryManager[Post]] = QueryManager( models.Q(published=True) & models.Q(confirmed=True)) - public_reversed = QueryManager(published=True).order_by("-order") + public_reversed: ClassVar[QueryManager[Post]] = QueryManager( + published=True).order_by("-order") class Meta: ordering = ("order",) @@ -203,7 +207,6 @@ class Meta: class AbstractTracked(models.Model): - number: models.IntegerField class Meta: abstract = True @@ -216,7 +219,7 @@ class Tracked(models.Model): tracker = FieldTracker() - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: """ No-op save() to ensure that FieldTracker.patch_save() works. """ super().save(*args, **kwargs) @@ -228,7 +231,7 @@ class TrackerTimeStamped(TimeStampedModel): tracker = FieldTracker() - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: """ Automatically add "modified" to update_fields.""" update_fields = kwargs.get('update_fields') if update_fields is not None: @@ -263,7 +266,7 @@ class TrackedNonFieldAttr(models.Model): number = models.FloatField() @property - def rounded(self): + def rounded(self) -> int | None: return round(self.number) if self.number is not None else None tracker = FieldTracker(fields=['rounded']) @@ -352,43 +355,52 @@ class SoftDeletable(SoftDeletableModel): all_objects: ClassVar[Manager[SoftDeletable]] = models.Manager() -class CustomSoftDeleteQuerySet(SoftDeletableQuerySet): - def only_read(self): +class CustomSoftDeleteQuerySet(SoftDeletableQuerySet[ModelT]): + def only_read(self) -> QuerySet[ModelT]: return self.filter(is_read=True) class CustomSoftDelete(SoftDeletableModel): is_read = models.BooleanField(default=False) - available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)() # type: ignore[misc] + available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)() class StringyDescriptor: """ Descriptor that returns a string version of the underlying integer value. """ - def __init__(self, name): + def __init__(self, name: str): self.name = name - def __get__(self, obj, cls=None): + @overload + def __get__(self, obj: None, cls: type[models.Model] | None = None) -> StringyDescriptor: + ... + + @overload + def __get__(self, obj: models.Model, cls: type[models.Model]) -> str: + ... + + def __get__(self, obj: models.Model | None, cls: type[models.Model] | None = None) -> StringyDescriptor | str: if obj is None: return self if self.name in obj.get_deferred_fields(): # This queries the database, and sets the value on the instance. + assert cls is not None fields_map = {f.name: f for f in cls._meta.fields} field = fields_map[self.name] DeferredAttribute(field=field).__get__(obj, cls) return str(obj.__dict__[self.name]) - def __set__(self, obj, value): + def __set__(self, obj: object, value: str) -> None: obj.__dict__[self.name] = int(value) - def __delete__(self, obj): + def __delete__(self, obj: object) -> None: del obj.__dict__[self.name] class CustomDescriptorField(models.IntegerField): - def contribute_to_class(self, cls, name, *args, **kwargs): + def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None: super().contribute_to_class(cls, name, *args, **kwargs) setattr(cls, name, StringyDescriptor(name)) @@ -404,7 +416,7 @@ class ModelWithCustomDescriptor(models.Model): class BoxJoinModel(models.Model): name = models.CharField(max_length=32) - objects: ClassVar[JoinManager[BoxJoinModel]] = JoinManager() + objects = JoinQueryset.as_manager() class JoinItemForeignKey(models.Model): @@ -414,7 +426,7 @@ class JoinItemForeignKey(models.Model): null=True, on_delete=models.CASCADE ) - objects: ClassVar[JoinManager[JoinItemForeignKey]] = JoinManager() + objects = JoinQueryset.as_manager() class CustomUUIDModel(UUIDModel): diff --git a/tests/test_choices.py b/tests/test_choices.py index 6d09319b..973b5968 100644 --- a/tests/test_choices.py +++ b/tests/test_choices.py @@ -1,116 +1,129 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +import pytest from django.test import TestCase from model_utils import Choices +T = TypeVar("T") -class ChoicesTests(TestCase): - def setUp(self): - self.STATUS = Choices('DRAFT', 'PUBLISHED') - - def test_getattr(self): - self.assertEqual(self.STATUS.DRAFT, 'DRAFT') - def test_indexing(self): - self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED') +class ChoicesTestsMixin(Generic[T]): - def test_iteration(self): - self.assertEqual(tuple(self.STATUS), - (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED'))) + STATUS: Choices[T] - def test_reversed(self): - self.assertEqual(tuple(reversed(self.STATUS)), - (('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT'))) + def test_getattr(self) -> None: + assert self.STATUS.DRAFT == 'DRAFT' - def test_len(self): - self.assertEqual(len(self.STATUS), 2) + def test_len(self) -> None: + assert len(self.STATUS) == 2 - def test_repr(self): - self.assertEqual(repr(self.STATUS), "Choices" + repr(( + def test_repr(self) -> None: + assert repr(self.STATUS) == "Choices" + repr(( ('DRAFT', 'DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED', 'PUBLISHED'), - ))) - - def test_wrong_length_tuple(self): - with self.assertRaises(ValueError): - Choices(('a',)) - - def test_contains_value(self): - self.assertTrue('PUBLISHED' in self.STATUS) - self.assertTrue('DRAFT' in self.STATUS) + )) - def test_doesnt_contain_value(self): - self.assertFalse('UNPUBLISHED' in self.STATUS) + def test_wrong_length_tuple(self) -> None: + with pytest.raises(ValueError): + Choices(('a',)) # type: ignore[arg-type] - def test_deepcopy(self): + def test_deepcopy(self) -> None: import copy - self.assertEqual(list(self.STATUS), - list(copy.deepcopy(self.STATUS))) + assert list(self.STATUS) == list(copy.deepcopy(self.STATUS)) + + def test_equality(self) -> None: + assert self.STATUS == Choices('DRAFT', 'PUBLISHED') + + def test_inequality(self) -> None: + assert self.STATUS != ['DRAFT', 'PUBLISHED'] + assert self.STATUS != Choices('DRAFT') + + def test_composability(self) -> None: + assert Choices('DRAFT') + Choices('PUBLISHED') == self.STATUS + assert Choices('DRAFT') + ('PUBLISHED',) == self.STATUS + assert ('DRAFT',) + Choices('PUBLISHED') == self.STATUS + + def test_option_groups(self) -> None: + # Note: The implementation accepts any kind of sequence, but the type system can only + # track per-index types for tuples. + if TYPE_CHECKING: + c = Choices(('group a', ['one', 'two']), ('group b', ('three',))) + else: + c = Choices(('group a', ['one', 'two']), ['group b', ('three',)]) + assert list(c) == [ + ('group a', [('one', 'one'), ('two', 'two')]), + ('group b', [('three', 'three')]), + ] + + +class ChoicesTests(TestCase, ChoicesTestsMixin[str]): + def setUp(self) -> None: + self.STATUS = Choices('DRAFT', 'PUBLISHED') - def test_equality(self): - self.assertEqual(self.STATUS, Choices('DRAFT', 'PUBLISHED')) + def test_indexing(self) -> None: + self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED') - def test_inequality(self): - self.assertNotEqual(self.STATUS, ['DRAFT', 'PUBLISHED']) - self.assertNotEqual(self.STATUS, Choices('DRAFT')) + def test_iteration(self) -> None: + self.assertEqual(tuple(self.STATUS), + (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED'))) - def test_composability(self): - self.assertEqual(Choices('DRAFT') + Choices('PUBLISHED'), self.STATUS) - self.assertEqual(Choices('DRAFT') + ('PUBLISHED',), self.STATUS) - self.assertEqual(('DRAFT',) + Choices('PUBLISHED'), self.STATUS) + def test_reversed(self) -> None: + self.assertEqual(tuple(reversed(self.STATUS)), + (('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT'))) - def test_option_groups(self): - c = Choices(('group a', ['one', 'two']), ['group b', ('three',)]) - self.assertEqual( - list(c), - [ - ('group a', [('one', 'one'), ('two', 'two')]), - ('group b', [('three', 'three')]), - ], - ) + def test_contains_value(self) -> None: + self.assertTrue('PUBLISHED' in self.STATUS) + self.assertTrue('DRAFT' in self.STATUS) + + def test_doesnt_contain_value(self) -> None: + self.assertFalse('UNPUBLISHED' in self.STATUS) -class LabelChoicesTests(ChoicesTests): - def setUp(self): +class LabelChoicesTests(TestCase, ChoicesTestsMixin[str]): + def setUp(self) -> None: self.STATUS = Choices( ('DRAFT', 'is draft'), ('PUBLISHED', 'is published'), 'DELETED', ) - def test_iteration(self): + def test_iteration(self) -> None: self.assertEqual(tuple(self.STATUS), ( ('DRAFT', 'is draft'), ('PUBLISHED', 'is published'), ('DELETED', 'DELETED'), )) - def test_reversed(self): + def test_reversed(self) -> None: self.assertEqual(tuple(reversed(self.STATUS)), ( ('DELETED', 'DELETED'), ('PUBLISHED', 'is published'), ('DRAFT', 'is draft'), )) - def test_indexing(self): + def test_indexing(self) -> None: self.assertEqual(self.STATUS['PUBLISHED'], 'is published') - def test_default(self): + def test_default(self) -> None: self.assertEqual(self.STATUS.DELETED, 'DELETED') - def test_provided(self): + def test_provided(self) -> None: self.assertEqual(self.STATUS.DRAFT, 'DRAFT') - def test_len(self): + def test_len(self) -> None: self.assertEqual(len(self.STATUS), 3) - def test_equality(self): + def test_equality(self) -> None: self.assertEqual(self.STATUS, Choices( ('DRAFT', 'is draft'), ('PUBLISHED', 'is published'), 'DELETED', )) - def test_inequality(self): + def test_inequality(self) -> None: self.assertNotEqual(self.STATUS, [ ('DRAFT', 'is draft'), ('PUBLISHED', 'is published'), @@ -118,27 +131,27 @@ def test_inequality(self): ]) self.assertNotEqual(self.STATUS, Choices('DRAFT')) - def test_repr(self): + def test_repr(self) -> None: self.assertEqual(repr(self.STATUS), "Choices" + repr(( ('DRAFT', 'DRAFT', 'is draft'), ('PUBLISHED', 'PUBLISHED', 'is published'), ('DELETED', 'DELETED', 'DELETED'), ))) - def test_contains_value(self): + def test_contains_value(self) -> None: self.assertTrue('PUBLISHED' in self.STATUS) self.assertTrue('DRAFT' in self.STATUS) # This should be True, because both the display value # and the internal representation are both DELETED. self.assertTrue('DELETED' in self.STATUS) - def test_doesnt_contain_value(self): + def test_doesnt_contain_value(self) -> None: self.assertFalse('UNPUBLISHED' in self.STATUS) - def test_doesnt_contain_display_value(self): + def test_doesnt_contain_display_value(self) -> None: self.assertFalse('is draft' in self.STATUS) - def test_composability(self): + def test_composability(self) -> None: self.assertEqual( Choices(('DRAFT', 'is draft',)) + Choices(('PUBLISHED', 'is published'), 'DELETED'), self.STATUS @@ -154,11 +167,17 @@ def test_composability(self): self.STATUS ) - def test_option_groups(self): - c = Choices( - ('group a', [(1, 'one'), (2, 'two')]), - ['group b', ((3, 'three'),)] - ) + def test_option_groups(self) -> None: + if TYPE_CHECKING: + c = Choices[int]( + ('group a', [(1, 'one'), (2, 'two')]), + ('group b', ((3, 'three'),)) + ) + else: + c = Choices( + ('group a', [(1, 'one'), (2, 'two')]), + ['group b', ((3, 'three'),)] + ) self.assertEqual( list(c), [ @@ -168,65 +187,65 @@ def test_option_groups(self): ) -class IdentifierChoicesTests(ChoicesTests): - def setUp(self): +class IdentifierChoicesTests(TestCase, ChoicesTestsMixin[int]): + def setUp(self) -> None: self.STATUS = Choices( (0, 'DRAFT', 'is draft'), (1, 'PUBLISHED', 'is published'), (2, 'DELETED', 'is deleted')) - def test_iteration(self): + def test_iteration(self) -> None: self.assertEqual(tuple(self.STATUS), ( (0, 'is draft'), (1, 'is published'), (2, 'is deleted'), )) - def test_reversed(self): + def test_reversed(self) -> None: self.assertEqual(tuple(reversed(self.STATUS)), ( (2, 'is deleted'), (1, 'is published'), (0, 'is draft'), )) - def test_indexing(self): + def test_indexing(self) -> None: self.assertEqual(self.STATUS[1], 'is published') - def test_getattr(self): + def test_getattr(self) -> None: self.assertEqual(self.STATUS.DRAFT, 0) - def test_len(self): + def test_len(self) -> None: self.assertEqual(len(self.STATUS), 3) - def test_repr(self): + def test_repr(self) -> None: self.assertEqual(repr(self.STATUS), "Choices" + repr(( (0, 'DRAFT', 'is draft'), (1, 'PUBLISHED', 'is published'), (2, 'DELETED', 'is deleted'), ))) - def test_contains_value(self): + def test_contains_value(self) -> None: self.assertTrue(0 in self.STATUS) self.assertTrue(1 in self.STATUS) self.assertTrue(2 in self.STATUS) - def test_doesnt_contain_value(self): + def test_doesnt_contain_value(self) -> None: self.assertFalse(3 in self.STATUS) - def test_doesnt_contain_display_value(self): - self.assertFalse('is draft' in self.STATUS) + def test_doesnt_contain_display_value(self) -> None: + self.assertFalse('is draft' in self.STATUS) # type: ignore[operator] - def test_doesnt_contain_python_attr(self): - self.assertFalse('PUBLISHED' in self.STATUS) + def test_doesnt_contain_python_attr(self) -> None: + self.assertFalse('PUBLISHED' in self.STATUS) # type: ignore[operator] - def test_equality(self): + def test_equality(self) -> None: self.assertEqual(self.STATUS, Choices( (0, 'DRAFT', 'is draft'), (1, 'PUBLISHED', 'is published'), (2, 'DELETED', 'is deleted') )) - def test_inequality(self): + def test_inequality(self) -> None: self.assertNotEqual(self.STATUS, [ (0, 'DRAFT', 'is draft'), (1, 'PUBLISHED', 'is published'), @@ -234,7 +253,7 @@ def test_inequality(self): ]) self.assertNotEqual(self.STATUS, Choices('DRAFT')) - def test_composability(self): + def test_composability(self) -> None: self.assertEqual( Choices( (0, 'DRAFT', 'is draft'), @@ -265,11 +284,17 @@ def test_composability(self): self.STATUS ) - def test_option_groups(self): - c = Choices( - ('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]), - ['group b', ((3, 'THREE', 'three'),)] - ) + def test_option_groups(self) -> None: + if TYPE_CHECKING: + c = Choices[int]( + ('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]), + ('group b', ((3, 'THREE', 'three'),)) + ) + else: + c = Choices( + ('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]), + ['group b', ((3, 'THREE', 'three'),)] + ) self.assertEqual( list(c), [ @@ -281,26 +306,26 @@ def test_option_groups(self): class SubsetChoicesTest(TestCase): - def setUp(self): - self.choices = Choices( + def setUp(self) -> None: + self.choices = Choices[int]( (0, 'a', 'A'), (1, 'b', 'B'), ) - def test_nonexistent_identifiers_raise(self): + def test_nonexistent_identifiers_raise(self) -> None: with self.assertRaises(ValueError): self.choices.subset('a', 'c') - def test_solo_nonexistent_identifiers_raise(self): + def test_solo_nonexistent_identifiers_raise(self) -> None: with self.assertRaises(ValueError): self.choices.subset('c') - def test_empty_subset_passes(self): + def test_empty_subset_passes(self) -> None: subset = self.choices.subset() self.assertEqual(subset, Choices()) - def test_subset_returns_correct_subset(self): + def test_subset_returns_correct_subset(self) -> None: subset = self.choices.subset('a') self.assertEqual(subset, Choices((0, 'a', 'A'))) diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py index f239bd37..81db1ec4 100644 --- a/tests/test_fields/test_field_tracker.py +++ b/tests/test_fields/test_field_tracker.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any + from django.core.cache import cache from django.core.exceptions import FieldError from django.db import models @@ -7,7 +9,7 @@ from django.test import TestCase from model_utils import FieldTracker -from model_utils.tracker import DescriptorWrapper +from model_utils.tracker import DescriptorWrapper, FieldInstanceTracker from tests.models import ( InheritedModelTracked, InheritedTracked, @@ -26,12 +28,18 @@ TrackerTimeStamped, ) +if TYPE_CHECKING: + MixinBase = TestCase +else: + MixinBase = object + -class FieldTrackerTestCase(TestCase): +class FieldTrackerMixin(MixinBase): - tracker = None + tracker: FieldInstanceTracker + instance: models.Model - def assertHasChanged(self, *, tracker=None, **kwargs): + def assertHasChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker for field, value in kwargs.items(): @@ -41,49 +49,57 @@ def assertHasChanged(self, *, tracker=None, **kwargs): else: self.assertEqual(tracker.has_changed(field), value) - def assertPrevious(self, *, tracker=None, **kwargs): + def assertPrevious(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker for field, value in kwargs.items(): self.assertEqual(tracker.previous(field), value) - def assertChanged(self, *, tracker=None, **kwargs): + def assertChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker self.assertEqual(tracker.changed(), kwargs) - def assertCurrent(self, *, tracker=None, **kwargs): + def assertCurrent(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker self.assertEqual(tracker.current(), kwargs) - def update_instance(self, **kwargs): + def update_instance(self, **kwargs: Any) -> None: for field, value in kwargs.items(): setattr(self.instance, field, value) self.instance.save() -class FieldTrackerCommonTests: +class FieldTrackerCommonMixin(FieldTrackerMixin): + + instance: ( + Tracked | TrackedNotDefault | TrackedMultiple + | ModelTracked | ModelTrackedNotDefault | ModelTrackedMultiple + | TrackedAbstract + ) - def test_pre_save_previous(self): + def test_pre_save_previous(self) -> None: self.assertPrevious(name=None, number=None) self.instance.name = 'new age' self.instance.number = 8 self.assertPrevious(name=None, number=None) -class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): +class FieldTrackerTests(FieldTrackerCommonMixin, TestCase): - tracked_class: type[models.Model] = Tracked + tracked_class: type[Tracked | ModelTracked | TrackedAbstract] = Tracked + instance: Tracked | ModelTracked | TrackedAbstract - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class() self.tracker = self.instance.tracker - def test_descriptor(self): - self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker)) + def test_descriptor(self) -> None: + tracker = self.tracked_class.tracker + self.assertTrue(isinstance(tracker, FieldTracker)) - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.assertChanged(name=None) self.instance.name = 'new age' self.assertChanged(name=None) @@ -94,7 +110,7 @@ def test_pre_save_changed(self): self.instance.mutable = [1, 2, 3] self.assertChanged(name=None, number=None, mutable=None) - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.assertHasChanged(name=True, number=False, mutable=False) self.instance.name = 'new age' self.assertHasChanged(name=True, number=False, mutable=False) @@ -103,12 +119,12 @@ def test_pre_save_has_changed(self): self.instance.mutable = [1, 2, 3] self.assertHasChanged(name=True, number=True, mutable=True) - def test_save_with_args(self): + def test_save_with_args(self) -> None: self.instance.number = 1 self.instance.save(False, False, None, None) self.assertChanged() - def test_first_save(self): + def test_first_save(self) -> None: self.assertHasChanged(name=True, number=False, mutable=False) self.assertPrevious(name=None, number=None, mutable=None) self.assertCurrent(name='', number=None, id=None, mutable=None) @@ -129,7 +145,7 @@ def test_first_save(self): with self.assertRaises(ValueError): self.instance.save(update_fields=['number']) - def test_post_save_has_changed(self): + def test_post_save_has_changed(self) -> None: self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.assertHasChanged(name=False, number=False, mutable=False) self.instance.name = 'new age' @@ -141,14 +157,14 @@ def test_post_save_has_changed(self): self.instance.name = 'retro' self.assertHasChanged(name=False, number=True, mutable=True) - def test_post_save_previous(self): + def test_post_save_previous(self) -> None: self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.instance.name = 'new age' self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3]) self.instance.mutable[1] = 4 self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3]) - def test_post_save_changed(self): + def test_post_save_changed(self) -> None: self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.assertChanged() self.instance.name = 'new age' @@ -162,7 +178,7 @@ def test_post_save_changed(self): self.instance.mutable = [1, 2, 3] self.assertChanged(number=4) - def test_current(self): + def test_current(self) -> None: self.assertCurrent(id=None, name='', number=None, mutable=None) self.instance.name = 'new age' self.assertCurrent(id=None, name='new age', number=None, mutable=None) @@ -175,7 +191,7 @@ def test_current(self): self.instance.save() self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1, 4, 3]) - def test_update_fields(self): + def test_update_fields(self) -> None: self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.assertChanged() self.instance.name = 'new age' @@ -198,7 +214,7 @@ def test_update_fields(self): self.assertEqual(in_db.number, self.instance.number) self.assertEqual(in_db.mutable, self.instance.mutable) - def test_refresh_from_db(self): + def test_refresh_from_db(self) -> None: self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.tracked_class.objects.filter(pk=self.instance.pk).update( name='new age', number=8, mutable=[3, 2, 1]) @@ -214,11 +230,12 @@ def test_refresh_from_db(self): self.instance.refresh_from_db() self.assertChanged() - def test_with_deferred(self): + def test_with_deferred(self) -> None: self.instance.name = 'new age' self.instance.number = 1 self.instance.save() item = self.tracked_class.objects.only('name').first() + assert item is not None self.assertTrue(item.get_deferred_fields()) # has_changed() returns False for deferred fields, without un-deferring them. @@ -234,6 +251,7 @@ def test_with_deferred(self): # examining a deferred field un-defers it item = self.tracked_class.objects.only('name').first() + assert item is not None self.assertEqual(item.number, 1) self.assertTrue('number' not in item.get_deferred_fields()) self.assertEqual(item.tracker.previous('number'), 1) @@ -252,6 +270,7 @@ def test_with_deferred(self): if self.tracked_class == Tracked: item = self.tracked_class.objects.only('name').first() + assert item is not None item.number = 2 # previous() fetches correct value from database after deferred field is assigned @@ -268,7 +287,7 @@ def test_with_deferred(self): class FieldTrackerMultipleInstancesTests(TestCase): - def test_with_deferred_fields_access_multiple(self): + def test_with_deferred_fields_access_multiple(self) -> None: Tracked.objects.create(pk=1, name='foo', number=1) Tracked.objects.create(pk=2, name='bar', number=2) @@ -278,16 +297,16 @@ def test_with_deferred_fields_access_multiple(self): instance.name -class FieldTrackedModelCustomTests(FieldTrackerTestCase, - FieldTrackerCommonTests): +class FieldTrackedModelCustomTests(FieldTrackerCommonMixin, TestCase): - tracked_class: type[models.Model] = TrackedNotDefault + tracked_class: type[TrackedNotDefault | ModelTrackedNotDefault] = TrackedNotDefault + instance: TrackedNotDefault | ModelTrackedNotDefault - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class() self.tracker = self.instance.name_tracker - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.assertChanged(name=None) self.instance.name = 'new age' self.assertChanged(name=None) @@ -296,7 +315,7 @@ def test_pre_save_changed(self): self.instance.name = '' self.assertChanged(name=None) - def test_first_save(self): + def test_first_save(self) -> None: self.assertHasChanged(name=True, number=None) self.assertPrevious(name=None, number=None) self.assertCurrent(name='') @@ -308,14 +327,14 @@ def test_first_save(self): self.assertCurrent(name='retro') self.assertChanged(name=None) - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.assertHasChanged(name=True, number=None) self.instance.name = 'new age' self.assertHasChanged(name=True, number=None) self.instance.number = 7 self.assertHasChanged(name=True, number=None) - def test_post_save_has_changed(self): + def test_post_save_has_changed(self) -> None: self.update_instance(name='retro', number=4) self.assertHasChanged(name=False, number=None) self.instance.name = 'new age' @@ -325,12 +344,12 @@ def test_post_save_has_changed(self): self.instance.name = 'retro' self.assertHasChanged(name=False, number=None) - def test_post_save_previous(self): + def test_post_save_previous(self) -> None: self.update_instance(name='retro', number=4) self.instance.name = 'new age' self.assertPrevious(name='retro', number=None) - def test_post_save_changed(self): + def test_post_save_changed(self) -> None: self.update_instance(name='retro', number=4) self.assertChanged() self.instance.name = 'new age' @@ -340,7 +359,7 @@ def test_post_save_changed(self): self.instance.name = 'retro' self.assertChanged() - def test_current(self): + def test_current(self) -> None: self.assertCurrent(name='') self.instance.name = 'new age' self.assertCurrent(name='new age') @@ -349,7 +368,7 @@ def test_current(self): self.instance.save() self.assertCurrent(name='new age') - def test_update_fields(self): + def test_update_fields(self) -> None: self.update_instance(name='retro', number=4) self.assertChanged() self.instance.name = 'new age' @@ -358,15 +377,16 @@ def test_update_fields(self): self.assertChanged() -class FieldTrackedModelAttributeTests(FieldTrackerTestCase): +class FieldTrackedModelAttributeTests(FieldTrackerMixin, TestCase): tracked_class = TrackedNonFieldAttr + instance: TrackedNonFieldAttr - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class() self.tracker = self.instance.tracker - def test_previous(self): + def test_previous(self) -> None: self.assertPrevious(rounded=None) self.instance.number = 7.5 self.assertPrevious(rounded=None) @@ -377,7 +397,7 @@ def test_previous(self): self.instance.save() self.assertPrevious(rounded=7) - def test_has_changed(self): + def test_has_changed(self) -> None: self.assertHasChanged(rounded=False) self.instance.number = 7.5 self.assertHasChanged(rounded=True) @@ -388,7 +408,7 @@ def test_has_changed(self): self.instance.number = 7.8 self.assertHasChanged(rounded=False) - def test_changed(self): + def test_changed(self) -> None: self.assertChanged() self.instance.number = 7.5 self.assertPrevious(rounded=None) @@ -401,7 +421,7 @@ def test_changed(self): self.instance.save() self.assertPrevious() - def test_current(self): + def test_current(self) -> None: self.assertCurrent(rounded=None) self.instance.number = 7.5 self.assertCurrent(rounded=8) @@ -409,17 +429,17 @@ def test_current(self): self.assertCurrent(rounded=8) -class FieldTrackedModelMultiTests(FieldTrackerTestCase, - FieldTrackerCommonTests): +class FieldTrackedModelMultiTests(FieldTrackerCommonMixin, TestCase): - tracked_class: type[models.Model] = TrackedMultiple + tracked_class: type[TrackedMultiple | ModelTrackedMultiple] = TrackedMultiple + instance: TrackedMultiple | ModelTrackedMultiple - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class() self.trackers = [self.instance.name_tracker, self.instance.number_tracker] - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.tracker = self.instance.name_tracker self.assertChanged(name=None) self.instance.name = 'new age' @@ -435,7 +455,7 @@ def test_pre_save_changed(self): self.instance.number = 8 self.assertChanged(number=None) - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.tracker = self.instance.name_tracker self.assertHasChanged(name=True, number=None) self.instance.name = 'new age' @@ -445,12 +465,12 @@ def test_pre_save_has_changed(self): self.instance.name = 'new age' self.assertHasChanged(name=None, number=False) - def test_pre_save_previous(self): + def test_pre_save_previous(self) -> None: for tracker in self.trackers: self.tracker = tracker super().test_pre_save_previous() - def test_post_save_has_changed(self): + def test_post_save_has_changed(self) -> None: self.update_instance(name='retro', number=4) self.assertHasChanged(tracker=self.trackers[0], name=False, number=None) self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) @@ -465,14 +485,14 @@ def test_post_save_has_changed(self): self.assertHasChanged(tracker=self.trackers[0], name=False, number=None) self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) - def test_post_save_previous(self): + def test_post_save_previous(self) -> None: self.update_instance(name='retro', number=4) self.instance.name = 'new age' self.instance.number = 8 self.assertPrevious(tracker=self.trackers[0], name='retro', number=None) self.assertPrevious(tracker=self.trackers[1], name=None, number=4) - def test_post_save_changed(self): + def test_post_save_changed(self) -> None: self.update_instance(name='retro', number=4) self.assertChanged(tracker=self.trackers[0]) self.assertChanged(tracker=self.trackers[1]) @@ -487,7 +507,7 @@ def test_post_save_changed(self): self.assertChanged(tracker=self.trackers[0]) self.assertChanged(tracker=self.trackers[1]) - def test_current(self): + def test_current(self) -> None: self.assertCurrent(tracker=self.trackers[0], name='') self.assertCurrent(tracker=self.trackers[1], number=None) self.instance.name = 'new age' @@ -501,88 +521,97 @@ def test_current(self): self.assertCurrent(tracker=self.trackers[1], number=8) -class FieldTrackerForeignKeyTests(FieldTrackerTestCase): +class FieldTrackerForeignKeyMixin(FieldTrackerMixin): - fk_class: type[models.Model] = Tracked - tracked_class: type[models.Model] = TrackedFK + fk_class: type[Tracked | ModelTracked] + tracked_class: type[TrackedFK | ModelTrackedFK] + instance: TrackedFK | ModelTrackedFK - def setUp(self): + def setUp(self) -> None: self.old_fk = self.fk_class.objects.create(number=8) - self.instance = self.tracked_class.objects.create(fk=self.old_fk) + self.instance = self.tracked_class.objects.create(fk=self.old_fk) # type: ignore[misc] - def test_default(self): + def test_default(self) -> None: self.tracker = self.instance.tracker self.assertChanged() self.assertPrevious() self.assertCurrent(id=self.instance.id, fk_id=self.old_fk.id) - self.instance.fk = self.fk_class.objects.create(number=8) + self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment] self.assertChanged(fk_id=self.old_fk.id) self.assertPrevious(fk_id=self.old_fk.id) self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id) - def test_custom(self): + def test_custom(self) -> None: self.tracker = self.instance.custom_tracker self.assertChanged() self.assertPrevious() self.assertCurrent(fk_id=self.old_fk.id) - self.instance.fk = self.fk_class.objects.create(number=8) + self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment] self.assertChanged(fk_id=self.old_fk.id) self.assertPrevious(fk_id=self.old_fk.id) self.assertCurrent(fk_id=self.instance.fk_id) - def test_custom_without_id(self): + def test_custom_without_id(self) -> None: with self.assertNumQueries(1): self.tracked_class.objects.get() self.tracker = self.instance.custom_tracker_without_id self.assertChanged() self.assertPrevious() self.assertCurrent(fk=self.old_fk.id) - self.instance.fk = self.fk_class.objects.create(number=8) + self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment] self.assertChanged(fk=self.old_fk.id) self.assertPrevious(fk=self.old_fk.id) self.assertCurrent(fk=self.instance.fk_id) -class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase): +class FieldTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase): + + fk_class = Tracked + tracked_class = TrackedFK + + +class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerMixin, TestCase): """Test that using `prefetch_related` on a tracked field does not raise a ValueError.""" fk_class = Tracked tracked_class = TrackedFK + instance: TrackedFK - def setUp(self): + def setUp(self) -> None: model_tracked = self.fk_class.objects.create(name="", number=0) self.instance = self.tracked_class.objects.create(fk=model_tracked) - def test_default(self): + def test_default(self) -> None: self.tracker = self.instance.tracker self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) - def test_custom(self): + def test_custom(self) -> None: self.tracker = self.instance.custom_tracker self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) - def test_custom_without_id(self): + def test_custom_without_id(self) -> None: self.tracker = self.instance.custom_tracker_without_id self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) -class FieldTrackerTimeStampedTests(FieldTrackerTestCase): +class FieldTrackerTimeStampedTests(FieldTrackerMixin, TestCase): fk_class = Tracked tracked_class = TrackerTimeStamped + instance: TrackerTimeStamped - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class.objects.create(name='old', number=1) self.tracker = self.instance.tracker - def test_set_modified_on_save(self): + def test_set_modified_on_save(self) -> None: old_modified = self.instance.modified self.instance.name = 'new' self.instance.save() self.assertGreater(self.instance.modified, old_modified) self.assertChanged() - def test_set_modified_on_save_update_fields(self): + def test_set_modified_on_save_update_fields(self) -> None: old_modified = self.instance.modified self.instance.name = 'new' self.instance.save(update_fields=('name',)) @@ -594,7 +623,7 @@ class InheritedFieldTrackerTests(FieldTrackerTests): tracked_class = InheritedTracked - def test_child_fields_not_tracked(self): + def test_child_fields_not_tracked(self) -> None: self.name2 = 'test' self.assertEqual(self.tracker.previous('name2'), None) self.assertRaises(FieldError, self.tracker.has_changed, 'name2') @@ -605,17 +634,18 @@ class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests): tracked_class = InheritedTrackedFK -class FieldTrackerFileFieldTests(FieldTrackerTestCase): +class FieldTrackerFileFieldTests(FieldTrackerMixin, TestCase): tracked_class = TrackedFileField + instance: TrackedFileField - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class() self.tracker = self.instance.tracker self.some_file = 'something.txt' self.another_file = 'another.txt' - def test_saved_data_without_instance(self): + def test_saved_data_without_instance(self) -> None: """ Tests that instance won't get copied by the Field Tracker. @@ -629,27 +659,27 @@ def test_saved_data_without_instance(self): self.assertEqual(self.tracker.saved_data, {}) self.update_instance(some_file=self.some_file) field_file_copy = self.tracker.saved_data.get('some_file') - self.assertIsNotNone(field_file_copy) + assert field_file_copy is not None self.assertEqual(field_file_copy.__getstate__().get('instance'), None) self.assertEqual(self.instance.some_file.instance, self.instance) self.assertIsInstance(self.instance.some_file, FieldFile) - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.assertChanged(some_file=None) self.instance.some_file = self.some_file self.assertChanged(some_file=None) - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.assertHasChanged(some_file=True) self.instance.some_file = self.some_file self.assertHasChanged(some_file=True) - def test_pre_save_previous(self): + def test_pre_save_previous(self) -> None: self.assertPrevious(some_file=None) self.instance.some_file = self.some_file self.assertPrevious(some_file=None) - def test_post_save_changed(self): + def test_post_save_changed(self) -> None: self.update_instance(some_file=self.some_file) self.assertChanged() previous_file = self.instance.some_file @@ -667,7 +697,7 @@ def test_post_save_changed(self): some_file=previous_file, ) - def test_post_save_has_changed(self): + def test_post_save_has_changed(self) -> None: self.update_instance(some_file=self.some_file) self.assertHasChanged(some_file=False) self.instance.some_file = self.another_file @@ -687,7 +717,7 @@ def test_post_save_has_changed(self): some_file=True, ) - def test_post_save_previous(self): + def test_post_save_previous(self) -> None: self.update_instance(some_file=self.some_file) previous_file = self.instance.some_file self.instance.some_file = self.another_file @@ -707,7 +737,7 @@ def test_post_save_previous(self): some_file=previous_file, ) - def test_current(self): + def test_current(self) -> None: self.assertCurrent(some_file=self.instance.some_file, id=None) self.instance.some_file = self.some_file self.assertCurrent(some_file=self.instance.some_file, id=None) @@ -730,9 +760,10 @@ def test_current(self): class ModelTrackerTests(FieldTrackerTests): - tracked_class: type[models.Model] = ModelTracked + tracked_class: type[ModelTracked | TrackedAbstract] = ModelTracked + instance: ModelTracked - def test_cache_compatible(self): + def test_cache_compatible(self) -> None: cache.set('key', self.instance) instance = cache.get('key') instance.number = 1 @@ -742,7 +773,7 @@ def test_cache_compatible(self): instance.number = 2 self.assertHasChanged(number=True) - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.assertChanged() self.instance.name = 'new age' self.assertChanged() @@ -753,7 +784,7 @@ def test_pre_save_changed(self): self.instance.mutable = [1, 2, 3] self.assertChanged() - def test_first_save(self): + def test_first_save(self) -> None: self.assertHasChanged(name=True, number=True, mutable=True) self.assertPrevious(name=None, number=None, mutable=None) self.assertCurrent(name='', number=None, id=None, mutable=None) @@ -774,7 +805,7 @@ def test_first_save(self): with self.assertRaises(ValueError): self.instance.save(update_fields=['number']) - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.assertHasChanged(name=True, number=True) self.instance.name = 'new age' self.assertHasChanged(name=True, number=True) @@ -786,7 +817,7 @@ class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests): tracked_class = ModelTrackedNotDefault - def test_first_save(self): + def test_first_save(self) -> None: self.assertHasChanged(name=True, number=True) self.assertPrevious(name=None, number=None) self.assertCurrent(name='') @@ -798,14 +829,14 @@ def test_first_save(self): self.assertCurrent(name='retro') self.assertChanged() - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.assertHasChanged(name=True, number=True) self.instance.name = 'new age' self.assertHasChanged(name=True, number=True) self.instance.number = 7 self.assertHasChanged(name=True, number=True) - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.assertChanged() self.instance.name = 'new age' self.assertChanged() @@ -819,7 +850,7 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests): tracked_class = ModelTrackedMultiple - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.tracker = self.instance.name_tracker self.assertHasChanged(name=True, number=True) self.instance.name = 'new age' @@ -829,7 +860,7 @@ def test_pre_save_has_changed(self): self.instance.name = 'new age' self.assertHasChanged(name=True, number=True) - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.tracker = self.instance.name_tracker self.assertChanged() self.instance.name = 'new age' @@ -846,12 +877,13 @@ def test_pre_save_changed(self): self.assertChanged() -class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests): +class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase): fk_class = ModelTracked tracked_class = ModelTrackedFK + instance: ModelTrackedFK - def test_custom_without_id(self): + def test_custom_without_id(self) -> None: with self.assertNumQueries(2): self.tracked_class.objects.get() self.tracker = self.instance.custom_tracker_without_id @@ -869,7 +901,7 @@ class InheritedModelTrackerTests(ModelTrackerTests): tracked_class = InheritedModelTracked - def test_child_fields_not_tracked(self): + def test_child_fields_not_tracked(self) -> None: self.name2 = 'test' self.assertEqual(self.tracker.previous('name2'), None) self.assertTrue(self.tracker.has_changed('name2')) @@ -882,19 +914,19 @@ class AbstractModelTrackerTests(ModelTrackerTests): class TrackerContextDecoratorTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.instance = Tracked.objects.create(number=1) self.tracker = self.instance.tracker - def assertChanged(self, *fields): + def assertChanged(self, *fields: str) -> None: for f in fields: self.assertTrue(self.tracker.has_changed(f)) - def assertNotChanged(self, *fields): + def assertNotChanged(self, *fields: str) -> None: for f in fields: self.assertFalse(self.tracker.has_changed(f)) - def test_context_manager(self): + def test_context_manager(self) -> None: with self.tracker: with self.tracker: self.instance.name = 'new' @@ -905,7 +937,7 @@ def test_context_manager(self): self.assertNotChanged('name') - def test_context_manager_fields(self): + def test_context_manager_fields(self) -> None: with self.tracker('number'): with self.tracker('number', 'name'): self.instance.name = 'new' @@ -918,10 +950,10 @@ def test_context_manager_fields(self): self.assertNotChanged('number', 'name') - def test_tracker_decorator(self): + def test_tracker_decorator(self) -> None: @Tracked.tracker - def tracked_method(obj): + def tracked_method(obj: Tracked) -> None: obj.name = 'new' self.assertChanged('name') @@ -929,10 +961,10 @@ def tracked_method(obj): self.assertNotChanged('name') - def test_tracker_decorator_fields(self): + def test_tracker_decorator_fields(self) -> None: @Tracked.tracker(fields=['name']) - def tracked_method(obj): + def tracked_method(obj: Tracked) -> None: obj.name = 'new' obj.number += 1 self.assertChanged('name', 'number') @@ -942,7 +974,7 @@ def tracked_method(obj): self.assertChanged('number') self.assertNotChanged('name') - def test_tracker_context_with_save(self): + def test_tracker_context_with_save(self) -> None: with self.tracker: self.instance.name = 'new' diff --git a/tests/test_fields/test_monitor_field.py b/tests/test_fields/test_monitor_field.py index f0041368..19ed9027 100644 --- a/tests/test_fields/test_monitor_field.py +++ b/tests/test_fields/test_monitor_field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timezone import time_machine @@ -8,33 +10,33 @@ class MonitorFieldTests(TestCase): - def setUp(self): + def setUp(self) -> None: with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)): self.instance = Monitored(name='Charlie') self.created = self.instance.name_changed - def test_save_no_change(self): + def test_save_no_change(self) -> None: self.instance.save() self.assertEqual(self.instance.name_changed, self.created) - def test_save_changed(self): + def test_save_changed(self) -> None: with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)): self.instance.name = 'Maria' self.instance.save() self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)) - def test_double_save(self): + def test_double_save(self) -> None: self.instance.name = 'Jose' self.instance.save() changed = self.instance.name_changed self.instance.save() self.assertEqual(self.instance.name_changed, changed) - def test_no_monitor_arg(self): + def test_no_monitor_arg(self) -> None: with self.assertRaises(TypeError): - MonitorField() + MonitorField() # type: ignore[call-arg] - def test_monitor_default_is_none_when_nullable(self): + def test_monitor_default_is_none_when_nullable(self) -> None: self.assertIsNone(self.instance.name_changed_nullable) expected_datetime = datetime(2022, 1, 18, 12, 0, 0, tzinfo=timezone.utc) @@ -49,33 +51,33 @@ class MonitorWhenFieldTests(TestCase): """ Will record changes only when name is 'Jose' or 'Maria' """ - def setUp(self): + def setUp(self) -> None: with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)): self.instance = MonitorWhen(name='Charlie') self.created = self.instance.name_changed - def test_save_no_change(self): + def test_save_no_change(self) -> None: self.instance.save() self.assertEqual(self.instance.name_changed, self.created) - def test_save_changed_to_Jose(self): + def test_save_changed_to_Jose(self) -> None: with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)): self.instance.name = 'Jose' self.instance.save() self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)) - def test_save_changed_to_Maria(self): + def test_save_changed_to_Maria(self) -> None: with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)): self.instance.name = 'Maria' self.instance.save() self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)) - def test_save_changed_to_Pedro(self): + def test_save_changed_to_Pedro(self) -> None: self.instance.name = 'Pedro' self.instance.save() self.assertEqual(self.instance.name_changed, self.created) - def test_double_save(self): + def test_double_save(self) -> None: self.instance.name = 'Jose' self.instance.save() changed = self.instance.name_changed @@ -87,20 +89,20 @@ class MonitorWhenEmptyFieldTests(TestCase): """ Monitor should never be updated id when is an empty list. """ - def setUp(self): + def setUp(self) -> None: self.instance = MonitorWhenEmpty(name='Charlie') self.created = self.instance.name_changed - def test_save_no_change(self): + def test_save_no_change(self) -> None: self.instance.save() self.assertEqual(self.instance.name_changed, self.created) - def test_save_changed_to_Jose(self): + def test_save_changed_to_Jose(self) -> None: self.instance.name = 'Jose' self.instance.save() self.assertEqual(self.instance.name_changed, self.created) - def test_save_changed_to_Maria(self): + def test_save_changed_to_Maria(self) -> None: self.instance.name = 'Maria' self.instance.save() self.assertEqual(self.instance.name_changed, self.created) @@ -108,18 +110,18 @@ def test_save_changed_to_Maria(self): class MonitorDoubleFieldTests(TestCase): - def setUp(self): + def setUp(self) -> None: DoubleMonitored.objects.create(name='Charlie', name2='Charlie2') - def test_recursion_error_with_only(self): + def test_recursion_error_with_only(self) -> None: # Any field passed to only() is generating a recursion error list(DoubleMonitored.objects.only('id')) - def test_recursion_error_with_defer(self): + def test_recursion_error_with_defer(self) -> None: # Only monitored fields passed to defer() are failing list(DoubleMonitored.objects.defer('name')) - def test_monitor_still_works_with_deferred_fields_filtered_out_of_save_initial(self): + def test_monitor_still_works_with_deferred_fields_filtered_out_of_save_initial(self) -> None: obj = DoubleMonitored.objects.defer('name').get(name='Charlie') with time_machine.travel(datetime(2016, 12, 1, tzinfo=timezone.utc)): obj.name = 'Charlie2' diff --git a/tests/test_fields/test_split_field.py b/tests/test_fields/test_split_field.py index 6028f895..bc94f48a 100644 --- a/tests/test_fields/test_split_field.py +++ b/tests/test_fields/test_split_field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import Article, SplitFieldAbstractParent @@ -7,62 +9,62 @@ class SplitFieldTests(TestCase): full_text = 'summary\n\n\n\nmore' excerpt = 'summary\n' - def setUp(self): + def setUp(self) -> None: self.post = Article.objects.create( title='example post', body=self.full_text) - def test_unicode_content(self): + def test_unicode_content(self) -> None: self.assertEqual(str(self.post.body), self.full_text) - def test_excerpt(self): + def test_excerpt(self) -> None: self.assertEqual(self.post.body.excerpt, self.excerpt) - def test_content(self): + def test_content(self) -> None: self.assertEqual(self.post.body.content, self.full_text) - def test_has_more(self): + def test_has_more(self) -> None: self.assertTrue(self.post.body.has_more) - def test_not_has_more(self): + def test_not_has_more(self) -> None: post = Article.objects.create(title='example 2', body='some text\n\nsome more\n') self.assertFalse(post.body.has_more) - def test_load_back(self): + def test_load_back(self) -> None: post = Article.objects.get(pk=self.post.pk) self.assertEqual(post.body.content, self.post.body.content) self.assertEqual(post.body.excerpt, self.post.body.excerpt) - def test_assign_to_body(self): + def test_assign_to_body(self) -> None: new_text = 'different\n\n\n\nother' self.post.body = new_text self.post.save() self.assertEqual(str(self.post.body), new_text) - def test_assign_to_content(self): + def test_assign_to_content(self) -> None: new_text = 'different\n\n\n\nother' self.post.body.content = new_text self.post.save() self.assertEqual(str(self.post.body), new_text) - def test_assign_to_excerpt(self): + def test_assign_to_excerpt(self) -> None: with self.assertRaises(AttributeError): - self.post.body.excerpt = 'this should fail' + self.post.body.excerpt = 'this should fail' # type: ignore[misc] - def test_access_via_class(self): + def test_access_via_class(self) -> None: with self.assertRaises(AttributeError): Article.body - def test_assign_splittext(self): + def test_assign_splittext(self) -> None: a = Article(title='Some Title') a.body = self.post.body self.assertEqual(a.body.excerpt, 'summary\n') - def test_value_to_string(self): + def test_value_to_string(self) -> None: f = self.post._meta.get_field('body') self.assertEqual(f.value_to_string(self.post), self.full_text) - def test_abstract_inheritance(self): + def test_abstract_inheritance(self) -> None: class Child(SplitFieldAbstractParent): pass diff --git a/tests/test_fields/test_status_field.py b/tests/test_fields/test_status_field.py index ab250c6d..fe79e11f 100644 --- a/tests/test_fields/test_status_field.py +++ b/tests/test_fields/test_status_field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from model_utils.fields import StatusField @@ -11,22 +13,22 @@ class StatusFieldTests(TestCase): - def test_status_with_default_filled(self): + def test_status_with_default_filled(self) -> None: instance = StatusFieldDefaultFilled() self.assertEqual(instance.status, instance.STATUS.yes) - def test_status_with_default_not_filled(self): + def test_status_with_default_not_filled(self) -> None: instance = StatusFieldDefaultNotFilled() self.assertEqual(instance.status, instance.STATUS.no) - def test_no_check_for_status(self): + def test_no_check_for_status(self) -> None: field = StatusField(no_check_for_status=True) # this model has no STATUS attribute, so checking for it would error field.prepare_class(Article) - def test_get_status_display(self): + def test_get_status_display(self) -> None: instance = StatusFieldDefaultFilled() self.assertEqual(instance.get_status_display(), "Yes") - def test_choices_name(self): + def test_choices_name(self) -> None: StatusFieldChoicesName() diff --git a/tests/test_fields/test_urlsafe_token_field.py b/tests/test_fields/test_urlsafe_token_field.py index 66aeb29d..72bbcda8 100644 --- a/tests/test_fields/test_urlsafe_token_field.py +++ b/tests/test_fields/test_urlsafe_token_field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock from django.db.models import NOT_PROVIDED @@ -7,41 +9,41 @@ class UrlsaftTokenFieldTests(TestCase): - def test_editable_default(self): + def test_editable_default(self) -> None: field = UrlsafeTokenField() self.assertFalse(field.editable) - def test_editable(self): + def test_editable(self) -> None: field = UrlsafeTokenField(editable=True) self.assertTrue(field.editable) - def test_max_length_default(self): + def test_max_length_default(self) -> None: field = UrlsafeTokenField() self.assertEqual(field.max_length, 128) - def test_max_length(self): + def test_max_length(self) -> None: field = UrlsafeTokenField(max_length=256) self.assertEqual(field.max_length, 256) - def test_factory_default(self): + def test_factory_default(self) -> None: field = UrlsafeTokenField() self.assertIsNone(field._factory) - def test_factory_not_callable(self): + def test_factory_not_callable(self) -> None: with self.assertRaises(TypeError): - UrlsafeTokenField(factory='INVALID') + UrlsafeTokenField(factory='INVALID') # type: ignore[arg-type] - def test_get_default(self): + def test_get_default(self) -> None: field = UrlsafeTokenField() value = field.get_default() self.assertEqual(len(value), field.max_length) - def test_get_default_with_non_default_max_length(self): + def test_get_default_with_non_default_max_length(self) -> None: field = UrlsafeTokenField(max_length=64) value = field.get_default() self.assertEqual(len(value), 64) - def test_get_default_with_factory(self): + def test_get_default_with_factory(self) -> None: token = 'SAMPLE_TOKEN' factory = Mock(return_value=token) field = UrlsafeTokenField(factory=factory) @@ -50,13 +52,13 @@ def test_get_default_with_factory(self): self.assertEqual(value, token) factory.assert_called_once_with(field.max_length) - def test_no_default_param(self): + def test_no_default_param(self) -> None: field = UrlsafeTokenField(default='DEFAULT') self.assertIs(field.default, NOT_PROVIDED) - def test_deconstruct(self): - def test_factory(): - pass + def test_deconstruct(self) -> None: + def test_factory(max_length: int) -> str: + assert False instance = UrlsafeTokenField(factory=test_factory) name, path, args, kwargs = instance.deconstruct() new_instance = UrlsafeTokenField(*args, **kwargs) diff --git a/tests/test_fields/test_uuid_field.py b/tests/test_fields/test_uuid_field.py index cc354f09..4c77aaad 100644 --- a/tests/test_fields/test_uuid_field.py +++ b/tests/test_fields/test_uuid_field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import uuid from django.core.exceptions import ValidationError @@ -8,31 +10,31 @@ class UUIDFieldTests(TestCase): - def test_uuid_version_default(self): + def test_uuid_version_default(self) -> None: instance = UUIDField() self.assertEqual(instance.default, uuid.uuid4) - def test_uuid_version_1(self): + def test_uuid_version_1(self) -> None: instance = UUIDField(version=1) self.assertEqual(instance.default, uuid.uuid1) - def test_uuid_version_2_error(self): + def test_uuid_version_2_error(self) -> None: self.assertRaises(ValidationError, UUIDField, 'version', 2) - def test_uuid_version_3(self): + def test_uuid_version_3(self) -> None: instance = UUIDField(version=3) self.assertEqual(instance.default, uuid.uuid3) - def test_uuid_version_4(self): + def test_uuid_version_4(self) -> None: instance = UUIDField(version=4) self.assertEqual(instance.default, uuid.uuid4) - def test_uuid_version_5(self): + def test_uuid_version_5(self) -> None: instance = UUIDField(version=5) self.assertEqual(instance.default, uuid.uuid5) - def test_uuid_version_bellow_min(self): + def test_uuid_version_bellow_min(self) -> None: self.assertRaises(ValidationError, UUIDField, 'version', 0) - def test_uuid_version_above_max(self): + def test_uuid_version_above_max(self) -> None: self.assertRaises(ValidationError, UUIDField, 'version', 6) diff --git a/tests/test_inheritance_iterable.py b/tests/test_inheritance_iterable.py index f4202e53..e9896c56 100644 --- a/tests/test_inheritance_iterable.py +++ b/tests/test_inheritance_iterable.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.db.models import Prefetch from django.test import TestCase @@ -5,7 +7,7 @@ class InheritanceIterableTest(TestCase): - def test_prefetch(self): + def test_prefetch(self) -> None: qs = InheritanceManagerTestChild1.objects.all().prefetch_related( Prefetch( 'normal_field', diff --git a/tests/test_managers/test_inheritance_manager.py b/tests/test_managers/test_inheritance_manager.py index 60987bfe..68e8a743 100644 --- a/tests/test_managers/test_inheritance_manager.py +++ b/tests/test_managers/test_inheritance_manager.py @@ -1,6 +1,11 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from django.db import models from django.test import TestCase +from model_utils.managers import InheritanceManager from tests.models import ( InheritanceManagerTestChild1, InheritanceManagerTestChild2, @@ -14,19 +19,22 @@ TimeFrame, ) +if TYPE_CHECKING: + from django.db.models.fields.related_descriptors import RelatedManager + class InheritanceManagerTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.child1 = InheritanceManagerTestChild1.objects.create() self.child2 = InheritanceManagerTestChild2.objects.create() self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() self.grandchild1_2 = \ InheritanceManagerTestGrandChild1_2.objects.create() - def get_manager(self): + def get_manager(self) -> InheritanceManager[InheritanceManagerTestParent]: return InheritanceManagerTestParent.objects - def test_normal(self): + def test_normal(self) -> None: children = { InheritanceManagerTestParent(pk=self.child1.pk), InheritanceManagerTestParent(pk=self.child2.pk), @@ -35,14 +43,14 @@ def test_normal(self): } self.assertEqual(set(self.get_manager().all()), children) - def test_select_all_subclasses(self): + def test_select_all_subclasses(self) -> None: children = {self.child1, self.child2} children.add(self.grandchild1) children.add(self.grandchild1_2) self.assertEqual( set(self.get_manager().select_subclasses()), children) - def test_select_subclasses_invalid_relation(self): + def test_select_subclasses_invalid_relation(self) -> None: """ If an invalid relation string is provided, we can provide the user with a list which is valid, rather than just have the select_related() @@ -52,7 +60,7 @@ def test_select_subclasses_invalid_relation(self): with self.assertRaisesRegex(ValueError, regex): self.get_manager().select_subclasses('user') - def test_select_specific_subclasses(self): + def test_select_specific_subclasses(self) -> None: children = { self.child1, InheritanceManagerTestParent(pk=self.child2.pk), @@ -67,7 +75,7 @@ def test_select_specific_subclasses(self): children, ) - def test_select_specific_grandchildren(self): + def test_select_specific_grandchildren(self) -> None: children = { InheritanceManagerTestParent(pk=self.child1.pk), InheritanceManagerTestParent(pk=self.child2.pk), @@ -83,7 +91,7 @@ def test_select_specific_grandchildren(self): children, ) - def test_children_and_grandchildren(self): + def test_children_and_grandchildren(self) -> None: children = { self.child1, InheritanceManagerTestParent(pk=self.child2.pk), @@ -100,24 +108,24 @@ def test_children_and_grandchildren(self): children, ) - def test_get_subclass(self): + def test_get_subclass(self) -> None: self.assertEqual( self.get_manager().get_subclass(pk=self.child1.pk), self.child1) - def test_get_subclass_on_queryset(self): + def test_get_subclass_on_queryset(self) -> None: self.assertEqual( self.get_manager().all().get_subclass(pk=self.child1.pk), self.child1) - def test_prior_select_related(self): + def test_prior_select_related(self) -> None: with self.assertNumQueries(1): obj = self.get_manager().select_related( "inheritancemanagertestchild1").select_subclasses( "inheritancemanagertestchild2").get(pk=self.child1.pk) obj.inheritancemanagertestchild1 - def test_manually_specifying_parent_fk_including_grandchildren(self): + def test_manually_specifying_parent_fk_including_grandchildren(self) -> None: """ given a Model which inherits from another Model, but also declares the OneToOne link manually using `related_name` and `parent_link`, @@ -148,7 +156,7 @@ def test_manually_specifying_parent_fk_including_grandchildren(self): self.assertEqual(set(results.subclasses), set(expected_related_names)) - def test_manually_specifying_parent_fk_single_subclass(self): + def test_manually_specifying_parent_fk_single_subclass(self) -> None: """ Using a string related_name when the relation is manually defined instead of implicit should still work in the same way. @@ -168,11 +176,11 @@ def test_manually_specifying_parent_fk_single_subclass(self): self.assertEqual(set(results.subclasses), set(expected_related_names)) - def test_filter_on_values_queryset(self): + def test_filter_on_values_queryset(self) -> None: queryset = InheritanceManagerTestChild1.objects.values('id').filter(pk=self.child1.pk) self.assertEqual(list(queryset), [{'id': self.child1.pk}]) - def test_values_list_on_select_subclasses(self): + def test_values_list_on_select_subclasses(self) -> None: """ Using `select_subclasses` in conjunction with `values_list()` raised an exception in `_get_sub_obj_recurse()` because the result of `values_list()` @@ -217,14 +225,14 @@ def test_values_list_on_select_subclasses(self): class InheritanceManagerUsingModelsTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.parent1 = InheritanceManagerTestParent.objects.create() self.child1 = InheritanceManagerTestChild1.objects.create() self.child2 = InheritanceManagerTestChild2.objects.create() self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create() - def test_select_subclass_by_child_model(self): + def test_select_subclass_by_child_model(self) -> None: """ Confirm that passing a child model works the same as passing the select_related manually @@ -236,7 +244,7 @@ def test_select_subclass_by_child_model(self): self.assertEqual(objs.subclasses, objsmodels.subclasses) self.assertEqual(list(objs), list(objsmodels)) - def test_select_subclass_by_grandchild_model(self): + def test_select_subclass_by_grandchild_model(self) -> None: """ Confirm that passing a grandchild model works the same as passing the select_related manually @@ -249,7 +257,7 @@ def test_select_subclass_by_grandchild_model(self): self.assertEqual(objs.subclasses, objsmodels.subclasses) self.assertEqual(list(objs), list(objsmodels)) - def test_selecting_all_subclasses_specifically_grandchildren(self): + def test_selecting_all_subclasses_specifically_grandchildren(self) -> None: """ A bare select_subclasses() should achieve the same results as doing select_subclasses and specifying all possible subclasses. @@ -266,7 +274,7 @@ def test_selecting_all_subclasses_specifically_grandchildren(self): self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) self.assertEqual(list(objs), list(objsmodels)) - def test_selecting_all_subclasses_specifically_children(self): + def test_selecting_all_subclasses_specifically_children(self) -> None: """ A bare select_subclasses() should achieve the same results as doing select_subclasses and specifying all possible subclasses. @@ -294,7 +302,7 @@ def test_selecting_all_subclasses_specifically_children(self): self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) self.assertEqual(list(objs), list(objsmodels)) - def test_select_subclass_just_self(self): + def test_select_subclass_just_self(self) -> None: """ Passing in the same model as the manager/queryset is bound against (ie: the root parent) should have no effect on the result set. @@ -310,7 +318,7 @@ def test_select_subclass_just_self(self): InheritanceManagerTestParent(pk=self.grandchild1_2.pk), ]) - def test_select_subclass_invalid_related_model(self): + def test_select_subclass_invalid_related_model(self) -> None: """ Confirming that giving a stupid model doesn't work. """ @@ -319,7 +327,7 @@ def test_select_subclass_invalid_related_model(self): InheritanceManagerTestParent.objects.select_subclasses( TimeFrame).order_by('pk') - def test_mixing_strings_and_classes_with_grandchildren(self): + def test_mixing_strings_and_classes_with_grandchildren(self) -> None: """ Given arguments consisting of both strings and model classes, ensure the right resolutions take place, accounting for the extra @@ -340,7 +348,7 @@ def test_mixing_strings_and_classes_with_grandchildren(self): ] self.assertEqual(list(objs), expecting2) - def test_mixing_strings_and_classes_with_children(self): + def test_mixing_strings_and_classes_with_children(self) -> None: """ Given arguments consisting of both strings and model classes, ensure the right resolutions take place, walking down as far as @@ -362,7 +370,7 @@ def test_mixing_strings_and_classes_with_children(self): ] self.assertEqual(list(objs), expecting2) - def test_duplications(self): + def test_duplications(self) -> None: """ Check that even if the same thing is provided as a string and a model that the right results are retrieved. @@ -379,7 +387,7 @@ def test_duplications(self): InheritanceManagerTestParent(pk=self.grandchild1_2.pk), ]) - def test_child_doesnt_accidentally_get_parent(self): + def test_child_doesnt_accidentally_get_parent(self) -> None: """ Given a Child model which also has an InheritanceManager, none of the returned objects should be Parent objects. @@ -392,7 +400,7 @@ def test_child_doesnt_accidentally_get_parent(self): InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), ], list(objs)) - def test_manually_specifying_parent_fk_only_specific_child(self): + def test_manually_specifying_parent_fk_only_specific_child(self) -> None: """ given a Model which inherits from another Model, but also declares the OneToOne link manually using `related_name` and `parent_link`, @@ -416,7 +424,7 @@ def test_manually_specifying_parent_fk_only_specific_child(self): self.assertEqual(set(results.subclasses), set(expected_related_names)) - def test_extras_descend(self): + def test_extras_descend(self) -> None: """ Ensure that extra(select=) values are copied onto sub-classes. """ @@ -425,25 +433,25 @@ def test_extras_descend(self): ) self.assertTrue(all(result.foo == (result.id + 1) for result in results)) - def test_limit_to_specific_subclass(self): + def test_limit_to_specific_subclass(self) -> None: child3 = InheritanceManagerTestChild3.objects.create() results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3) self.assertEqual([child3], list(results)) - def test_limit_to_specific_subclass_with_custom_db_column(self): + def test_limit_to_specific_subclass_with_custom_db_column(self) -> None: item = InheritanceManagerTestChild3_1.objects.create() results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3_1) self.assertEqual([item], list(results)) - def test_limit_to_specific_grandchild_class(self): + def test_limit_to_specific_grandchild_class(self) -> None: grandchild1 = InheritanceManagerTestGrandChild1.objects.get() results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestGrandChild1) self.assertEqual([grandchild1], list(results)) - def test_limit_to_child_fetches_grandchildren_as_child_class(self): + def test_limit_to_child_fetches_grandchildren_as_child_class(self) -> None: # Not sure if this is the desired behaviour...? children = InheritanceManagerTestChild1.objects.all() @@ -451,7 +459,7 @@ def test_limit_to_child_fetches_grandchildren_as_child_class(self): self.assertEqual(set(children), set(results)) - def test_can_fetch_limited_class_grandchildren(self): + def test_can_fetch_limited_class_grandchildren(self) -> None: # Not sure if this is the desired behaviour...? children = InheritanceManagerTestChild1.objects.select_subclasses() @@ -459,7 +467,7 @@ def test_can_fetch_limited_class_grandchildren(self): self.assertEqual(set(children), set(results)) - def test_selecting_multiple_instance_classes(self): + def test_selecting_multiple_instance_classes(self) -> None: child3 = InheritanceManagerTestChild3.objects.create() children1 = InheritanceManagerTestChild1.objects.all() @@ -467,7 +475,7 @@ def test_selecting_multiple_instance_classes(self): self.assertEqual(set([child3] + list(children1)), set(results)) - def test_selecting_multiple_instance_classes_including_grandchildren(self): + def test_selecting_multiple_instance_classes_including_grandchildren(self) -> None: child3 = InheritanceManagerTestChild3.objects.create() grandchild1 = InheritanceManagerTestGrandChild1.objects.get() @@ -475,7 +483,7 @@ def test_selecting_multiple_instance_classes_including_grandchildren(self): self.assertEqual({child3, grandchild1}, set(results)) - def test_select_subclasses_interaction_with_instance_of(self): + def test_select_subclasses_interaction_with_instance_of(self) -> None: child3 = InheritanceManagerTestChild3.objects.create() results = InheritanceManagerTestParent.objects.select_subclasses(InheritanceManagerTestChild1).instance_of(InheritanceManagerTestChild3) @@ -484,7 +492,7 @@ def test_select_subclasses_interaction_with_instance_of(self): class InheritanceManagerRelatedTests(InheritanceManagerTests): - def setUp(self): + def setUp(self) -> None: self.related = InheritanceManagerTestRelated.objects.create() self.child1 = InheritanceManagerTestChild1.objects.create( related=self.related) @@ -493,16 +501,16 @@ def setUp(self): self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related) self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create(related=self.related) - def get_manager(self): + def get_manager(self) -> RelatedManager[InheritanceManagerTestParent]: # type: ignore[override] return self.related.imtests - def test_get_method_with_select_subclasses(self): + def test_get_method_with_select_subclasses(self) -> None: self.assertEqual( InheritanceManagerTestParent.objects.select_subclasses().get( id=self.child1.id), self.child1) - def test_get_method_with_select_subclasses_check_for_useless_join(self): + def test_get_method_with_select_subclasses_check_for_useless_join(self) -> None: child4 = InheritanceManagerTestChild4.objects.create(related=self.related, other_onetoone=self.child1) self.assertEqual( str(InheritanceManagerTestChild4.objects.select_subclasses().filter( @@ -510,26 +518,26 @@ def test_get_method_with_select_subclasses_check_for_useless_join(self): str(InheritanceManagerTestChild4.objects.select_subclasses().select_related(None).filter( id=child4.id).query)) - def test_annotate_with_select_subclasses(self): + def test_annotate_with_select_subclasses(self) -> None: qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( models.Count('id')) self.assertEqual(qs.get(id=self.child1.id).id__count, 1) - def test_annotate_with_named_arguments_with_select_subclasses(self): + def test_annotate_with_named_arguments_with_select_subclasses(self) -> None: qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( test_count=models.Count('id')) self.assertEqual(qs.get(id=self.child1.id).test_count, 1) - def test_annotate_before_select_subclasses(self): + def test_annotate_before_select_subclasses(self) -> None: qs = InheritanceManagerTestParent.objects.annotate( models.Count('id')).select_subclasses() self.assertEqual(qs.get(id=self.child1.id).id__count, 1) - def test_annotate_with_named_arguments_before_select_subclasses(self): + def test_annotate_with_named_arguments_before_select_subclasses(self) -> None: qs = InheritanceManagerTestParent.objects.annotate( test_count=models.Count('id')).select_subclasses() self.assertEqual(qs.get(id=self.child1.id).test_count, 1) - def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self): + def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self) -> None: qs = InheritanceManagerTestParent.objects.select_subclasses() self.assertEqual(qs.subclasses, qs._clone().subclasses) diff --git a/tests/test_managers/test_join_manager.py b/tests/test_managers/test_join_manager.py index 57c1c344..44bdcfc8 100644 --- a/tests/test_managers/test_join_manager.py +++ b/tests/test_managers/test_join_manager.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import BoxJoinModel, JoinItemForeignKey class JoinManagerTest(TestCase): - def setUp(self): + def setUp(self) -> None: for i in range(20): BoxJoinModel.objects.create(name=f'name_{i}') @@ -13,24 +15,24 @@ def setUp(self): ) JoinItemForeignKey.objects.create(weight=20) - def test_self_join(self): + def test_self_join(self) -> None: a_slice = BoxJoinModel.objects.all()[0:10] with self.assertNumQueries(1): result = a_slice.join() self.assertEqual(result.count(), 10) - def test_self_join_with_where_statement(self): + def test_self_join_with_where_statement(self) -> None: qs = BoxJoinModel.objects.filter(name='name_1') result = qs.join() self.assertEqual(result.count(), 1) - def test_join_with_other_qs(self): + def test_join_with_other_qs(self) -> None: item_qs = JoinItemForeignKey.objects.filter(weight=10) boxes = BoxJoinModel.objects.all().join(qs=item_qs) self.assertEqual(boxes.count(), 1) self.assertEqual(boxes[0].name, 'name_1') - def test_reverse_join(self): + def test_reverse_join(self) -> None: box_qs = BoxJoinModel.objects.filter(name='name_1') items = JoinItemForeignKey.objects.all().join(box_qs) self.assertEqual(items.count(), 1) diff --git a/tests/test_managers/test_query_manager.py b/tests/test_managers/test_query_manager.py index 2339dbe0..03ec814f 100644 --- a/tests/test_managers/test_query_manager.py +++ b/tests/test_managers/test_query_manager.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import Post class QueryManagerTests(TestCase): - def setUp(self): + def setUp(self) -> None: data = ((True, True, 0), (True, False, 4), (False, False, 2), @@ -14,14 +16,14 @@ def setUp(self): for p, c, o in data: Post.objects.create(published=p, confirmed=c, order=o) - def test_passing_kwargs(self): + def test_passing_kwargs(self) -> None: qs = Post.public.all() self.assertEqual([p.order for p in qs], [0, 1, 4, 5]) - def test_passing_Q(self): + def test_passing_Q(self) -> None: qs = Post.public_confirmed.all() self.assertEqual([p.order for p in qs], [0, 1]) - def test_ordering(self): + def test_ordering(self) -> None: qs = Post.public_reversed.all() self.assertEqual([p.order for p in qs], [5, 4, 1, 0]) diff --git a/tests/test_managers/test_softdelete_manager.py b/tests/test_managers/test_softdelete_manager.py index aec53576..01fffc42 100644 --- a/tests/test_managers/test_softdelete_manager.py +++ b/tests/test_managers/test_softdelete_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import CustomSoftDelete @@ -5,21 +7,21 @@ class CustomSoftDeleteManagerTests(TestCase): - def test_custom_manager_empty(self): + def test_custom_manager_empty(self) -> None: qs = CustomSoftDelete.available_objects.only_read() self.assertEqual(qs.count(), 0) - def test_custom_qs_empty(self): + def test_custom_qs_empty(self) -> None: qs = CustomSoftDelete.available_objects.all().only_read() self.assertEqual(qs.count(), 0) - def test_is_read(self): + def test_is_read(self) -> None: for is_read in [True, False, True, False]: CustomSoftDelete.available_objects.create(is_read=is_read) qs = CustomSoftDelete.available_objects.only_read() self.assertEqual(qs.count(), 2) - def test_is_read_removed(self): + def test_is_read_removed(self) -> None: for is_read, is_removed in [(True, True), (True, False), (False, False), (False, True)]: CustomSoftDelete.available_objects.create(is_read=is_read, is_removed=is_removed) qs = CustomSoftDelete.available_objects.only_read() diff --git a/tests/test_managers/test_status_manager.py b/tests/test_managers/test_status_manager.py index 106881a2..a4b69b2f 100644 --- a/tests/test_managers/test_status_manager.py +++ b/tests/test_managers/test_status_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.core.exceptions import ImproperlyConfigured from django.db import models from django.test import TestCase @@ -8,10 +10,10 @@ class StatusManagerAddedTests(TestCase): - def test_manager_available(self): + def test_manager_available(self) -> None: self.assertTrue(isinstance(StatusManagerAdded.active, QueryManager)) - def test_conflict_error(self): + def test_conflict_error(self) -> None: with self.assertRaises(ImproperlyConfigured): class ErrorModel(StatusModel): STATUS = ( diff --git a/tests/test_miscellaneous.py b/tests/test_miscellaneous.py index 684d824c..0fbfecc1 100644 --- a/tests/test_miscellaneous.py +++ b/tests/test_miscellaneous.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.core.management import call_command from django.test import TestCase @@ -5,23 +7,23 @@ class MigrationsTests(TestCase): - def test_makemigrations(self): + def test_makemigrations(self) -> None: call_command('makemigrations', dry_run=True) class GetExcerptTests(TestCase): - def test_split(self): + def test_split(self) -> None: e = get_excerpt("some content\n\n\n\nsome more") self.assertEqual(e, 'some content\n') - def test_auto_split(self): + def test_auto_split(self) -> None: e = get_excerpt("para one\n\npara two\n\npara three") self.assertEqual(e, 'para one\n\npara two') - def test_middle_of_para(self): + def test_middle_of_para(self) -> None: e = get_excerpt("some text\n\nmore text") self.assertEqual(e, 'some text') - def test_middle_of_line(self): + def test_middle_of_line(self) -> None: e = get_excerpt("some text more text") self.assertEqual(e, "some text more text") diff --git a/tests/test_models/test_deferred_fields.py b/tests/test_models/test_deferred_fields.py index 7a839fab..f51e5eb7 100644 --- a/tests/test_models/test_deferred_fields.py +++ b/tests/test_models/test_deferred_fields.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import ModelWithCustomDescriptor class CustomDescriptorTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.instance = ModelWithCustomDescriptor.objects.create( custom_field='1', tracked_custom_field='1', @@ -12,7 +14,7 @@ def setUp(self): tracked_regular_field=1, ) - def test_custom_descriptor_works(self): + def test_custom_descriptor_works(self) -> None: instance = self.instance self.assertEqual(instance.custom_field, '1') self.assertEqual(instance.__dict__['custom_field'], 1) @@ -25,7 +27,7 @@ def test_custom_descriptor_works(self): self.assertEqual(instance.custom_field, '2') self.assertEqual(instance.__dict__['custom_field'], 2) - def test_deferred(self): + def test_deferred(self) -> None: instance = ModelWithCustomDescriptor.objects.only('id').get( pk=self.instance.pk) self.assertIn('custom_field', instance.get_deferred_fields()) diff --git a/tests/test_models/test_softdeletable_model.py b/tests/test_models/test_softdeletable_model.py index c2ffd54f..1f58f435 100644 --- a/tests/test_models/test_softdeletable_model.py +++ b/tests/test_models/test_softdeletable_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from django.utils.connection import ConnectionDoesNotExist @@ -5,7 +7,7 @@ class SoftDeletableModelTests(TestCase): - def test_can_only_see_not_removed_entries(self): + def test_can_only_see_not_removed_entries(self) -> None: SoftDeletable.available_objects.create(name='a', is_removed=True) SoftDeletable.available_objects.create(name='b', is_removed=False) @@ -14,7 +16,7 @@ def test_can_only_see_not_removed_entries(self): self.assertEqual(queryset.count(), 1) self.assertEqual(queryset[0].name, 'b') - def test_instance_cannot_be_fully_deleted(self): + def test_instance_cannot_be_fully_deleted(self) -> None: instance = SoftDeletable.available_objects.create(name='a') instance.delete() @@ -22,7 +24,7 @@ def test_instance_cannot_be_fully_deleted(self): self.assertEqual(SoftDeletable.available_objects.count(), 0) self.assertEqual(SoftDeletable.all_objects.count(), 1) - def test_instance_cannot_be_fully_deleted_via_queryset(self): + def test_instance_cannot_be_fully_deleted_via_queryset(self) -> None: SoftDeletable.available_objects.create(name='a') SoftDeletable.available_objects.all().delete() @@ -30,12 +32,12 @@ def test_instance_cannot_be_fully_deleted_via_queryset(self): self.assertEqual(SoftDeletable.available_objects.count(), 0) self.assertEqual(SoftDeletable.all_objects.count(), 1) - def test_delete_instance_no_connection(self): + def test_delete_instance_no_connection(self) -> None: obj = SoftDeletable.available_objects.create(name='a') self.assertRaises(ConnectionDoesNotExist, obj.delete, using='other') - def test_instance_purge(self): + def test_instance_purge(self) -> None: instance = SoftDeletable.available_objects.create(name='a') instance.delete(soft=False) @@ -43,11 +45,11 @@ def test_instance_purge(self): self.assertEqual(SoftDeletable.available_objects.count(), 0) self.assertEqual(SoftDeletable.all_objects.count(), 0) - def test_instance_purge_no_connection(self): + def test_instance_purge_no_connection(self) -> None: instance = SoftDeletable.available_objects.create(name='a') self.assertRaises(ConnectionDoesNotExist, instance.delete, using='other', soft=False) - def test_deprecation_warning(self): + def test_deprecation_warning(self) -> None: self.assertWarns(DeprecationWarning, SoftDeletable.objects.all) diff --git a/tests/test_models/test_status_model.py b/tests/test_models/test_status_model.py index 1e36a58f..9c9a4a42 100644 --- a/tests/test_models/test_status_model.py +++ b/tests/test_models/test_status_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timezone import time_machine @@ -7,12 +9,14 @@ class StatusModelTests(TestCase): - def setUp(self): + model: type[Status] | type[StatusPlainTuple] + + def setUp(self) -> None: self.model = Status self.on_hold = Status.STATUS.on_hold self.active = Status.STATUS.active - def test_created(self): + def test_created(self) -> None: with time_machine.travel(datetime(2016, 1, 1)): c1 = self.model.objects.create() self.assertTrue(c1.status_changed, datetime(2016, 1, 1)) @@ -21,7 +25,7 @@ def test_created(self): self.assertEqual(self.model.active.count(), 2) self.assertEqual(self.model.deleted.count(), 0) - def test_modification(self): + def test_modification(self) -> None: t1 = self.model.objects.create() date_created = t1.status_changed t1.status = self.on_hold @@ -37,7 +41,7 @@ def test_modification(self): t1.save() self.assertTrue(t1.status_changed > date_active_again) - def test_save_with_update_fields_overrides_status_changed_provided(self): + def test_save_with_update_fields_overrides_status_changed_provided(self) -> None: ''' Tests if the save method updated status_changed field accordingly when update_fields is used as an argument @@ -52,7 +56,7 @@ def test_save_with_update_fields_overrides_status_changed_provided(self): self.assertEqual(t1.status_changed, datetime(2020, 1, 2, tzinfo=timezone.utc)) - def test_save_with_update_fields_overrides_status_changed_not_provided(self): + def test_save_with_update_fields_overrides_status_changed_not_provided(self) -> None: ''' Tests if the save method updated status_changed field accordingly when update_fields is used as an argument @@ -69,7 +73,7 @@ def test_save_with_update_fields_overrides_status_changed_not_provided(self): class StatusModelPlainTupleTests(StatusModelTests): - def setUp(self): + def setUp(self) -> None: self.model = StatusPlainTuple self.on_hold = StatusPlainTuple.STATUS[2][0] self.active = StatusPlainTuple.STATUS[0][0] @@ -77,7 +81,7 @@ def setUp(self): class StatusModelDefaultManagerTests(TestCase): - def test_default_manager_is_not_status_model_generated_ones(self): + def test_default_manager_is_not_status_model_generated_ones(self) -> None: # Regression test for GH-251 # The logic behind order for managers seems to have changed in Django 1.10 # and affects default manager. diff --git a/tests/test_models/test_timeframed_model.py b/tests/test_models/test_timeframed_model.py index 3038828b..246b3992 100644 --- a/tests/test_models/test_timeframed_model.py +++ b/tests/test_models/test_timeframed_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timedelta from django.core.exceptions import ImproperlyConfigured @@ -10,36 +12,36 @@ class TimeFramedModelTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.now = datetime.now() - def test_not_yet_begun(self): + def test_not_yet_begun(self) -> None: TimeFrame.objects.create(start=self.now + timedelta(days=2)) self.assertEqual(TimeFrame.timeframed.count(), 0) - def test_finished(self): + def test_finished(self) -> None: TimeFrame.objects.create(end=self.now - timedelta(days=1)) self.assertEqual(TimeFrame.timeframed.count(), 0) - def test_no_end(self): + def test_no_end(self) -> None: TimeFrame.objects.create(start=self.now - timedelta(days=10)) self.assertEqual(TimeFrame.timeframed.count(), 1) - def test_no_start(self): + def test_no_start(self) -> None: TimeFrame.objects.create(end=self.now + timedelta(days=2)) self.assertEqual(TimeFrame.timeframed.count(), 1) - def test_within_range(self): + def test_within_range(self) -> None: TimeFrame.objects.create(start=self.now - timedelta(days=1), end=self.now + timedelta(days=1)) self.assertEqual(TimeFrame.timeframed.count(), 1) class TimeFrameManagerAddedTests(TestCase): - def test_manager_available(self): + def test_manager_available(self) -> None: self.assertTrue(isinstance(TimeFrameManagerAdded.timeframed, QueryManager)) - def test_conflict_error(self): + def test_conflict_error(self) -> None: with self.assertRaises(ImproperlyConfigured): class ErrorModel(TimeFramedModel): timeframed = models.BooleanField() diff --git a/tests/test_models/test_timestamped_model.py b/tests/test_models/test_timestamped_model.py index 1087da10..66979417 100644 --- a/tests/test_models/test_timestamped_model.py +++ b/tests/test_models/test_timestamped_model.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from collections.abc import Iterable from datetime import datetime, timedelta, timezone import time_machine @@ -7,19 +10,19 @@ class TimeStampedModelTests(TestCase): - def test_created(self): + def test_created(self) -> None: with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)): t1 = TimeStamp.objects.create() self.assertEqual(t1.created, datetime(2016, 1, 1, tzinfo=timezone.utc)) - def test_created_sets_modified(self): + def test_created_sets_modified(self) -> None: ''' Ensure that on creation that modified is set exactly equal to created. ''' t1 = TimeStamp.objects.create() self.assertEqual(t1.created, t1.modified) - def test_modified(self): + def test_modified(self) -> None: with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)): t1 = TimeStamp.objects.create() @@ -28,7 +31,7 @@ def test_modified(self): self.assertEqual(t1.modified, datetime(2016, 1, 2, tzinfo=timezone.utc)) - def test_overriding_created_via_object_creation_also_uses_creation_date_for_modified(self): + def test_overriding_created_via_object_creation_also_uses_creation_date_for_modified(self) -> None: """ Setting the created date when first creating an object should be permissible. @@ -38,7 +41,7 @@ def test_overriding_created_via_object_creation_also_uses_creation_date_for_modi self.assertEqual(t1.created, different_date) self.assertEqual(t1.modified, different_date) - def test_overriding_modified_via_object_creation(self): + def test_overriding_modified_via_object_creation(self) -> None: """ Setting the modified date explicitly should be possible when first creating an object, but not thereafter. @@ -48,7 +51,7 @@ def test_overriding_modified_via_object_creation(self): self.assertEqual(t1.modified, different_date) self.assertNotEqual(t1.created, different_date) - def test_overriding_created_after_object_created(self): + def test_overriding_created_after_object_created(self) -> None: """ The created date may be changed post-create """ @@ -58,7 +61,7 @@ def test_overriding_created_after_object_created(self): t1.save() self.assertEqual(t1.created, different_date) - def test_overriding_modified_after_object_created(self): + def test_overriding_modified_after_object_created(self) -> None: """ The modified date should always be updated when the object is saved, regardless of attempts to change it. @@ -69,7 +72,7 @@ def test_overriding_modified_after_object_created(self): t1.save() self.assertNotEqual(t1.modified, different_date) - def test_overrides_using_save(self): + def test_overrides_using_save(self) -> None: """ The first time an object is saved, allow modification of both created and modified fields. @@ -90,7 +93,7 @@ def test_overrides_using_save(self): self.assertNotEqual(t1.modified, different_date2) self.assertNotEqual(t1.modified, different_date) - def test_save_with_update_fields_overrides_modified_provided_within_a(self): + def test_save_with_update_fields_overrides_modified_provided_within_a(self) -> None: """ Tests if the save method updated modified field accordingly when update_fields is used as an argument @@ -111,8 +114,8 @@ def test_save_with_update_fields_overrides_modified_provided_within_a(self): t1.save(update_fields=update_fields) self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc)) - def test_save_is_skipped_for_empty_update_fields_iterable(self): - tests = ( + def test_save_is_skipped_for_empty_update_fields_iterable(self) -> None: + tests: Iterable[Iterable[str]] = ( [], # list (), # tuple set(), # set @@ -131,7 +134,7 @@ def test_save_is_skipped_for_empty_update_fields_iterable(self): self.assertEqual(t1.test_field, 0) self.assertEqual(t1.modified, datetime(2020, 1, 1, tzinfo=timezone.utc)) - def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(self): + def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(self) -> None: with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)): t1 = TimeStamp.objects.create() @@ -140,7 +143,7 @@ def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(s self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc)) - def test_model_inherit_timestampmodel_and_statusmodel(self): + def test_model_inherit_timestampmodel_and_statusmodel(self) -> None: with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)): t1 = TimeStampWithStatusModel.objects.create() diff --git a/tests/test_models/test_uuid_model.py b/tests/test_models/test_uuid_model.py index 3c7663d7..d9d71d9d 100644 --- a/tests/test_models/test_uuid_model.py +++ b/tests/test_models/test_uuid_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import CustomNotPrimaryUUIDModel, CustomUUIDModel @@ -5,13 +7,13 @@ class UUIDFieldTests(TestCase): - def test_uuid_model_with_uuid_field_as_primary_key(self): + def test_uuid_model_with_uuid_field_as_primary_key(self) -> None: instance = CustomUUIDModel() instance.save() self.assertEqual(instance.id.__class__.__name__, 'UUID') self.assertEqual(instance.id, instance.pk) - def test_uuid_model_with_uuid_field_as_not_primary_key(self): + def test_uuid_model_with_uuid_field_as_not_primary_key(self) -> None: instance = CustomNotPrimaryUUIDModel() instance.save() self.assertEqual(instance.uuid.__class__.__name__, 'UUID')