Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support returning the correct values for the different QuerySet methods when using .values() and .values_list(). #33

Merged
merged 11 commits into from
Mar 10, 2019
Merged
4 changes: 2 additions & 2 deletions django-stubs/db/models/manager.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from django.db.models.query import QuerySet

_T = TypeVar("_T", bound=Model, covariant=True)

class BaseManager(QuerySet[_T]):
class BaseManager(QuerySet[_T, _T]):
creation_counter: int = ...
auto_created: bool = ...
use_in_migrations: bool = ...
Expand All @@ -21,7 +21,7 @@ class BaseManager(QuerySet[_T]):
def _get_queryset_methods(cls, queryset_class: type) -> Dict[str, Any]: ...
def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
def db_manager(self, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> Manager: ...
def get_queryset(self) -> QuerySet[_T]: ...
def get_queryset(self) -> QuerySet[_T, _T]: ...

class Manager(BaseManager[_T]): ...

Expand Down
104 changes: 58 additions & 46 deletions django-stubs/db/models/query.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from typing import (
Any,
Dict,
Expand All @@ -13,6 +14,9 @@ from typing import (
TypeVar,
Union,
overload,
Generic,
NamedTuple,
Collection,
)

from django.db.models.base import Model
Expand Down Expand Up @@ -46,7 +50,7 @@ class FlatValuesListIterable(BaseIterable):

_T = TypeVar("_T", bound=models.Model, covariant=True)

class QuerySet(Iterable[_T], Sized):
class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
query: Query
def __init__(
self,
Expand All @@ -58,32 +62,33 @@ class QuerySet(Iterable[_T], Sized):
@classmethod
def as_manager(cls) -> Manager[Any]: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __iter__(self) -> Iterator[_Row]: ...
def __contains__(self, x: object) -> bool: ...
@overload
def __getitem__(self, i: int) -> _Row: ...
@overload
def __getitem__(self, s: slice) -> QuerySet[_T, _Row]: ...
def __bool__(self) -> bool: ...
def __class_getitem__(cls, item: Type[_T]):
pass
def __getstate__(self) -> Dict[str, Any]: ...
@overload
def __getitem__(self, k: int) -> _T: ...
@overload
def __getitem__(self, k: str) -> Any: ...
@overload
def __getitem__(self, k: slice) -> QuerySet[_T]: ...
def __and__(self, other: QuerySet) -> QuerySet: ...
def __or__(self, other: QuerySet) -> QuerySet: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
# __and__ and __or__ ignore the other QuerySet's _Row type parameter because they use the same row type as the self QuerySet.
# Technically, the other QuerySet must be of the same type _T, but _T is covariant
def __and__(self, other: QuerySet[_T, Any]) -> QuerySet[_T, _Row]: ...
def __or__(self, other: QuerySet[_T, Any]) -> QuerySet[_T, _Row]: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
def get(self, *args: Any, **kwargs: Any) -> _T: ...
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
def create(self, **kwargs: Any) -> _T: ...
def bulk_create(self, objs: Iterable[Model], batch_size: Optional[int] = ...) -> List[_T]: ...
def get_or_create(self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any) -> Tuple[_T, bool]: ...
def update_or_create(
self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any
) -> Tuple[_T, bool]: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def first(self) -> Optional[_T]: ...
def last(self) -> Optional[_T]: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def first(self) -> Optional[_Row]: ...
def last(self) -> Optional[_Row]: ...
def in_bulk(self, id_list: Iterable[Any] = ..., *, field_name: str = ...) -> Dict[Any, _T]: ...
def delete(self) -> Tuple[int, Dict[str, int]]: ...
def update(self, **kwargs: Any) -> int: ...
Expand All @@ -93,31 +98,38 @@ class QuerySet(Iterable[_T], Sized):
def raw(
self, raw_query: str, params: Any = ..., translations: Optional[Dict[str, str]] = ..., using: None = ...
) -> RawQuerySet: ...
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet: ...
def values_list(self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...) -> QuerySet: ...
# @overload
# def values_list(self, *fields: Union[str, Combinable], named: Literal[True]) -> NamedValuesListIterable: ...
# @overload
# def values_list(self, *fields: Union[str, Combinable], flat: Literal[True]) -> FlatValuesListIterable: ...
# @overload
# def values_list(self, *fields: Union[str, Combinable]) -> ValuesListIterable: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet: ...
def datetimes(self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...) -> QuerySet: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def complex_filter(self, filter_obj: Any) -> QuerySet[_T]: ...
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet[_T, Dict[str, Any]]: ...
@overload
def values_list(
self, *fields: Union[str, Combinable], flat: Literal[False] = ..., named: Literal[True]
) -> QuerySet[_T, NamedTuple]: ...
@overload
def values_list(
self, *fields: Union[str, Combinable], flat: Literal[True], named: Literal[False] = ...
) -> QuerySet[_T, Any]: ...
@overload
def values_list(
self, *fields: Union[str, Combinable], flat: Literal[False] = ..., named: Literal[False] = ...
) -> QuerySet[_T, Tuple]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet[_T, datetime.date]: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...
) -> QuerySet[_T, datetime.datetime]: ...
def none(self) -> QuerySet[_T, _Row]: ...
def all(self) -> QuerySet[_T, _Row]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ...
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ...
def complex_filter(self, filter_obj: Any) -> QuerySet[_T, _Row]: ...
def count(self) -> int: ...
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet: ...
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T, _Row]: ...
def intersection(self, *other_qs: Any) -> QuerySet[_T, _Row]: ...
def difference(self, *other_qs: Any) -> QuerySet[_T, _Row]: ...
def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet[_T, _Row]: ...
def select_related(self, *fields: Any) -> QuerySet[_T, _Row]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T, _Row]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
def extra(
self,
select: Optional[Dict[str, Any]] = ...,
Expand All @@ -126,11 +138,11 @@ class QuerySet(Iterable[_T], Sized):
tables: Optional[List[str]] = ...,
order_by: Optional[Sequence[str]] = ...,
select_params: Optional[Sequence[Any]] = ...,
) -> QuerySet[_T]: ...
def reverse(self) -> QuerySet[_T]: ...
def defer(self, *fields: Any) -> QuerySet[_T]: ...
def only(self, *fields: Any) -> QuerySet[_T]: ...
def using(self, alias: Optional[str]) -> QuerySet[_T]: ...
) -> QuerySet[_T, _Row]: ...
def reverse(self) -> QuerySet[_T, _Row]: ...
def defer(self, *fields: Any) -> QuerySet[_T, _Row]: ...
def only(self, *fields: Any) -> QuerySet[_T, _Row]: ...
def using(self, alias: Optional[str]) -> QuerySet[_T, _Row]: ...
@property
def ordered(self) -> bool: ...
@property
Expand Down Expand Up @@ -159,7 +171,7 @@ class RawQuerySet(Iterable[_T], Sized):
@overload
def __getitem__(self, k: str) -> Any: ...
@overload
def __getitem__(self, k: slice) -> QuerySet[_T]: ...
def __getitem__(self, k: slice) -> RawQuerySet[_T]: ...
@property
def columns(self) -> List[str]: ...
@property
Expand Down
4 changes: 2 additions & 2 deletions django-stubs/shortcuts.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ def redirect(

_T = TypeVar("_T", bound=Model)

def get_object_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T]], *args: Any, **kwargs: Any) -> _T: ...
def get_list_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T]], *args: Any, **kwargs: Any) -> List[_T]: ...
def get_object_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T, _T]], *args: Any, **kwargs: Any) -> _T: ...
def get_list_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T, _T]], *args: Any, **kwargs: Any) -> List[_T]: ...
def resolve_url(to: Union[Callable, Model, str], *args: Any, **kwargs: Any) -> str: ...
32 changes: 31 additions & 1 deletion mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import os
from typing import Callable, Dict, Optional, Union, cast

Expand All @@ -6,7 +8,7 @@
from mypy.options import Options
from mypy.plugin import (
AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin,
)
AnalyzeTypeContext)
from mypy.types import (
AnyType, CallableType, Instance, NoneTyp, Type, TypeOfAny, TypeType, UnionType,
)
Expand Down Expand Up @@ -80,6 +82,18 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
return ret


def set_first_generic_param_as_default_for_second(fullname: str, ctx: AnalyzeTypeContext) -> Type:
if not ctx.type.args:
return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit),
AnyType(TypeOfAny.explicit)])
args = ctx.type.args
if len(args) == 1:
args = [args[0], args[0]]

analyzed_args = [ctx.api.analyze_type(arg) for arg in args]
return ctx.api.named_type(fullname, analyzed_args)


def return_user_model_hook(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
setting_expr = helpers.get_setting_expr(api, 'AUTH_USER_MODEL')
Expand Down Expand Up @@ -266,6 +280,14 @@ def _get_current_form_bases(self) -> Dict[str, int]:
else:
return {}

def _get_current_queryset_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.QUERYSET_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node)
.setdefault('queryset_bases', {helpers.QUERYSET_CLASS_FULLNAME: 1}))
else:
return {}

def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'django.contrib.auth.get_user_model':
Expand Down Expand Up @@ -344,6 +366,14 @@ def get_attribute_hook(self, fullname: str

return extract_and_return_primary_key_of_bound_related_field_parameter

def get_type_analyze_hook(self, fullname: str
) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
queryset_bases = self._get_current_queryset_bases()
if fullname in queryset_bases:
return partial(set_first_generic_param_as_default_for_second, fullname)

return None


def plugin(version):
return DjangoPlugin
25 changes: 5 additions & 20 deletions scripts/typecheck_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,11 @@
'Argument "is_dst" to "localize" of "BaseTzInfo" has incompatible type "None"; expected "bool"'
],
'aggregation': [
'Incompatible types in assignment (expression has type "QuerySet[Any]", variable has type "List[Any]")',
'"as_sql" undefined in superclass',
'Incompatible types in assignment (expression has type "FlatValuesListIterable", '
+ 'variable has type "ValuesListIterable")',
'Incompatible type for "contact" of "Book" (got "Optional[Author]", expected "Union[Author, Combinable]")',
'Incompatible type for "publisher" of "Book" (got "Optional[Publisher]", '
+ 'expected "Union[Publisher, Combinable]")'
],
'aggregation_regress': [
'Incompatible types in assignment (expression has type "List[str]", variable has type "QuerySet[Author]")',
'Incompatible types in assignment (expression has type "FlatValuesListIterable", '
+ 'variable has type "QuerySet[Any]")',
'Too few arguments for "count" of "Sequence"'
],
'apps': [
'Incompatible types in assignment (expression has type "str", target has type "type")',
'"Callable[[bool, bool], List[Type[Model]]]" has no attribute "cache_clear"'
Expand Down Expand Up @@ -159,9 +150,6 @@
'db_typecasts': [
'"object" has no attribute "__iter__"; maybe "__str__" or "__dir__"? (not iterable)'
],
'expressions': [
'Argument 1 to "Subquery" has incompatible type "Sequence[Dict[str, Any]]"; expected "QuerySet[Any]"'
],
'from_db_value': [
'has no attribute "vendor"'
],
Expand Down Expand Up @@ -199,9 +187,9 @@
],
'get_object_or_404': [
'Argument 1 to "get_object_or_404" has incompatible type "str"; '
+ 'expected "Union[Type[<nothing>], QuerySet[<nothing>]]"',
+ 'expected "Union[Type[<nothing>], QuerySet[<nothing>, <nothing>]]"',
'Argument 1 to "get_list_or_404" has incompatible type "List[Type[Article]]"; '
+ 'expected "Union[Type[<nothing>], QuerySet[<nothing>]]"',
+ 'expected "Union[Type[<nothing>], QuerySet[<nothing>, <nothing>]]"',
'CustomClass'
],
'get_or_create': [
Expand All @@ -227,10 +215,6 @@
'many_to_one': [
'Incompatible type for "parent" of "Child" (got "None", expected "Union[Parent, Combinable]")'
],
'model_inheritance_regress': [
'Incompatible types in assignment (expression has type "List[Supplier]", '
+ 'variable has type "QuerySet[Supplier]")'
],
'model_meta': [
'"object" has no attribute "items"',
'"Field" has no attribute "many_to_many"'
Expand Down Expand Up @@ -305,7 +289,8 @@
],
'queries': [
'Incompatible types in assignment (expression has type "None", variable has type "str")',
'Invalid index type "Optional[str]" for "Dict[str, int]"; expected type "str"'
'Invalid index type "Optional[str]" for "Dict[str, int]"; expected type "str"',
'No overload variant of "values_list" of "QuerySet" matches argument types "str", "bool", "bool"',
],
'requests': [
'Incompatible types in assignment (expression has type "Dict[str, str]", variable has type "QueryDict")'
Expand All @@ -314,7 +299,7 @@
'Argument 1 to "TextIOWrapper" has incompatible type "HttpResponse"; expected "IO[bytes]"'
],
'prefetch_related': [
'Incompatible types in assignment (expression has type "List[Room]", variable has type "QuerySet[Room]")',
'Incompatible types in assignment (expression has type "List[Room]", variable has type "QuerySet[Room, Room]")',
'"None" has no attribute "__iter__"',
'has no attribute "read_by"'
],
Expand Down
Loading