Skip to content

Commit

Permalink
QuerySet.annotate improvements (#398)
Browse files Browse the repository at this point in the history
* QuerySet.annotate returns self-type. Attribute access falls back to Any.

- QuerySets that have an annotated model do not report errors during .filter() when called with invalid fields.
- QuerySets that have an annotated model return ordinary dict rather than TypedDict for .values()
- QuerySets that have an annotated model return Any rather than typed Tuple for .values_list()

* Fix .annotate so it reuses existing annotated types. Fixes error in typechecking Django testsuite.

* Fix self-typecheck error

* Fix flake8

* Fix case of .values/.values_list before .annotate.

* Extra ignores for Django 2.2 tests (false positives due to tests assuming QuerySet.first() won't return None)

Fix mypy self-check.

* More tests + more precise typing in case annotate called before values_list.

Cleanup tests.

* Test and fix annotate in combination with values/values_list with no params.

* Remove line that does nothing :)

* Formatting fixes

* Address code review

* Fix quoting in tests after mypy changed things

* Use Final

* Use typing_extensions.Final

* Fixes after ValuesQuerySet -> _ValuesQuerySet refactor. Still not passing tests yet.

* Fix inheritance of _ValuesQuerySet and remove unneeded type ignores.

This allows the test
"annotate_values_or_values_list_before_or_after_annotate_broadens_type"
to pass.

* Make it possible to annotate user code with "annotated models", using PEP 583 Annotated type.

* Add docs

* Make QuerySet[_T] an external alias to _QuerySet[_T, _T].

This currently has the drawback that error messages display the internal type _QuerySet, with both type arguments.

See also discussion on #661 and #608.

Fixes #635: QuerySet methods on Managers (like .all()) now return QuerySets rather than Managers.

Address code review by @sobolevn.

* Support passing TypedDicts to WithAnnotations

* Add an example of an error to README regarding WithAnnotations + TypedDict.

* Fix runtime behavior of ValuesQuerySet alias (you can't extend Any, for example).

Fix some edge case with from_queryset after QuerySet changed to be an
alias to _QuerySet. Can't make a minimal test case as this only occurred
on a large internal codebase.

* Fix issue when using from_queryset in some cases when having an argument with a type annotation on the QuerySet.

The mypy docstring on anal_type says not to call defer() after it.
  • Loading branch information
syastrov authored Jul 23, 2021
1 parent c69e720 commit cfd69c0
Show file tree
Hide file tree
Showing 25 changed files with 860 additions and 123 deletions.
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,48 @@ def use_my_model():
return foo.xyz # Gives an error
```

### How do I annotate cases where I called QuerySet.annotate?

Django-stubs provides a special type, `django_stubs_ext.WithAnnotations[Model]`, which indicates that the `Model` has
been annotated, meaning it allows getting/setting extra attributes on the model instance.

Optionally, you can provide a `TypedDict` of these attributes,
e.g. `WithAnnotations[MyModel, MyTypedDict]`, to specify which annotated attributes are present.

Currently, the mypy plugin can recognize that specific names were passed to `QuerySet.annotate` and
include them in the type, but does not record the types of these attributes.

The knowledge of the specific annotated fields is not yet used in creating more specific types for `QuerySet`'s
`values`, `values_list`, or `filter` methods, however knowledge that the model was annotated _is_ used to create a
broader type result type for `values`/`values_list`, and to allow `filter`ing on any field.

```python
from typing import TypedDict
from django_stubs_ext import WithAnnotations
from django.db import models
from django.db.models.expressions import Value

class MyModel(models.Model):
username = models.CharField(max_length=100)


def func(m: WithAnnotations[MyModel]) -> str:
return m.asdf # OK, since the model is annotated as allowing any attribute

func(MyModel.objects.annotate(foo=Value("")).get(id=1)) # OK
func(MyModel.objects.get(id=1)) # Error, since this model will not allow access to any attribute


class MyTypedDict(TypedDict):
foo: str

def func2(m: WithAnnotations[MyModel, MyTypedDict]) -> str:
print(m.bar) # Error, since field "bar" is not in MyModel or MyTypedDict.
return m.foo # OK, since we said field "foo" was allowed

func(MyModel.objects.annotate(foo=Value("")).get(id=1)) # OK
func(MyModel.objects.annotate(bar=Value("")).get(id=1)) # Error
```

## Related projects

Expand Down
6 changes: 3 additions & 3 deletions django-stubs/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, NamedTuple
from typing import Any, Protocol

from .utils.version import get_version as get_version

Expand All @@ -7,7 +7,7 @@ __version__: str

def setup(set_prefix: bool = ...) -> None: ...

# Used by mypy_django_plugin when returning a QuerySet row that is a NamedTuple where the field names are unknown
class _NamedTupleAnyAttr(NamedTuple):
# Used internally by mypy_django_plugin.
class _AnyAttrAllowed(Protocol):
def __getattr__(self, item: str) -> Any: ...
def __setattr__(self, item: str, value: Any) -> None: ...
98 changes: 95 additions & 3 deletions django-stubs/db/models/manager.pyi
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
import datetime
from typing import (
Any,
Dict,
Generic,
Iterable,
Iterator,
List,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)

from django.db.models import Combinable
from django.db.models.base import Model
from django.db.models.query import QuerySet
from django.db.models.query import QuerySet, RawQuerySet

from django_stubs_ext import ValuesQuerySet

_T = TypeVar("_T", bound=Model, covariant=True)
_M = TypeVar("_M", bound="BaseManager")

class BaseManager(QuerySet[_T]):
class BaseManager(Generic[_T]):
creation_counter: int = ...
auto_created: bool = ...
use_in_migrations: bool = ...
Expand All @@ -24,6 +42,80 @@ class BaseManager(QuerySet[_T]):
def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
def db_manager(self: _M, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> _M: ...
def get_queryset(self) -> QuerySet[_T]: ...
# NOTE: The following methods are in common with QuerySet, but note that the use of QuerySet as a return type
# rather than a self-type (_QS), since Manager's QuerySet-like methods return QuerySets and not Managers.
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
def get(self, *args: Any, **kwargs: Any) -> _T: ...
def create(self, *args: Any, **kwargs: Any) -> _T: ...
def bulk_create(
self, objs: Iterable[_T], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
) -> List[_T]: ...
def bulk_update(self, objs: Iterable[_T], fields: Sequence[str], batch_size: Optional[int] = ...) -> None: ...
def get_or_create(self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any) -> Tuple[_T, bool]: ...
def update_or_create(
self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any
) -> Tuple[_T, bool]: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def first(self) -> Optional[_T]: ...
def last(self) -> Optional[_T]: ...
def in_bulk(self, id_list: Iterable[Any] = ..., *, field_name: str = ...) -> Dict[Any, _T]: ...
def delete(self) -> Tuple[int, Dict[str, int]]: ...
def update(self, **kwargs: Any) -> int: ...
def exists(self) -> bool: ...
def explain(self, *, format: Optional[Any] = ..., **options: Any) -> str: ...
def raw(
self,
raw_query: str,
params: Any = ...,
translations: Optional[Dict[str, str]] = ...,
using: Optional[str] = ...,
) -> RawQuerySet: ...
# The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> ValuesQuerySet[_T, Dict[str, Any]]: ...
# The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param
def values_list(
self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...
) -> ValuesQuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> ValuesQuerySet[_T, datetime.date]: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: Optional[datetime.tzinfo] = ...
) -> ValuesQuerySet[_T, datetime.datetime]: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def complex_filter(self, filter_obj: Any) -> QuerySet[_T]: ...
def count(self) -> int: ...
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
def select_for_update(
self, nowait: bool = ..., skip_locked: bool = ..., of: Sequence[str] = ..., no_key: bool = ...
) -> QuerySet[_T]: ...
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def alias(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
# extra() return type won't be supported any time soon
def extra(
self,
select: Optional[Dict[str, Any]] = ...,
where: Optional[List[str]] = ...,
params: Optional[List[Any]] = ...,
tables: Optional[List[str]] = ...,
order_by: Optional[Sequence[str]] = ...,
select_params: Optional[Sequence[Any]] = ...,
) -> QuerySet[Any]: ...
def reverse(self) -> QuerySet[_T]: ...
def defer(self, *fields: Any) -> QuerySet[_T]: ...
def only(self, *fields: Any) -> QuerySet[_T]: ...
def using(self, alias: Optional[str]) -> QuerySet[_T]: ...
@property
def ordered(self) -> bool: ...

class Manager(BaseManager[_T]): ...

Expand Down
67 changes: 24 additions & 43 deletions django-stubs/db/models/query.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ from django.db.models.query_utils import Q as Q # noqa: F401
from django.db.models.sql.query import Query, RawQuery

_T = TypeVar("_T", bound=models.Model, covariant=True)
_QS = TypeVar("_QS", bound="QuerySet")
_Row = TypeVar("_Row", covariant=True)
_QS = TypeVar("_QS", bound="_QuerySet")

class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
class _QuerySet(Generic[_T, _Row], Collection[_Row], Reversible[_Row], Sized):
model: Type[_T]
query: Query
def __init__(
Expand All @@ -47,11 +48,13 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
def __class_getitem__(cls: Type[_QS], item: Type[_T]) -> Type[_QS]: ...
def __getstate__(self) -> Dict[str, Any]: ...
# Technically, the other QuerySet must be of the same type _T, but _T is covariant
def __and__(self: _QS, other: QuerySet[_T]) -> _QS: ...
def __or__(self: _QS, other: QuerySet[_T]) -> _QS: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
def __and__(self: _QS, other: _QuerySet[_T, _Row]) -> _QS: ...
def __or__(self: _QS, other: _QuerySet[_T, _Row]) -> _QS: ...
# IMPORTANT: When updating any of the following methods' signatures, please ALSO modify
# the corresponding method in BaseManager.
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
def get(self, *args: Any, **kwargs: Any) -> _T: ...
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
def create(self, *args: Any, **kwargs: Any) -> _T: ...
def bulk_create(
self, objs: Iterable[_T], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
Expand All @@ -61,10 +64,10 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
def update_or_create(
self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any
) -> Tuple[_T, bool]: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def first(self) -> Optional[_T]: ...
def last(self) -> Optional[_T]: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def first(self) -> Optional[_Row]: ...
def last(self) -> Optional[_Row]: ...
def in_bulk(self, id_list: Iterable[Any] = ..., *, field_name: str = ...) -> Dict[Any, _T]: ...
def delete(self) -> Tuple[int, Dict[str, int]]: ...
def update(self, **kwargs: Any) -> int: ...
Expand All @@ -78,15 +81,15 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
using: Optional[str] = ...,
) -> RawQuerySet: ...
# The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> _ValuesQuerySet[_T, Dict[str, Any]]: ...
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> _QuerySet[_T, Dict[str, Any]]: ...
# The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param
def values_list(
self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...
) -> _ValuesQuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> _ValuesQuerySet[_T, datetime.date]: ...
) -> _QuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> _QuerySet[_T, datetime.date]: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: Optional[datetime.tzinfo] = ...
) -> _ValuesQuerySet[_T, datetime.datetime]: ...
) -> _QuerySet[_T, datetime.datetime]: ...
def none(self: _QS) -> _QS: ...
def all(self: _QS) -> _QS: ...
def filter(self: _QS, *args: Any, **kwargs: Any) -> _QS: ...
Expand All @@ -101,8 +104,7 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
) -> _QS: ...
def select_related(self: _QS, *fields: Any) -> _QS: ...
def prefetch_related(self: _QS, *lookups: Any) -> _QS: ...
# TODO: return type
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[Any]: ...
def annotate(self: _QS, *args: Any, **kwargs: Any) -> _QS: ...
def alias(self: _QS, *args: Any, **kwargs: Any) -> _QS: ...
def order_by(self: _QS, *field_names: Any) -> _QS: ...
def distinct(self: _QS, *field_names: Any) -> _QS: ...
Expand All @@ -115,7 +117,7 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
tables: Optional[List[str]] = ...,
order_by: Optional[Sequence[str]] = ...,
select_params: Optional[Sequence[Any]] = ...,
) -> QuerySet[Any]: ...
) -> _QuerySet[Any, Any]: ...
def reverse(self: _QS) -> _QS: ...
def defer(self: _QS, *fields: Any) -> _QS: ...
def only(self: _QS, *fields: Any) -> _QS: ...
Expand All @@ -125,36 +127,13 @@ class QuerySet(Generic[_T], Collection[_T], Reversible[_T], Sized):
@property
def db(self) -> str: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ...
def __iter__(self) -> Iterator[_T]: ...
def __contains__(self, x: object) -> bool: ...
@overload
def __getitem__(self, i: int) -> _T: ...
@overload
def __getitem__(self: _QS, s: slice) -> _QS: ...
def __reversed__(self) -> Iterator[_T]: ...

_Row = TypeVar("_Row", covariant=True)

class _ValuesQuerySet(Generic[_T, _Row], Collection[_Row], Reversible[_Row], QuerySet[_T], Sized): # type: ignore
def __len__(self) -> int: ...
def __contains__(self, x: object) -> bool: ...
def __iter__(self) -> Iterator[_Row]: ...
def __contains__(self, x: object) -> bool: ...
@overload
def __getitem__(self, i: int) -> _Row: ...
@overload
def __getitem__(self: _QS, s: slice) -> _QS: ... # type: ignore
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def first(self) -> Optional[_Row]: ...
def last(self) -> Optional[_Row]: ...
def distinct(self, *field_names: Any) -> _ValuesQuerySet[_T, _Row]: ...
def order_by(self, *field_names: Any) -> _ValuesQuerySet[_T, _Row]: ...
def all(self) -> _ValuesQuerySet[_T, _Row]: ...
def annotate(self, *args: Any, **kwargs: Any) -> _ValuesQuerySet[_T, Any]: ...
def filter(self, *args: Any, **kwargs: Any) -> _ValuesQuerySet[_T, _Row]: ...
def exclude(self, *args: Any, **kwargs: Any) -> _ValuesQuerySet[_T, _Row]: ...
def __getitem__(self: _QS, s: slice) -> _QS: ...
def __reversed__(self) -> Iterator[_Row]: ...

class RawQuerySet(Iterable[_T], Sized):
query: RawQuery
Expand Down Expand Up @@ -188,6 +167,8 @@ class RawQuerySet(Iterable[_T], Sized):
def resolve_model_init_order(self) -> Tuple[List[str], List[int], List[Tuple[str, int]]]: ...
def using(self, alias: Optional[str]) -> RawQuerySet[_T]: ...

QuerySet = _QuerySet[_T, _T]

class Prefetch(object):
def __init__(self, lookup: str, queryset: Optional[QuerySet] = ..., to_attr: Optional[str] = ...) -> None: ...
def __getstate__(self) -> Dict[str, Any]: ...
Expand Down
2 changes: 1 addition & 1 deletion django-stubs/views/generic/list.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ from django.db.models.query import QuerySet
from django.http import HttpRequest, HttpResponse
from django.views.generic.base import ContextMixin, TemplateResponseMixin, View

T = TypeVar("T", bound=Model)
T = TypeVar("T", bound=Model, covariant=True)

class MultipleObjectMixin(Generic[T], ContextMixin):
allow_empty: bool = ...
Expand Down
4 changes: 3 additions & 1 deletion django_stubs_ext/django_stubs_ext/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .aliases import ValuesQuerySet as ValuesQuerySet
from .annotations import Annotations as Annotations
from .annotations import WithAnnotations as WithAnnotations
from .patch import monkeypatch as monkeypatch

__all__ = ["monkeypatch", "ValuesQuerySet"]
__all__ = ["monkeypatch", "ValuesQuerySet", "WithAnnotations", "Annotations"]
8 changes: 5 additions & 3 deletions django_stubs_ext/django_stubs_ext/aliases.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import typing

if typing.TYPE_CHECKING:
from django.db.models.query import _T, _Row, _ValuesQuerySet
from django.db.models.query import _T, _QuerySet, _Row

ValuesQuerySet = _ValuesQuerySet[_T, _Row]
ValuesQuerySet = _QuerySet[_T, _Row]
else:
ValuesQuerySet = typing.Any
from django.db.models.query import QuerySet

ValuesQuerySet = QuerySet
22 changes: 22 additions & 0 deletions django_stubs_ext/django_stubs_ext/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Generic, Mapping, TypeVar

from django.db.models.base import Model
from typing_extensions import Annotated

# Really, we would like to use TypedDict as a bound, but it's not possible
_Annotations = TypeVar("_Annotations", covariant=True, bound=Mapping[str, Any])


class Annotations(Generic[_Annotations]):
"""Use as `Annotations[MyTypedDict]`"""

pass


_T = TypeVar("_T", bound=Model)

WithAnnotations = Annotated[_T, Annotations[_Annotations]]
"""Alias to make it easy to annotate the model `_T` as having annotations `_Annotations` (a `TypedDict` or `Any` if not provided).
Use as `WithAnnotations[MyModel]` or `WithAnnotations[MyModel, MyTypedDict]`.
"""
8 changes: 8 additions & 0 deletions django_stubs_ext/tests/test_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Any

from django_stubs_ext import ValuesQuerySet


def test_extends_values_queryset() -> None:
class MyQS(ValuesQuerySet[Any, Any]):
pass
Loading

0 comments on commit cfd69c0

Please sign in to comment.