Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix schemas for extra actions #5992

Merged
merged 4 commits into from
Jul 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/api-guide/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ You may disable schema generation for a view by setting `schema` to `None`:
...
schema = None # Will not appear in schema

This also applies to extra actions for `ViewSet`s:

class CustomViewSet(viewsets.ModelViewSet):

@action(detail=True, schema=None)
def extra_action(self, request, pk=None):
...

---

**Note**: For full details on `SchemaGenerator` plus the `AutoSchema` and
Expand Down
8 changes: 5 additions & 3 deletions rest_framework/schemas/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def should_include_endpoint(self, path, callback):
if callback.cls.schema is None:
return False

if 'schema' in callback.initkwargs:
if callback.initkwargs['schema'] is None:
return False

if path.endswith('.{format}') or path.endswith('.{format}/'):
return False # Ignore .json style URLs.

Expand Down Expand Up @@ -365,9 +369,7 @@ def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
view = callback.cls(**getattr(callback, 'initkwargs', {}))
view.args = ()
view.kwargs = {}
view.format_kwarg = None
Expand Down
21 changes: 20 additions & 1 deletion rest_framework/schemas/inspectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
import warnings
from collections import OrderedDict
from weakref import WeakKeyDictionary

from django.db import models
from django.utils.encoding import force_text, smart_text
Expand Down Expand Up @@ -128,6 +129,10 @@ class ViewInspector(object):

Provide subclass for per-view schema generation
"""

def __init__(self):
self.instance_schemas = WeakKeyDictionary()

def __get__(self, instance, owner):
"""
Enables `ViewInspector` as a Python _Descriptor_.
Expand All @@ -144,9 +149,17 @@ def __get__(self, instance, owner):
See: https://docs.python.org/3/howto/descriptor.html for info on
descriptor usage.
"""
if instance in self.instance_schemas:
return self.instance_schemas[instance]

self.view = instance
return self

def __set__(self, instance, other):
self.instance_schemas[instance] = other
if other is not None:
other.view = instance

@property
def view(self):
"""View property."""
Expand Down Expand Up @@ -189,6 +202,7 @@ def __init__(self, manual_fields=None):
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""
super(AutoSchema, self).__init__()
if manual_fields is None:
manual_fields = []
self._manual_fields = manual_fields
Expand Down Expand Up @@ -455,6 +469,7 @@ def __init__(self, fields, description='', encoding=None):
* `fields`: list of `coreapi.Field` instances.
* `descripton`: String description for view. Optional.
"""
super(ManualSchema, self).__init__()
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
Expand All @@ -474,9 +489,13 @@ def get_link(self, path, method, base_url):
)


class DefaultSchema(object):
class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
def __get__(self, instance, owner):
result = super(DefaultSchema, self).__get__(instance, owner)
if not isinstance(result, DefaultSchema):
return result

inspector_class = api_settings.DEFAULT_SCHEMA_CLASS
assert issubclass(inspector_class, ViewInspector), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
inspector = inspector_class()
Expand Down
43 changes: 43 additions & 0 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def custom_list_action(self, request):
def custom_list_action_multiple_methods(self, request):
return super(ExampleViewSet, self).list(self, request)

@action(detail=False, schema=None)
def excluded_action(self, request):
pass

def get_serializer(self, *args, **kwargs):
assert self.request
assert self.action
Expand Down Expand Up @@ -720,6 +724,45 @@ class CustomView(APIView):
assert len(fields) == 2
assert "my_extra_field" in [f.name for f in fields]

@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_viewset_action_with_schema(self):
class CustomViewSet(GenericViewSet):
@action(detail=True, schema=AutoSchema(manual_fields=[
coreapi.Field(
"my_extra_field",
required=True,
location="path",
schema=coreschema.String()
),
]))
def extra_action(self, pk, **kwargs):
pass

router = SimpleRouter()
router.register(r'detail', CustomViewSet, base_name='detail')

generator = SchemaGenerator()
view = generator.create_view(router.urls[0].callback, 'GET')
link = view.schema.get_link('/a/url/{id}/', 'GET', '')
fields = link.fields

assert len(fields) == 2
assert "my_extra_field" in [f.name for f in fields]

@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_viewset_action_with_null_schema(self):
class CustomViewSet(GenericViewSet):
@action(detail=True, schema=None)
def extra_action(self, pk, **kwargs):
pass

router = SimpleRouter()
router.register(r'detail', CustomViewSet, base_name='detail')

generator = SchemaGenerator()
view = generator.create_view(router.urls[0].callback, 'GET')
assert view.schema is None

@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_view_with_manual_schema(self):

Expand Down