Skip to content

Commit

Permalink
Fix type generics in InheritanceIterable
Browse files Browse the repository at this point in the history
  • Loading branch information
mthuurne committed Apr 18, 2024
1 parent d895457 commit cd2fcde
Showing 1 changed file with 43 additions and 35 deletions.
78 changes: 43 additions & 35 deletions model_utils/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from django.db import connection, models
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.related import OneToOneField, OneToOneRel
from django.db.models.query import ModelIterable, QuerySet
from django.db.models.sql.datastructures import Join

ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
Expand All @@ -16,45 +17,52 @@
from typing import Any

from django.db.models.query import BaseIterable
from django.db.models.query import ModelIterable as ModelIterableGeneric
from django.db.models.query import QuerySet as QuerySetGeneric
from django.db.models.sql.query import Query

ModelIterable = ModelIterableGeneric[ModelT]
QuerySet = QuerySetGeneric[ModelT]
else:
from django.db.models.query import ModelIterable, QuerySet


class InheritanceIterable(ModelIterable):
def __iter__(self) -> Iterator[ModelT]:
queryset = self.queryset
iter: Iterable[ModelT] = ModelIterable(queryset)
if hasattr(queryset, 'subclasses'):
extras = tuple(queryset.query.extra.keys())
# sort the subclass names longest first,
# so with 'a' and 'a__b' it goes as deep as possible
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
for obj in iter:
sub_obj = None
for s in subclasses:
assert hasattr(queryset, '_get_sub_obj_recurse')
sub_obj = queryset._get_sub_obj_recurse(obj, s)
if sub_obj:
break
if not sub_obj:
sub_obj = obj

if hasattr(queryset, '_annotated'):
for k in queryset._annotated:
setattr(sub_obj, k, getattr(obj, k))

for k in extras:

def _iter_inheritance_queryset(queryset: QuerySet[ModelT]) -> Iterator[ModelT]:
iter: Iterable[ModelT] = ModelIterable(queryset)
if hasattr(queryset, 'subclasses'):
extras = tuple(queryset.query.extra.keys())
# sort the subclass names longest first,
# so with 'a' and 'a__b' it goes as deep as possible
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
for obj in iter:
sub_obj = None
for s in subclasses:
assert hasattr(queryset, '_get_sub_obj_recurse')
sub_obj = queryset._get_sub_obj_recurse(obj, s)
if sub_obj:
break
if not sub_obj:
sub_obj = obj

if hasattr(queryset, '_annotated'):
for k in queryset._annotated:
setattr(sub_obj, k, getattr(obj, k))

yield sub_obj
else:
yield from iter
for k in extras:
setattr(sub_obj, k, getattr(obj, k))

yield sub_obj
else:
yield from iter


if TYPE_CHECKING:
class InheritanceIterable(ModelIterable[ModelT]):
queryset: QuerySet[ModelT]

def __init__(self, queryset: QuerySet[ModelT], *args: Any, **kwargs: Any):
...

def __iter__(self) -> Iterator[ModelT]:
...

else:
class InheritanceIterable(ModelIterable):
def __iter__(self):
return _iter_inheritance_queryset(self.queryset)


class InheritanceQuerySetMixin(Generic[ModelT]):
Expand Down

0 comments on commit cd2fcde

Please sign in to comment.