Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow @api_view wrapped functions as well. #13

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions hybridrouter/hybridrouter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Optional, Type, Union, overload
from typing import Callable, Optional, Type, Union, overload

from django.urls import include, path, re_path
from django.urls.exceptions import NoReverseMatch
Expand Down Expand Up @@ -63,18 +63,25 @@ def register(
) -> None:
...

@overload
def register(
self, prefix: str, viewset: Type[Callable], basename: Optional[str] = None
) -> None:
...

def register(
self,
prefix: str,
viewset: Union[Type[APIView], Type[ViewSetMixin]],
viewset: Union[Type[APIView], Type[ViewSetMixin], Type[Callable]],
basename: Optional[str] = None,
) -> None:
"""
Registers an APIView or ViewSet with the specified prefix.
Registers an APIView, ViewSet, or @api_view-decorated function with the specified prefix.

Args:
prefix (str): URL prefix for the view or viewset.
viewset (Type[APIView] or Type[ViewSetMixin]): The APIView or ViewSet class.
viewset (Type[APIView] or Type[ViewSetMixin] or Type[Callable]):
A class (APIView or ViewSet) or function (@api_view-decorated function).
basename (str, optional): The base name for the view or viewset. Defaults to None.
"""
if basename is None:
Expand Down Expand Up @@ -148,9 +155,13 @@ def _build_urls(self, node, prefix, urls):
viewset_urls = self._get_viewset_urls(node.view, prefix, node.basename)
urls.extend(viewset_urls)
else:
# Add the basic view with a unique name
name = f"{node.basename}"
urls.append(path(f"{prefix}", node.view.as_view(), name=name))
# Only APIView has as_view, so try that first
try:
urls.append(path(f"{prefix}", node.view.as_view(), name=name))
except AttributeError:
# That didn't work, so it must be an @api_view-decorated function.
urls.append(path(f"{prefix}", node.view, name=name))
# If this node is a nested router, include it
elif node.is_nested_router:
urls.append(
Expand Down
29 changes: 28 additions & 1 deletion hybridrouter/tests/test_hybrid_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .conftest import recevoir_test_url_resolver
from .models import Item
from .views import ItemView
from .views import ItemView, item_view
from .viewsets import ItemViewSet, SlugItemViewSet


Expand All @@ -26,6 +26,7 @@ def create_urlconf(router):
def test_register_views_and_viewsets(hybrid_router, db):
# Enregistrer des vues simples
hybrid_router.register("items-view", ItemView, basename="item-view")
hybrid_router.register("apiitems-view", item_view, basename="apiitem-view")

# Enregistrer des ViewSets
hybrid_router.register("items-set", ItemViewSet, basename="item-set")
Expand All @@ -40,10 +41,12 @@ def test_register_views_and_viewsets(hybrid_router, db):

# Vérifier que les URL sont correctement générées
view_url = reverse("item-view")
api_view_url = reverse("apiitem-view")
list_url = reverse("item-set-list")
detail_url = reverse("item-set-detail", kwargs={"pk": 1})

assert view_url == "/items-view/"
assert api_view_url == "/apiitems-view/"
assert list_url == "/items-set/"
assert detail_url == "/items-set/1/"

Expand All @@ -52,6 +55,9 @@ def test_register_views_and_viewsets(hybrid_router, db):
response = client.get(view_url)
assert response.status_code == status.HTTP_200_OK

response = client.get(api_view_url)
assert response.status_code == status.HTTP_200_OK

Item.objects.create(id=1, name="Test Item", description="Item for testing.")

response = client.get(list_url)
Expand All @@ -61,6 +67,27 @@ def test_register_views_and_viewsets(hybrid_router, db):
assert response.status_code == status.HTTP_200_OK


@override_settings()
def test_register_only_api_views(hybrid_router, db):
# Enregistrer uniquement des vues simples
hybrid_router.register("simple-view", item_view, basename="simple-view")

urlconf = create_urlconf(hybrid_router)

with override_settings(ROOT_URLCONF=urlconf):
resolver = get_resolver(urlconf)
recevoir_test_url_resolver(resolver.url_patterns)

# Vérifier que l'URL est correctement générée
view_url = reverse("simple-view")
assert view_url == "/simple-view/"

# Vérifier que la vue fonctionne correctement
client = APIClient()
response = client.get(view_url)
assert response.status_code == status.HTTP_200_OK


@override_settings()
def test_register_only_views(hybrid_router, db):
# Enregistrer uniquement des vues simples
Expand Down
8 changes: 8 additions & 0 deletions hybridrouter/tests/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from rest_framework.decorators import api_view
from rest_framework.views import APIView
from rest_framework.response import Response
from .models import Item
Expand All @@ -9,3 +10,10 @@ def get(self, request):
items = Item.objects.all()
serializer = ItemSerializer(items, many=True)
return Response(serializer.data)


@api_view(["GET"])
def item_view(request):
items = Item.objects.all()
serializer = ItemSerializer(items, many=True)
return Response(serializer.data)
Loading