Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QuerySet.annotate improvements #398

Merged
merged 26 commits into from
Jul 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
808ecd4
QuerySet.annotate returns self-type. Attribute access falls back to Any.
syastrov Jun 12, 2020
d4278f7
Fix .annotate so it reuses existing annotated types. Fixes error in t…
syastrov Jun 13, 2020
922ed03
Fix self-typecheck error
syastrov Jun 13, 2020
c31a0a6
Fix flake8
syastrov Jun 14, 2020
dad53c2
Fix case of .values/.values_list before .annotate.
syastrov Jun 14, 2020
63fcc8a
Extra ignores for Django 2.2 tests (false positives due to tests assu…
syastrov Jun 14, 2020
2a319a8
More tests + more precise typing in case annotate called before value…
syastrov Jun 14, 2020
33372f6
Test and fix annotate in combination with values/values_list with no …
syastrov Aug 1, 2020
db604d6
Remove line that does nothing :)
syastrov Aug 1, 2020
ebf9500
Formatting fixes
syastrov May 10, 2021
6118634
Merge branch 'master' into queryset-annotate
syastrov Jun 27, 2021
a0d8e38
Address code review
syastrov Jun 27, 2021
741a2d3
Fix quoting in tests after mypy changed things
syastrov Jun 27, 2021
968d90b
Use Final
syastrov Jun 28, 2021
7cc814d
Use typing_extensions.Final
syastrov Jun 28, 2021
a06ed52
Merge branch 'master' into queryset-annotate
syastrov Jul 1, 2021
eddce50
Fixes after ValuesQuerySet -> _ValuesQuerySet refactor. Still not pas…
syastrov Jul 1, 2021
4d4dcf5
Fix inheritance of _ValuesQuerySet and remove unneeded type ignores.
syastrov Jul 5, 2021
53ad2db
Merge branch 'master' into queryset-annotate
syastrov Jul 6, 2021
b8debd2
Make it possible to annotate user code with "annotated models", using…
syastrov Jul 6, 2021
751ae7f
Add docs
syastrov Jul 6, 2021
7108a0c
Make QuerySet[_T] an external alias to _QuerySet[_T, _T].
syastrov Jul 7, 2021
51f0448
Support passing TypedDicts to WithAnnotations
syastrov Jul 8, 2021
72c1dfb
Add an example of an error to README regarding WithAnnotations + Type…
syastrov Jul 17, 2021
a6a47da
Fix runtime behavior of ValuesQuerySet alias (you can't extend Any, f…
syastrov Jul 21, 2021
d3ea945
Fix issue when using from_queryset in some cases when having an argum…
syastrov Jul 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,48 @@ def use_my_model():
return foo.xyz # Gives an error
```

### How do I annotate cases where I called QuerySet.annotate?

Django-stubs provides a special type, `django_stubs_ext.WithAnnotations[Model]`, which indicates that the `Model` has
been annotated, meaning it allows getting/setting extra attributes on the model instance.

Optionally, you can provide a `TypedDict` of these attributes,
e.g. `WithAnnotations[MyModel, MyTypedDict]`, to specify which annotated attributes are present.

Currently, the mypy plugin can recognize that specific names were passed to `QuerySet.annotate` and
include them in the type, but does not record the types of these attributes.

The knowledge of the specific annotated fields is not yet used in creating more specific types for `QuerySet`'s
`values`, `values_list`, or `filter` methods, however knowledge that the model was annotated _is_ used to create a
broader type result type for `values`/`values_list`, and to allow `filter`ing on any field.

```python
from typing import TypedDict
from django_stubs_ext import WithAnnotations
from django.db import models
from django.db.models.expressions import Value

class MyModel(models.Model):
username = models.CharField(max_length=100)


def func(m: WithAnnotations[MyModel]) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be possible for us to later add something like WithAnnotations[Model, TypedDict] and not break things for users? Double checking 🙂

return m.asdf # OK, since the model is annotated as allowing any attribute

func(MyModel.objects.annotate(foo=Value("")).get(id=1)) # OK
func(MyModel.objects.get(id=1)) # Error, since this model will not allow access to any attribute


class MyTypedDict(TypedDict):
foo: str

def func2(m: WithAnnotations[MyModel, MyTypedDict]) -> str:
print(m.bar) # Error, since field "bar" is not in MyModel or MyTypedDict.
return m.foo # OK, since we said field "foo" was allowed
syastrov marked this conversation as resolved.
Show resolved Hide resolved

func(MyModel.objects.annotate(foo=Value("")).get(id=1)) # OK
func(MyModel.objects.annotate(bar=Value("")).get(id=1)) # Error
```

## Related projects

Expand Down
6 changes: 3 additions & 3 deletions django-stubs/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, NamedTuple
from typing import Any, Protocol

from .utils.version import get_version as get_version

Expand All @@ -7,7 +7,7 @@ __version__: str

def setup(set_prefix: bool = ...) -> None: ...

# Used by mypy_django_plugin when returning a QuerySet row that is a NamedTuple where the field names are unknown
class _NamedTupleAnyAttr(NamedTuple):
# Used internally by mypy_django_plugin.
class _AnyAttrAllowed(Protocol):
def __getattr__(self, item: str) -> Any: ...
def __setattr__(self, item: str, value: Any) -> None: ...
98 changes: 95 additions & 3 deletions django-stubs/db/models/manager.pyi
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
import datetime
from typing import (
Any,
Dict,
Generic,
Iterable,
Iterator,
List,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)

from django.db.models import Combinable
from django.db.models.base import Model
from django.db.models.query import QuerySet
from django.db.models.query import QuerySet, RawQuerySet

from django_stubs_ext import ValuesQuerySet

_T = TypeVar("_T", bound=Model, covariant=True)
_M = TypeVar("_M", bound="BaseManager")

class BaseManager(QuerySet[_T]):
class BaseManager(Generic[_T]):
creation_counter: int = ...
auto_created: bool = ...
use_in_migrations: bool = ...
Expand All @@ -24,6 +42,80 @@ class BaseManager(QuerySet[_T]):
def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
def db_manager(self: _M, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> _M: ...
def get_queryset(self) -> QuerySet[_T]: ...
# NOTE: The following methods are in common with QuerySet, but note that the use of QuerySet as a return type
# rather than a self-type (_QS), since Manager's QuerySet-like methods return QuerySets and not Managers.
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
def get(self, *args: Any, **kwargs: Any) -> _T: ...
def create(self, *args: Any, **kwargs: Any) -> _T: ...
def bulk_create(
self, objs: Iterable[_T], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
) -> List[_T]: ...
def bulk_update(self, objs: Iterable[_T], fields: Sequence[str], batch_size: Optional[int] = ...) -> None: ...
def get_or_create(self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any) -> Tuple[_T, bool]: ...
def update_or_create(
self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any
) -> Tuple[_T, bool]: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def first(self) -> Optional[_T]: ...
def last(self) -> Optional[_T]: ...
def in_bulk(self, id_list: Iterable[Any] = ..., *, field_name: str = ...) -> Dict[Any, _T]: ...
def delete(self) -> Tuple[int, Dict[str, int]]: ...
def update(self, **kwargs: Any) -> int: ...
def exists(self) -> bool: ...
def explain(self, *, format: Optional[Any] = ..., **options: Any) -> str: ...
def raw(
self,
raw_query: str,
params: Any = ...,
translations: Optional[Dict[str, str]] = ...,
using: Optional[str] = ...,
) -> RawQuerySet: ...
# The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> ValuesQuerySet[_T, Dict[str, Any]]: ...
# The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param
def values_list(
self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...
) -> ValuesQuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> ValuesQuerySet[_T, datetime.date]: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: Optional[datetime.tzinfo] = ...
) -> ValuesQuerySet[_T, datetime.datetime]: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def complex_filter(self, filter_obj: Any) -> QuerySet[_T]: ...
def count(self) -> int: ...
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
def select_for_update(
self, nowait: bool = ..., skip_locked: bool = ..., of: Sequence[str] = ..., no_key: bool = ...
) -> QuerySet[_T]: ...
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def alias(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
# extra() return type won't be supported any time soon
def extra(
self,
select: Optional[Dict[str, Any]] = ...,
where: Optional[List[str]] = ...,
params: Optional[List[Any]] = ...,
tables: Optional[List[str]] = ...,
order_by: Optional[Sequence[str]] = ...,
select_params: Optional[Sequence[Any]] = ...,
) -> QuerySet[Any]: ...
def reverse(self) -> QuerySet[_T]: ...
def defer(self, *fields: Any) -> QuerySet[_T]: ...
def only(self, *fields: Any) -> QuerySet[_T]: ...
def using(self, alias: Optional[str]) -> QuerySet[_T]: ...
@property
def ordered(self) -> bool: ...

class Manager(BaseManager[_T]): ...

Expand Down
67 changes: 24 additions & 43 deletions django-stubs/db/models/query.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ from django.db.models.query_utils import Q as Q # noqa: F401
from django.db.models.sql.query import Query, RawQuery

_T = TypeVar("_T", bound=models.Model, covariant=True)
_QS = TypeVar("_QS", bound="QuerySet")
_Row = TypeVar("_Row", covariant=True)
_QS = TypeVar("_QS", bound="_QuerySet")

class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
class _QuerySet(Generic[_T, _Row], Collection[_Row], Reversible[_Row], Sized):
model: Type[_T]
query: Query
def __init__(
Expand All @@ -47,11 +48,13 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
def __class_getitem__(cls: Type[_QS], item: Type[_T]) -> Type[_QS]: ...
def __getstate__(self) -> Dict[str, Any]: ...
# Technically, the other QuerySet must be of the same type _T, but _T is covariant
def __and__(self: _QS, other: QuerySet[_T]) -> _QS: ...
def __or__(self: _QS, other: QuerySet[_T]) -> _QS: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
def __and__(self: _QS, other: _QuerySet[_T, _Row]) -> _QS: ...
def __or__(self: _QS, other: _QuerySet[_T, _Row]) -> _QS: ...
# IMPORTANT: When updating any of the following methods' signatures, please ALSO modify
# the corresponding method in BaseManager.
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
def get(self, *args: Any, **kwargs: Any) -> _T: ...
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
def create(self, *args: Any, **kwargs: Any) -> _T: ...
def bulk_create(
self, objs: Iterable[_T], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
Expand All @@ -61,10 +64,10 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
def update_or_create(
self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any
) -> Tuple[_T, bool]: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def first(self) -> Optional[_T]: ...
def last(self) -> Optional[_T]: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def first(self) -> Optional[_Row]: ...
def last(self) -> Optional[_Row]: ...
def in_bulk(self, id_list: Iterable[Any] = ..., *, field_name: str = ...) -> Dict[Any, _T]: ...
def delete(self) -> Tuple[int, Dict[str, int]]: ...
def update(self, **kwargs: Any) -> int: ...
Expand All @@ -78,15 +81,15 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
using: Optional[str] = ...,
) -> RawQuerySet: ...
# The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> _ValuesQuerySet[_T, Dict[str, Any]]: ...
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> _QuerySet[_T, Dict[str, Any]]: ...
# The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param
def values_list(
self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...
) -> _ValuesQuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> _ValuesQuerySet[_T, datetime.date]: ...
) -> _QuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> _QuerySet[_T, datetime.date]: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: Optional[datetime.tzinfo] = ...
) -> _ValuesQuerySet[_T, datetime.datetime]: ...
) -> _QuerySet[_T, datetime.datetime]: ...
def none(self: _QS) -> _QS: ...
def all(self: _QS) -> _QS: ...
def filter(self: _QS, *args: Any, **kwargs: Any) -> _QS: ...
Expand All @@ -101,8 +104,7 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
) -> _QS: ...
def select_related(self: _QS, *fields: Any) -> _QS: ...
def prefetch_related(self: _QS, *lookups: Any) -> _QS: ...
# TODO: return type
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[Any]: ...
def annotate(self: _QS, *args: Any, **kwargs: Any) -> _QS: ...
def alias(self: _QS, *args: Any, **kwargs: Any) -> _QS: ...
def order_by(self: _QS, *field_names: Any) -> _QS: ...
def distinct(self: _QS, *field_names: Any) -> _QS: ...
Expand All @@ -115,7 +117,7 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
tables: Optional[List[str]] = ...,
order_by: Optional[Sequence[str]] = ...,
select_params: Optional[Sequence[Any]] = ...,
) -> QuerySet[Any]: ...
) -> _QuerySet[Any, Any]: ...
def reverse(self: _QS) -> _QS: ...
def defer(self: _QS, *fields: Any) -> _QS: ...
def only(self: _QS, *fields: Any) -> _QS: ...
Expand All @@ -125,36 +127,13 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
@property
def db(self) -> str: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ...
def __iter__(self) -> Iterator[_T]: ...
def __contains__(self, x: object) -> bool: ...
@overload
def __getitem__(self, i: int) -> _T: ...
@overload
def __getitem__(self: _QS, s: slice) -> _QS: ...
def __reversed__(self) -> Iterator[_T]: ...

_Row = TypeVar("_Row", covariant=True)

class _ValuesQuerySet(Generic[_T, _Row], Collection[_Row], Reversible[_Row], QuerySet[_T], Sized): # type: ignore
def __len__(self) -> int: ...
def __contains__(self, x: object) -> bool: ...
def __iter__(self) -> Iterator[_Row]: ...
def __contains__(self, x: object) -> bool: ...
@overload
def __getitem__(self, i: int) -> _Row: ...
@overload
def __getitem__(self: _QS, s: slice) -> _QS: ... # type: ignore
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def first(self) -> Optional[_Row]: ...
def last(self) -> Optional[_Row]: ...
def distinct(self, *field_names: Any) -> _ValuesQuerySet[_T, _Row]: ...
def order_by(self, *field_names: Any) -> _ValuesQuerySet[_T, _Row]: ...
def all(self) -> _ValuesQuerySet[_T, _Row]: ...
def annotate(self, *args: Any, **kwargs: Any) -> _ValuesQuerySet[_T, Any]: ...
def filter(self, *args: Any, **kwargs: Any) -> _ValuesQuerySet[_T, _Row]: ...
def exclude(self, *args: Any, **kwargs: Any) -> _ValuesQuerySet[_T, _Row]: ...
def __getitem__(self: _QS, s: slice) -> _QS: ...
def __reversed__(self) -> Iterator[_Row]: ...

class RawQuerySet(Iterable[_T], Sized):
query: RawQuery
Expand Down Expand Up @@ -188,6 +167,8 @@ class RawQuerySet(Iterable[_T], Sized):
def resolve_model_init_order(self) -> Tuple[List[str], List[int], List[Tuple[str, int]]]: ...
def using(self, alias: Optional[str]) -> RawQuerySet[_T]: ...

QuerySet = _QuerySet[_T, _T]

class Prefetch(object):
def __init__(self, lookup: str, queryset: Optional[QuerySet] = ..., to_attr: Optional[str] = ...) -> None: ...
def __getstate__(self) -> Dict[str, Any]: ...
Expand Down
2 changes: 1 addition & 1 deletion django-stubs/views/generic/list.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ from django.db.models.query import QuerySet
from django.http import HttpRequest, HttpResponse
from django.views.generic.base import ContextMixin, TemplateResponseMixin, View

T = TypeVar("T", bound=Model)
T = TypeVar("T", bound=Model, covariant=True)

class MultipleObjectMixin(Generic[T], ContextMixin):
allow_empty: bool = ...
Expand Down
4 changes: 3 additions & 1 deletion django_stubs_ext/django_stubs_ext/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .aliases import ValuesQuerySet as ValuesQuerySet
from .annotations import Annotations as Annotations
from .annotations import WithAnnotations as WithAnnotations
from .patch import monkeypatch as monkeypatch

__all__ = ["monkeypatch", "ValuesQuerySet"]
__all__ = ["monkeypatch", "ValuesQuerySet", "WithAnnotations", "Annotations"]
8 changes: 5 additions & 3 deletions django_stubs_ext/django_stubs_ext/aliases.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import typing

if typing.TYPE_CHECKING:
from django.db.models.query import _T, _Row, _ValuesQuerySet
from django.db.models.query import _T, _QuerySet, _Row

ValuesQuerySet = _ValuesQuerySet[_T, _Row]
ValuesQuerySet = _QuerySet[_T, _Row]
else:
ValuesQuerySet = typing.Any
from django.db.models.query import QuerySet

ValuesQuerySet = QuerySet
22 changes: 22 additions & 0 deletions django_stubs_ext/django_stubs_ext/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Generic, Mapping, TypeVar

from django.db.models.base import Model
from typing_extensions import Annotated

# Really, we would like to use TypedDict as a bound, but it's not possible
_Annotations = TypeVar("_Annotations", covariant=True, bound=Mapping[str, Any])


class Annotations(Generic[_Annotations]):
"""Use as `Annotations[MyTypedDict]`"""

pass


_T = TypeVar("_T", bound=Model)

WithAnnotations = Annotated[_T, Annotations[_Annotations]]
"""Alias to make it easy to annotate the model `_T` as having annotations `_Annotations` (a `TypedDict` or `Any` if not provided).

Use as `WithAnnotations[MyModel]` or `WithAnnotations[MyModel, MyTypedDict]`.
"""
8 changes: 8 additions & 0 deletions django_stubs_ext/tests/test_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Any

from django_stubs_ext import ValuesQuerySet


def test_extends_values_queryset() -> None:
class MyQS(ValuesQuerySet[Any, Any]):
pass
Loading