Skip to content

Commit

Permalink
Add method mapping to ViewSet actions
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan P Kilby committed Jun 25, 2018
1 parent 9b64818 commit f323989
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 21 deletions.
63 changes: 62 additions & 1 deletion rest_framework/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def action(methods=None, detail=None, name=None, url_path=None, url_name=None, *
)

def decorator(func):
func.bind_to_methods = methods
func.mapping = MethodMapper(func, methods)

func.detail = detail
func.name = name if name else pretty_name(func.__name__)
func.url_path = url_path if url_path else func.__name__
Expand All @@ -156,10 +157,70 @@ def decorator(func):
'name': func.name,
'description': func.__doc__ or None
})

return func
return decorator


class MethodMapper(dict):
"""
Enables mapping HTTP methods to different ViewSet methods for a single,
logical action.
Example usage:
class MyViewSet(ViewSet):
@action(detail=False)
def example(self, request, **kwargs):
...
@example.mapping.post
def create_example(self, request, **kwargs):
...
"""

def __init__(self, action, methods):
self.action = action
for method in methods:
self[method] = self.action.__name__

def _map(self, method, func):
assert method not in self, (
"Method '%s' has already been mapped to '.%s'." % (method, self[method]))
assert func.__name__ != self.action.__name__, (
"Method mapping does not behave like the property decorator. You "
"cannot use the same method name for each mapping declaration.")

self[method] = func.__name__

return func

def get(self, func):
return self._map('get', func)

def post(self, func):
return self._map('post', func)

def put(self, func):
return self._map('put', func)

def patch(self, func):
return self._map('patch', func)

def delete(self, func):
return self._map('delete', func)

def head(self, func):
return self._map('head', func)

def options(self, func):
return self._map('options', func)

def trace(self, func):
return self._map('trace', func)


def detail_route(methods=None, **kwargs):
"""
Used to mark a method on a ViewSet that should be routed for detail requests.
Expand Down
3 changes: 1 addition & 2 deletions rest_framework/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def _get_dynamic_route(self, route, action):

return Route(
url=route.url.replace('{url_path}', url_path),
mapping={http_method: action.__name__
for http_method in action.bind_to_methods},
mapping=action.mapping,
name=route.name.replace('{url_name}', action.url_name),
detail=route.detail,
initkwargs=initkwargs,
Expand Down
2 changes: 1 addition & 1 deletion rest_framework/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


def _is_extra_action(attr):
return hasattr(attr, 'bind_to_methods')
return hasattr(attr, 'mapping')


class ViewSetMixin(object):
Expand Down
64 changes: 59 additions & 5 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_defaults(self):
def test_action(request):
"""Description"""

assert test_action.bind_to_methods == ['get']
assert test_action.mapping == {'get': 'test_action'}
assert test_action.detail is True
assert test_action.name == 'Test action'
assert test_action.url_path == 'test_action'
Expand All @@ -191,15 +191,69 @@ def test_detail_required(self):
with pytest.raises(AssertionError) as excinfo:
@action()
def test_action(request):
pass
raise NotImplementedError

assert str(excinfo.value) == "@action() missing required argument: 'detail'"

def test_method_mapping_http_methods(self):
# All HTTP methods should be mappable
@action(detail=False, methods=[])
def test_action():
raise NotImplementedError

for name in APIView.http_method_names:
def method():
raise NotImplementedError

# Python 2.x compatibility - cast __name__ to str
method.__name__ = str(name)
getattr(test_action.mapping, name)(method)

# ensure the mapping returns the correct method name
for name in APIView.http_method_names:
assert test_action.mapping[name] == name

def test_method_mapping(self):
@action(detail=False)
def test_action(request):
raise NotImplementedError

@test_action.mapping.post
def test_action_post(request):
raise NotImplementedError

# The secondary handler methods should not have the action attributes
for name in ['mapping', 'detail', 'name', 'url_path', 'url_name', 'kwargs']:
assert hasattr(test_action, name) and not hasattr(test_action_post, name)

def test_method_mapping_already_mapped(self):
@action(detail=True)
def test_action(request):
raise NotImplementedError

msg = "Method 'get' has already been mapped to '.test_action'."
with self.assertRaisesMessage(AssertionError, msg):
@test_action.mapping.get
def test_action_get(request):
raise NotImplementedError

def test_method_mapping_overwrite(self):
@action(detail=True)
def test_action():
raise NotImplementedError

msg = ("Method mapping does not behave like the property decorator. You "
"cannot use the same method name for each mapping declaration.")
with self.assertRaisesMessage(AssertionError, msg):
@test_action.mapping.post
def test_action():
raise NotImplementedError

def test_detail_route_deprecation(self):
with pytest.warns(PendingDeprecationWarning) as record:
@detail_route()
def view(request):
pass
raise NotImplementedError

assert len(record) == 1
assert str(record[0].message) == (
Expand All @@ -212,7 +266,7 @@ def test_list_route_deprecation(self):
with pytest.warns(PendingDeprecationWarning) as record:
@list_route()
def view(request):
pass
raise NotImplementedError

assert len(record) == 1
assert str(record[0].message) == (
Expand All @@ -226,7 +280,7 @@ def test_route_url_name_from_path(self):
with pytest.warns(PendingDeprecationWarning):
@list_route(url_path='foo_bar')
def view(request):
pass
raise NotImplementedError

assert view.url_path == 'foo_bar'
assert view.url_name == 'foo-bar'
34 changes: 32 additions & 2 deletions tests/test_routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.test import TestCase, override_settings
from django.urls import resolve
from django.urls import resolve, reverse

from rest_framework import permissions, serializers, viewsets
from rest_framework.compat import get_regex_pattern
Expand Down Expand Up @@ -107,8 +107,23 @@ def action1(self, request, *args, **kwargs):
def action2(self, request, *args, **kwargs):
return Response({'method': 'action2'})

@action(methods=['post'], detail=True)
def action3(self, request, pk, *args, **kwargs):
return Response({'post': pk})

@action3.mapping.delete
def action3_delete(self, request, pk, *args, **kwargs):
return Response({'delete': pk})


class TestSimpleRouter(URLPatternsTestCase, TestCase):
router = SimpleRouter()
router.register('basics', BasicViewSet, base_name='basic')

urlpatterns = [
url(r'^api/', include(router.urls)),
]

class TestSimpleRouter(TestCase):
def setUp(self):
self.router = SimpleRouter()

Expand All @@ -127,6 +142,21 @@ def test_action_routes(self):
'delete': 'action2',
}

assert routes[2].url == '^{prefix}/{lookup}/action3{trailing_slash}$'
assert routes[2].mapping == {
'post': 'action3',
'delete': 'action3_delete',
}

def test_multiple_action_handlers(self):
# Standard action
response = self.client.post(reverse('basic-action3', args=[1]))
assert response.data == {'post': '1'}

# Additional handler registered with MethodMapper
response = self.client.delete(reverse('basic-action3', args=[1]))
assert response.data == {'delete': '1'}


class TestRootView(URLPatternsTestCase, TestCase):
urlpatterns = [
Expand Down
35 changes: 25 additions & 10 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,29 +75,35 @@ def custom_action(self, request, pk):
"""
A description of custom action.
"""
return super(ExampleSerializer, self).retrieve(self, request)
raise NotImplementedError

@action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithDictField)
def custom_action_with_dict_field(self, request, pk):
"""
A custom action using a dict field in the serializer.
"""
return super(ExampleSerializer, self).retrieve(self, request)
raise NotImplementedError

@action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithListFields)
def custom_action_with_list_fields(self, request, pk):
"""
A custom action using both list field and list serializer in the serializer.
"""
return super(ExampleSerializer, self).retrieve(self, request)
raise NotImplementedError

@action(detail=False)
def custom_list_action(self, request):
return super(ExampleViewSet, self).list(self, request)
raise NotImplementedError

@action(methods=['post', 'get'], detail=False, serializer_class=EmptySerializer)
def custom_list_action_multiple_methods(self, request):
return super(ExampleViewSet, self).list(self, request)
"""Custom description."""
raise NotImplementedError

@custom_list_action_multiple_methods.mapping.delete
def custom_list_action_multiple_methods_delete(self, request):
"""Deletion description."""
raise NotImplementedError

def get_serializer(self, *args, **kwargs):
assert self.request
Expand Down Expand Up @@ -147,7 +153,8 @@ def test_anonymous_request(self):
'custom_list_action_multiple_methods': {
'read': coreapi.Link(
url='/example/custom_list_action_multiple_methods/',
action='get'
action='get',
description='Custom description.',
)
},
'read': coreapi.Link(
Expand Down Expand Up @@ -238,12 +245,19 @@ def test_authenticated_request(self):
'custom_list_action_multiple_methods': {
'read': coreapi.Link(
url='/example/custom_list_action_multiple_methods/',
action='get'
action='get',
description='Custom description.',
),
'create': coreapi.Link(
url='/example/custom_list_action_multiple_methods/',
action='post'
)
action='post',
description='Custom description.',
),
'delete': coreapi.Link(
url='/example/custom_list_action_multiple_methods/',
action='delete',
description='Deletion description.',
),
},
'update': coreapi.Link(
url='/example/{id}/',
Expand Down Expand Up @@ -526,7 +540,8 @@ def test_schema_for_regular_views(self):
'custom_list_action_multiple_methods': {
'read': coreapi.Link(
url='/example1/custom_list_action_multiple_methods/',
action='get'
action='get',
description='Custom description.',
)
},
'read': coreapi.Link(
Expand Down

0 comments on commit f323989

Please sign in to comment.