diff --git a/model_utils/managers.py b/model_utils/managers.py index 1e54a005..0402c6eb 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -7,6 +7,7 @@ from django.db import connection, models from django.db.models.constants import LOOKUP_SEP from django.db.models.fields.related import OneToOneField, OneToOneRel +from django.db.models.query import ModelIterable, QuerySet from django.db.models.sql.datastructures import Join ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) @@ -16,45 +17,52 @@ from typing import Any from django.db.models.query import BaseIterable - from django.db.models.query import ModelIterable as ModelIterableGeneric - from django.db.models.query import QuerySet as QuerySetGeneric from django.db.models.sql.query import Query - ModelIterable = ModelIterableGeneric[ModelT] - QuerySet = QuerySetGeneric[ModelT] -else: - from django.db.models.query import ModelIterable, QuerySet - - -class InheritanceIterable(ModelIterable): - def __iter__(self) -> Iterator[ModelT]: - queryset = self.queryset - iter: Iterable[ModelT] = ModelIterable(queryset) - if hasattr(queryset, 'subclasses'): - extras = tuple(queryset.query.extra.keys()) - # sort the subclass names longest first, - # so with 'a' and 'a__b' it goes as deep as possible - subclasses = sorted(queryset.subclasses, key=len, reverse=True) - for obj in iter: - sub_obj = None - for s in subclasses: - assert hasattr(queryset, '_get_sub_obj_recurse') - sub_obj = queryset._get_sub_obj_recurse(obj, s) - if sub_obj: - break - if not sub_obj: - sub_obj = obj - - if hasattr(queryset, '_annotated'): - for k in queryset._annotated: - setattr(sub_obj, k, getattr(obj, k)) - - for k in extras: + +def _iter_inheritance_queryset(queryset: QuerySet[ModelT]) -> Iterator[ModelT]: + iter: Iterable[ModelT] = ModelIterable(queryset) + if hasattr(queryset, 'subclasses'): + extras = tuple(queryset.query.extra.keys()) + # sort the subclass names longest first, + # so with 'a' and 'a__b' it goes as deep as possible + subclasses = sorted(queryset.subclasses, key=len, reverse=True) + for obj in iter: + sub_obj = None + for s in subclasses: + assert hasattr(queryset, '_get_sub_obj_recurse') + sub_obj = queryset._get_sub_obj_recurse(obj, s) + if sub_obj: + break + if not sub_obj: + sub_obj = obj + + if hasattr(queryset, '_annotated'): + for k in queryset._annotated: setattr(sub_obj, k, getattr(obj, k)) - yield sub_obj - else: - yield from iter + for k in extras: + setattr(sub_obj, k, getattr(obj, k)) + + yield sub_obj + else: + yield from iter + + +if TYPE_CHECKING: + class InheritanceIterable(ModelIterable[ModelT]): + queryset: QuerySet[ModelT] + + def __init__(self, queryset: QuerySet[ModelT], *args: Any, **kwargs: Any): + ... + + def __iter__(self) -> Iterator[ModelT]: + ... + +else: + class InheritanceIterable(ModelIterable): + def __iter__(self): + return _iter_inheritance_queryset(self.queryset) class InheritanceQuerySetMixin(Generic[ModelT]):