Skip to content

Commit

Permalink
Add better support for ManyToManyField's through model
Browse files Browse the repository at this point in the history
  • Loading branch information
flaeppe committed Sep 25, 2023
1 parent a539bbe commit 6b8e4c4
Show file tree
Hide file tree
Showing 13 changed files with 678 additions and 64 deletions.
18 changes: 9 additions & 9 deletions django-stubs/db/models/fields/related.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable, Iterable, Sequence
from typing import Any, Literal, TypeVar, overload
from typing import Any, Generic, Literal, TypeVar, overload
from uuid import UUID

from django.core import validators # due to weird mypy.stubtest error
Expand All @@ -11,14 +11,14 @@ from django.db.models.fields.related_descriptors import ForwardManyToOneDescript
from django.db.models.fields.related_descriptors import ( # noqa: F401
ForwardOneToOneDescriptor as ForwardOneToOneDescriptor,
)
from django.db.models.fields.related_descriptors import ManyRelatedManager
from django.db.models.fields.related_descriptors import ManyToManyDescriptor as ManyToManyDescriptor
from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor as ReverseManyToOneDescriptor
from django.db.models.fields.related_descriptors import ReverseOneToOneDescriptor as ReverseOneToOneDescriptor
from django.db.models.fields.reverse_related import ForeignObjectRel as ForeignObjectRel # noqa: F401
from django.db.models.fields.reverse_related import ManyToManyRel as ManyToManyRel
from django.db.models.fields.reverse_related import ManyToOneRel as ManyToOneRel
from django.db.models.fields.reverse_related import OneToOneRel as OneToOneRel
from django.db.models.manager import RelatedManager
from django.db.models.query_utils import FilteredRelation, PathInfo, Q
from django.utils.functional import _StrOrPromise
from typing_extensions import Self
Expand All @@ -27,6 +27,7 @@ RECURSIVE_RELATIONSHIP_CONSTANT: Literal["self"]

def resolve_relation(scope_model: type[Model], relation: str | type[Model]) -> str | type[Model]: ...

_M = TypeVar("_M", bound=Model)
# __set__ value type
_ST = TypeVar("_ST")
# __get__ return type
Expand Down Expand Up @@ -204,10 +205,9 @@ class OneToOneField(ForeignKey[_ST, _GT]):
@overload
def __get__(self, instance: Any, owner: Any) -> Self: ...

class ManyToManyField(RelatedField[_ST, _GT]):
_pyi_private_set_type: Sequence[Any]
_pyi_private_get_type: RelatedManager[Any]
_To = TypeVar("_To", bound=Model)

class ManyToManyField(RelatedField[Any, Any], Generic[_To, _M]):
description: str
has_null_arg: bool
swappable: bool
Expand All @@ -221,12 +221,12 @@ class ManyToManyField(RelatedField[_ST, _GT]):
rel_class: type[ManyToManyRel]
def __init__(
self,
to: type[Model] | str,
to: type[_To] | str,
related_name: str | None = ...,
related_query_name: str | None = ...,
limit_choices_to: _AllLimitChoicesTo | None = ...,
symmetrical: bool | None = ...,
through: str | type[Model] | None = ...,
through: type[_M] | str | None = ...,
through_fields: tuple[str, str] | None = ...,
db_constraint: bool = ...,
db_table: str | None = ...,
Expand Down Expand Up @@ -255,10 +255,10 @@ class ManyToManyField(RelatedField[_ST, _GT]):
) -> None: ...
# class access
@overload
def __get__(self, instance: None, owner: Any) -> ManyToManyDescriptor[Self]: ...
def __get__(self, instance: None, owner: Any) -> ManyToManyDescriptor[_M]: ...
# Model instance access
@overload
def __get__(self, instance: Model, owner: Any) -> _GT: ...
def __get__(self, instance: Model, owner: Any) -> ManyRelatedManager[_To]: ...
# non-Model instances
@overload
def __get__(self, instance: Any, owner: Any) -> Self: ...
Expand Down
70 changes: 53 additions & 17 deletions django-stubs/db/models/fields/related_descriptors.pyi
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from collections.abc import Callable
from typing import Any, Generic, TypeVar, overload
from collections.abc import Callable, Iterable
from typing import Any, Generic, NoReturn, TypeVar, overload

from django.core.exceptions import ObjectDoesNotExist
from django.db.models.base import Model
from django.db.models.fields import Field
from django.db.models.fields.related import ForeignKey, RelatedField
from django.db.models.fields.related import ForeignKey, ManyToManyField, RelatedField
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
from django.db.models.manager import RelatedManager
from django.db.models.manager import BaseManager, RelatedManager
from django.db.models.query import QuerySet
from django.db.models.query_utils import DeferredAttribute
from typing_extensions import Self

_T = TypeVar("_T")
_M = TypeVar("_M", bound=Model)
_F = TypeVar("_F", bound=Field)
_From = TypeVar("_From", bound=Model)
_To = TypeVar("_To", bound=Model)
Expand Down Expand Up @@ -65,28 +66,63 @@ class ReverseOneToOneDescriptor(Generic[_From, _To]):
def __reduce__(self) -> tuple[Callable[..., Any], tuple[type[_To], str]]: ...

class ReverseManyToOneDescriptor:
"""
In the example::
class Child(Model):
parent = ForeignKey(Parent, related_name='children')
``Parent.children`` is a ``ReverseManyToOneDescriptor`` instance.
"""

rel: ManyToOneRel
field: ForeignKey
def __init__(self, rel: ManyToOneRel) -> None: ...
@property
def related_manager_cls(self) -> type[RelatedManager]: ...
def __get__(self, instance: Model | None, cls: type[Model] | None = ...) -> ReverseManyToOneDescriptor: ...
def __set__(self, instance: Model, value: list[Model]) -> Any: ...
def related_manager_cls(self) -> type[RelatedManager[Any]]: ...
@overload
def __get__(self, instance: None, cls: Any = ...) -> Self: ...
@overload
def __get__(self, instance: Model, cls: Any = ...) -> type[RelatedManager[Any]]: ...
def __set__(self, instance: Any, value: Any) -> NoReturn: ...

def create_reverse_many_to_one_manager(
superclass: type[BaseManager[_M]], rel: ManyToOneRel
) -> type[RelatedManager[_M]]: ...

def create_reverse_many_to_one_manager(superclass: type, rel: Any) -> type[RelatedManager]: ...
class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_M]):
"""
In the example::
class Pizza(Model):
toppings = ManyToManyField(Topping, related_name='pizzas')
``Pizza.toppings`` and ``Topping.pizzas`` are ``ManyToManyDescriptor``
instances.
"""

class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_F]):
field: _F # type: ignore[assignment]
# 'field' here is 'rel.field'
rel: ManyToManyRel # type: ignore[assignment]
field: ManyToManyField[Any, _M] # type: ignore[assignment]
reverse: bool
def __init__(self, rel: ManyToManyRel, reverse: bool = ...) -> None: ...
@property
def through(self) -> type[Model]: ...
def through(self) -> type[_M]: ...
@property
def related_manager_cls(self) -> type[Any]: ... # ManyRelatedManager
def related_manager_cls(self) -> type[ManyRelatedManager[Any]]: ... # type: ignore[override]

# fake
class _ForwardManyToManyManager(Generic[_T]):
def all(self) -> QuerySet: ...
class ManyRelatedManager(BaseManager[_M], Generic[_M]):
related_val: tuple[int, ...]
def add(self, *objs: _M | int, bulk: bool = ...) -> None: ...
async def aadd(self, *objs: _M | int, bulk: bool = ...) -> None: ...
def remove(self, *objs: _M | int, bulk: bool = ...) -> None: ...
async def aremove(self, *objs: _M | int, bulk: bool = ...) -> None: ...
def set(self, objs: QuerySet[_M] | Iterable[_M | int], *, bulk: bool = ..., clear: bool = ...) -> None: ...
async def aset(self, objs: QuerySet[_M] | Iterable[_M | int], *, bulk: bool = ..., clear: bool = ...) -> None: ...
def clear(self) -> None: ...
async def aclear(self) -> None: ...
def __call__(self, *, manager: str) -> ManyRelatedManager[_M]: ...

def create_forward_many_to_many_manager(superclass: type, rel: Any, reverse: Any) -> _ForwardManyToManyManager: ...
def create_forward_many_to_many_manager(
superclass: type[BaseManager[_M]], rel: ManyToManyRel, reverse: bool
) -> type[ManyRelatedManager[_M]]: ...
4 changes: 2 additions & 2 deletions django-stubs/db/models/fields/reverse_related.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ class OneToOneRel(ManyToOneRel):
) -> None: ...

class ManyToManyRel(ForeignObjectRel):
field: ManyToManyField # type: ignore[assignment]
field: ManyToManyField[Any, Any] # type: ignore[assignment]
through: type[Model] | None
through_fields: tuple[str, str] | None
db_constraint: bool
def __init__(
self,
field: ManyToManyField,
field: ManyToManyField[Any, Any],
to: type[Model] | str,
related_name: str | None = ...,
related_query_name: str | None = ...,
Expand Down
24 changes: 23 additions & 1 deletion mypy_django_plugin/django/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
from collections import defaultdict
from contextlib import contextmanager
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, Literal, Optional, Sequence, Set, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)

from django.core.exceptions import FieldDoesNotExist, FieldError
from django.db import models
Expand Down Expand Up @@ -270,6 +284,14 @@ def all_registered_model_classes(self) -> Set[Type[models.Model]]:
def all_registered_model_class_fullnames(self) -> Set[str]:
return {helpers.get_class_fullname(cls) for cls in self.all_registered_model_classes}

@cached_property
def model_class_fullnames_by_label(self) -> Mapping[str, str]:
return {
klass._meta.label: helpers.get_class_fullname(klass)
for klass in self.all_registered_model_classes
if klass is not models.Model
}

def get_field_nullability(self, field: Union["Field[Any, Any]", ForeignObjectRel], method: Optional[str]) -> bool:
if method in ("values", "values_list"):
return field.null
Expand Down
1 change: 0 additions & 1 deletion mypy_django_plugin/lib/fullnames.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
FOREIGN_OBJECT_FULLNAME,
FOREIGN_KEY_FULLNAME,
ONETOONE_FIELD_FULLNAME,
MANYTOMANY_FIELD_FULLNAME,
)
)

Expand Down
65 changes: 64 additions & 1 deletion mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from mypy.nodes import (
GDEF,
MDEF,
AssignmentStmt,
Block,
ClassDef,
Context,
Expression,
MemberExpr,
MypyFile,
Expand All @@ -33,7 +35,8 @@
SemanticAnalyzerPluginInterface,
)
from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, Instance, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType
from mypy.semanal_shared import parse_bool
from mypy.types import AnyType, Instance, LiteralType, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType
from mypy.types import Type as MypyType
from typing_extensions import TypedDict

Expand All @@ -45,12 +48,14 @@


class DjangoTypeMetadata(TypedDict, total=False):
is_abstract_model: bool
from_queryset_manager: str
reverse_managers: Dict[str, str]
baseform_bases: Dict[str, int]
manager_bases: Dict[str, int]
model_bases: Dict[str, int]
queryset_bases: Dict[str, int]
m2m_throughs: Dict[str, str]


def get_django_metadata(model_info: TypeInfo) -> DjangoTypeMetadata:
Expand Down Expand Up @@ -385,3 +390,61 @@ def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) ->
if sym is not None and isinstance(sym.node, TypeInfo):
bases = get_django_metadata_bases(sym.node, "manager_bases")
bases[fullname] = 1


def is_abstract_model(model: TypeInfo) -> bool:
if model.metaclass_type is None or model.metaclass_type.type.fullname != fullnames.MODEL_METACLASS_FULLNAME:
return False

metadata = get_django_metadata(model)
if metadata.get("is_abstract_model") is not None:
return metadata["is_abstract_model"]

meta = model.names.get("Meta")
# Check if 'abstract' is declared in this model's 'class Meta' as
# 'abstract = True' won't be inherited from a parent model.
if meta is not None and isinstance(meta.node, TypeInfo) and "abstract" in meta.node.names:
for stmt in meta.node.defn.defs.body:
if (
# abstract =
isinstance(stmt, AssignmentStmt)
and len(stmt.lvalues) == 1
and isinstance(stmt.lvalues[0], NameExpr)
and stmt.lvalues[0].name == "abstract"
):
# abstract = True (builtins.bool)
rhs_is_true = parse_bool(stmt.rvalue) is True
# abstract: Literal[True]
is_literal_true = isinstance(stmt.type, LiteralType) and stmt.type.value is True
metadata["is_abstract_model"] = rhs_is_true or is_literal_true
return metadata["is_abstract_model"]

metadata["is_abstract_model"] = False
return False


def resolve_lazy_reference(
reference: str, *, api: Union[TypeChecker, SemanticAnalyzer], django_context: "DjangoContext", ctx: Context
) -> Optional[TypeInfo]:
"""
Attempts to resolve a lazy reference(e.g. "<app_label>.<object_name>") to a
'TypeInfo' instance.
"""
if "." not in reference:
# <object_name> -- needs prefix of <app_label>. We can't implicitly solve
# what app label this should be, yet.
return None

# Reference conforms to the structure of a lazy reference: '<app_label>.<object_name>'
fullname = django_context.model_class_fullnames_by_label.get(reference)
if fullname is not None:
model_info = lookup_fully_qualified_typeinfo(api, fullname)
if model_info is not None:
return model_info
elif isinstance(api, SemanticAnalyzer) and not api.final_iteration:
# Getting this far, where Django matched the reference but we still can't
# find it, we want to defer
api.defer()
else:
api.fail("Could not match lazy reference with any model", ctx)
return None
5 changes: 5 additions & 0 deletions mypy_django_plugin/transformers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import manytomany

if TYPE_CHECKING:
from django.contrib.contenttypes.fields import GenericForeignKey
Expand Down Expand Up @@ -213,6 +214,10 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan

assert isinstance(outer_model_info, TypeInfo)

if default_return_type.type.has_base(fullnames.MANYTOMANY_FIELD_FULLNAME):
return manytomany.fill_model_args_for_many_to_many_field(
ctx=ctx, model_info=outer_model_info, default_return_type=default_return_type, django_context=django_context
)
if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
return fill_descriptor_types_for_related_field(ctx, django_context)

Expand Down
Loading

0 comments on commit 6b8e4c4

Please sign in to comment.