From b76f9f4d9d4f8b4ce8f0f9f0938c81faaafc25c6 Mon Sep 17 00:00:00 2001 From: Petter Friberg Date: Thu, 1 Aug 2024 00:00:00 +0200 Subject: [PATCH] Improve `get_context_object_name` on `SingleObjectMixin` and `MultipleObjectMixin` (#2298) --- django-stubs/views/generic/detail.pyi | 7 +++-- django-stubs/views/generic/list.pyi | 14 +++++++-- tests/assert_type/views/generic.py | 30 +++++++++++++++++++ tests/typecheck/views/generic/test_detail.yml | 4 +-- 4 files changed, 48 insertions(+), 7 deletions(-) create mode 100644 tests/assert_type/views/generic.py diff --git a/django-stubs/views/generic/detail.pyi b/django-stubs/views/generic/detail.pyi index a833c72cc..b2f04dc47 100644 --- a/django-stubs/views/generic/detail.pyi +++ b/django-stubs/views/generic/detail.pyi @@ -1,4 +1,4 @@ -from typing import Any, Generic, TypeVar +from typing import Any, Generic, TypeVar, overload from django.db import models from django.http import HttpRequest, HttpResponse @@ -17,7 +17,10 @@ class SingleObjectMixin(Generic[_M], ContextMixin): def get_object(self, queryset: models.query.QuerySet[_M] | None = ...) -> _M: ... def get_queryset(self) -> models.query.QuerySet[_M]: ... def get_slug_field(self) -> str: ... - def get_context_object_name(self, obj: _M) -> str | None: ... + @overload + def get_context_object_name(self, obj: _M) -> str: ... + @overload + def get_context_object_name(self, obj: Any) -> str | None: ... class BaseDetailView(SingleObjectMixin[_M], View): object: _M diff --git a/django-stubs/views/generic/list.pyi b/django-stubs/views/generic/list.pyi index 01f68b27c..785a81aa8 100644 --- a/django-stubs/views/generic/list.pyi +++ b/django-stubs/views/generic/list.pyi @@ -1,12 +1,17 @@ from collections.abc import Sequence -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Protocol, TypeVar, overload, type_check_only from django.core.paginator import Page, Paginator, _SupportsPagination from django.db.models import Model, QuerySet from django.http import HttpRequest, HttpResponse from django.views.generic.base import ContextMixin, TemplateResponseMixin, View -_M = TypeVar("_M", bound=Model, covariant=True) +_M = TypeVar("_M", bound=Model) + +@type_check_only +class _HasModel(Protocol): + @property + def model(self) -> type[Model]: ... class MultipleObjectMixin(Generic[_M], ContextMixin): allow_empty: bool @@ -34,7 +39,10 @@ class MultipleObjectMixin(Generic[_M], ContextMixin): ) -> Paginator: ... def get_paginate_orphans(self) -> int: ... def get_allow_empty(self) -> bool: ... - def get_context_object_name(self, object_list: _SupportsPagination[_M]) -> str | None: ... + @overload + def get_context_object_name(self, object_list: _HasModel) -> str: ... + @overload + def get_context_object_name(self, object_list: Any) -> str | None: ... def get_context_data( self, *, object_list: _SupportsPagination[_M] | None = ..., **kwargs: Any ) -> dict[str, Any]: ... diff --git a/tests/assert_type/views/generic.py b/tests/assert_type/views/generic.py new file mode 100644 index 000000000..6147f02d7 --- /dev/null +++ b/tests/assert_type/views/generic.py @@ -0,0 +1,30 @@ +from typing import Optional, Type + +from django.db import models +from django.views.generic.detail import SingleObjectMixin +from django.views.generic.list import ListView +from typing_extensions import assert_type + + +class MyModel(models.Model): ... + + +class MyDetailView(SingleObjectMixin[MyModel]): ... + + +detail_view = MyDetailView() +assert_type(detail_view.model, Type[MyModel]) +assert_type(detail_view.queryset, Optional[models.QuerySet[MyModel, MyModel]]) +assert_type(detail_view.get_context_object_name(MyModel()), str) +assert_type(detail_view.get_context_object_name(1), Optional[str]) + + +class MyListView(ListView[MyModel]): ... + + +list_view = MyListView() +assert_type(list_view.model, Optional[Type[MyModel]]) +assert_type(list_view.queryset, Optional[models.QuerySet[MyModel, MyModel]]) +assert_type(list_view.get_context_object_name(models.QuerySet[MyModel]()), str) +assert_type(list_view.get_context_object_name(MyModel()), Optional[str]) +assert_type(list_view.get_context_object_name(1), Optional[str]) diff --git a/tests/typecheck/views/generic/test_detail.yml b/tests/typecheck/views/generic/test_detail.yml index 8e1713db4..6fa29f5c3 100644 --- a/tests/typecheck/views/generic/test_detail.yml +++ b/tests/typecheck/views/generic/test_detail.yml @@ -37,8 +37,8 @@ def get_queryset(self) -> QuerySet[MyModel]: self.get_object(super().get_queryset()) return super().get_queryset() - custom_settings: | - INSTALLED_APPS = ('myapp',) + installed_apps: + - myapp files: - path: myapp/__init__.py - path: myapp/models.py