From a3a1e84dbc1c5a4934f6988210d4d9cedcd75f0b Mon Sep 17 00:00:00 2001 From: Ali Hamdan Date: Sun, 10 Sep 2023 11:20:43 +0200 Subject: [PATCH] Make managers django.contrib generic This is similar to the base model and managers and to the user models and manager. It also reflects the runtime behavior as a model's manager is always associated with the model class itself (represented by `Self`), not the parent class where it is defined. Also did some trivial fixes in the those files. --- django-stubs/contrib/admin/models.pyi | 36 +++++++++++-------- django-stubs/contrib/auth/base_user.pyi | 15 ++++---- django-stubs/contrib/auth/models.pyi | 26 ++++++++------ .../contrib/sessions/base_session.pyi | 19 ++++++---- django-stubs/contrib/sessions/models.pyi | 15 ++++++-- django-stubs/contrib/sites/managers.pyi | 6 ++-- django-stubs/contrib/sites/models.pyi | 15 ++++---- django-stubs/db/models/manager.pyi | 4 +-- 8 files changed, 82 insertions(+), 54 deletions(-) diff --git a/django-stubs/contrib/admin/models.pyi b/django-stubs/contrib/admin/models.pyi index c658498a0..4605bfe62 100644 --- a/django-stubs/contrib/admin/models.pyi +++ b/django-stubs/contrib/admin/models.pyi @@ -1,15 +1,19 @@ -from typing import Any, ClassVar +import datetime as dt +from typing import Any, ClassVar, TypeVar +from typing_extensions import Self from uuid import UUID +from django.contrib.contenttypes.models import ContentType from django.db import models -from django.db.models.base import Model ADDITION: int CHANGE: int DELETION: int -ACTION_FLAG_CHOICES: Any +ACTION_FLAG_CHOICES: list[tuple[int, str]] -class LogEntryManager(models.Manager[LogEntry]): +_LogEntryT = TypeVar("_LogEntryT", bound=LogEntry) + +class LogEntryManager(models.Manager[_LogEntryT]): def log_action( self, user_id: int, @@ -17,21 +21,23 @@ class LogEntryManager(models.Manager[LogEntry]): object_id: int | str | UUID, object_repr: str, action_flag: int, - change_message: Any = ..., - ) -> LogEntry: ... + change_message: str | list[Any] = ..., + ) -> _LogEntryT: ... class LogEntry(models.Model): - action_time: models.DateTimeField[Any] = ... - user: models.ForeignKey[Any] = ... - content_type: models.ForeignKey[Any] = ... - object_id: models.TextField[Any] = ... - object_repr: models.CharField[Any] = ... - action_flag: models.PositiveSmallIntegerField[Any] = ... - change_message: models.TextField[Any] = ... - objects: ClassVar[LogEntryManager] = ... + objects: ClassVar[LogEntryManager[Self]] # type: ignore[assignment] + + action_time: models.DateTimeField[dt.datetime] + user: models.ForeignKey[Any] + content_type: models.ForeignKey[ContentType | None] + object_id: models.TextField[str | None] + object_repr: models.CharField[str] + action_flag: models.PositiveSmallIntegerField[int] + change_message: models.TextField[str] + def is_addition(self) -> bool: ... def is_change(self) -> bool: ... def is_deletion(self) -> bool: ... def get_change_message(self) -> str: ... - def get_edited_object(self) -> Model: ... + def get_edited_object(self) -> models.Model: ... def get_admin_url(self) -> str | None: ... diff --git a/django-stubs/contrib/auth/base_user.pyi b/django-stubs/contrib/auth/base_user.pyi index 6ce0d466d..b1fdb132c 100644 --- a/django-stubs/contrib/auth/base_user.pyi +++ b/django-stubs/contrib/auth/base_user.pyi @@ -1,19 +1,20 @@ -from typing import Any, TypeVar, overload +from typing import Any, TypeVar from typing_extensions import Literal from django.db import models from django.db.models.base import Model from django.db.models.fields import BooleanField -_T = TypeVar("_T", bound=Model) +_T = TypeVar("_T") +_ModelT = TypeVar("_ModelT", bound=Model) -class BaseUserManager(models.Manager[_T]): +class BaseUserManager(models.Manager[_ModelT]): @classmethod def normalize_email(cls, email: str | None) -> str: ... def make_random_password( self, length: int = ..., allowed_chars: str = ... ) -> str: ... - def get_by_natural_key(self, username: str | None) -> _T: ... + def get_by_natural_key(self, username: str | None) -> _ModelT: ... class AbstractBaseUser(models.Model): REQUIRED_FIELDS: list[str] = ... @@ -35,8 +36,4 @@ class AbstractBaseUser(models.Model): @classmethod def get_email_field_name(cls) -> str: ... @classmethod - @overload - def normalize_username(cls, username: str) -> str: ... - @classmethod - @overload - def normalize_username(cls, username: Any) -> Any: ... + def normalize_username(cls, username: _T) -> _T: ... diff --git a/django-stubs/contrib/auth/models.pyi b/django-stubs/contrib/auth/models.pyi index df41c72d3..c733b05bc 100644 --- a/django-stubs/contrib/auth/models.pyi +++ b/django-stubs/contrib/auth/models.pyi @@ -13,35 +13,39 @@ from django.db.models.manager import EmptyManager _AnyUser = Model | AnonymousUser +_T = TypeVar("_T", bound=Model) + def update_last_login( sender: type[AbstractBaseUser], user: AbstractBaseUser, **kwargs: Any ) -> None: ... -class PermissionManager(models.Manager[Permission]): +_PermissionT = TypeVar("_PermissionT", bound=Permission) + +class PermissionManager(models.Manager[_PermissionT]): def get_by_natural_key( self, codename: str, app_label: str, model: str - ) -> Permission: ... + ) -> _PermissionT: ... class Permission(models.Model): content_type_id: int - objects: ClassVar[PermissionManager] + objects: ClassVar[PermissionManager[Self]] # type: ignore[assignment] name = models.CharField(max_length=255) content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) codename = models.CharField(max_length=100) def natural_key(self) -> tuple[str, str, str]: ... -class GroupManager(models.Manager[Group]): - def get_by_natural_key(self, name: str) -> Group: ... +_GroupT = TypeVar("_GroupT", bound=Group) + +class GroupManager(models.Manager[_GroupT]): + def get_by_natural_key(self, name: str) -> _GroupT: ... class Group(models.Model): - objects: ClassVar[GroupManager] + objects: ClassVar[GroupManager[Self]] # type: ignore[assignment] name = models.CharField(max_length=150) permissions = models.ManyToManyField[Permission, Any](Permission) - def natural_key(self) -> Any: ... - -_T = TypeVar("_T", bound=Model) + def natural_key(self) -> tuple[str, ...]: ... class UserManager(BaseUserManager[_T]): def create_user( @@ -114,9 +118,9 @@ class AnonymousUser: def set_password(self, raw_password: str) -> Any: ... def check_password(self, raw_password: str) -> Any: ... @property - def groups(self) -> EmptyManager: ... + def groups(self) -> EmptyManager[Group]: ... @property - def user_permissions(self) -> EmptyManager: ... + def user_permissions(self) -> EmptyManager[Permission]: ... def get_user_permissions(self, obj: _AnyUser | None = ...) -> set[str]: ... def get_group_permissions(self, obj: _AnyUser | None = ...) -> set[Any]: ... def get_all_permissions(self, obj: _AnyUser | None = ...) -> set[str]: ... diff --git a/django-stubs/contrib/sessions/base_session.pyi b/django-stubs/contrib/sessions/base_session.pyi index 438e0d2b9..06ec80e0f 100644 --- a/django-stubs/contrib/sessions/base_session.pyi +++ b/django-stubs/contrib/sessions/base_session.pyi @@ -1,20 +1,25 @@ from datetime import datetime -from typing import Any, ClassVar +from typing import ClassVar, TypeVar +from typing_extensions import Self from django.contrib.sessions.backends.base import SessionBase from django.db import models -class BaseSessionManager(models.Manager[Any]): +_SessionT = TypeVar("_SessionT", bound=AbstractBaseSession) + +class BaseSessionManager(models.Manager[_SessionT]): def encode(self, session_dict: dict[str, int]) -> str: ... def save( self, session_key: str, session_dict: dict[str, int], expire_date: datetime - ) -> AbstractBaseSession: ... + ) -> _SessionT: ... class AbstractBaseSession(models.Model): - expire_date: datetime - session_data: str - session_key: str - objects: ClassVar[BaseSessionManager] = ... + objects: ClassVar[BaseSessionManager[Self]] # type: ignore[assignment] + + session_key: models.CharField[str] + session_data: models.TextField[str] + expire_date: models.DateTimeField[datetime] + @classmethod def get_session_store_class(cls) -> type[SessionBase] | None: ... def get_decoded(self) -> dict[str, int]: ... diff --git a/django-stubs/contrib/sessions/models.pyi b/django-stubs/contrib/sessions/models.pyi index 08b30ef61..da876345b 100644 --- a/django-stubs/contrib/sessions/models.pyi +++ b/django-stubs/contrib/sessions/models.pyi @@ -1,4 +1,15 @@ +from typing import ClassVar, TypeVar +from typing_extensions import Self + +from django.contrib.sessions.backends.db import SessionStore from django.contrib.sessions.base_session import AbstractBaseSession, BaseSessionManager -class SessionManager(BaseSessionManager): ... -class Session(AbstractBaseSession): ... +_T = TypeVar("_T", bound=AbstractBaseSession) + +class SessionManager(BaseSessionManager[_T]): ... + +class Session(AbstractBaseSession): + objects: ClassVar[SessionManager[Self]] # type: ignore[assignment] + + @classmethod + def get_session_store_class(cls) -> type[SessionStore]: ... diff --git a/django-stubs/contrib/sites/managers.pyi b/django-stubs/contrib/sites/managers.pyi index e7bdfdbf6..bdf0e8492 100644 --- a/django-stubs/contrib/sites/managers.pyi +++ b/django-stubs/contrib/sites/managers.pyi @@ -1,6 +1,8 @@ -from typing import Any +from typing import TypeVar from django.db import models -class CurrentSiteManager(models.Manager[Any]): +_T = TypeVar("_T", bound=models.Model) + +class CurrentSiteManager(models.Manager[_T]): def __init__(self, field_name: str | None = ...) -> None: ... diff --git a/django-stubs/contrib/sites/models.pyi b/django-stubs/contrib/sites/models.pyi index e4181d66f..60f5d0a58 100644 --- a/django-stubs/contrib/sites/models.pyi +++ b/django-stubs/contrib/sites/models.pyi @@ -1,17 +1,20 @@ -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeVar +from typing_extensions import Self from django.db import models from django.http.request import HttpRequest -SITE_CACHE: Any +SITE_CACHE: dict[Any, Site] -class SiteManager(models.Manager[Site]): - def get_current(self, request: HttpRequest | None = ...) -> Site: ... +_SiteT = TypeVar("_SiteT", bound=Site) + +class SiteManager(models.Manager[_SiteT]): + def get_current(self, request: HttpRequest | None = ...) -> _SiteT: ... def clear_cache(self) -> None: ... - def get_by_natural_key(self, domain: str) -> Site: ... + def get_by_natural_key(self, domain: str) -> _SiteT: ... class Site(models.Model): - objects: ClassVar[SiteManager] + objects: ClassVar[SiteManager[Self]] # type: ignore[assignment] domain = models.CharField(max_length=100) name = models.CharField(max_length=50) diff --git a/django-stubs/db/models/manager.pyi b/django-stubs/db/models/manager.pyi index c8fedcbee..a8267ff1d 100644 --- a/django-stubs/db/models/manager.pyi +++ b/django-stubs/db/models/manager.pyi @@ -136,5 +136,5 @@ class ManagerDescriptor: self, instance: Model | None, cls: type[Model] = ... ) -> Manager[Any]: ... -class EmptyManager(Manager[Any]): - def __init__(self, model: type[Model]) -> None: ... +class EmptyManager(Manager[_T]): + def __init__(self, model: type[_T]) -> None: ...