diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index aaefe3db8f..25ce9df0d7 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -243,6 +243,14 @@ You may disable schema generation for a view by setting `schema` to `None`: ... schema = None # Will not appear in schema +This also applies to extra actions for `ViewSet`s: + + class CustomViewSet(viewsets.ModelViewSet): + + @action(detail=True, schema=None) + def extra_action(self, request, pk=None): + ... + --- **Note**: For full details on `SchemaGenerator` plus the `AutoSchema` and diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index 629f92b0da..8794c9967c 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -218,6 +218,10 @@ def should_include_endpoint(self, path, callback): if callback.cls.schema is None: return False + if 'schema' in callback.initkwargs: + if callback.initkwargs['schema'] is None: + return False + if path.endswith('.{format}') or path.endswith('.{format}/'): return False # Ignore .json style URLs. @@ -365,9 +369,7 @@ def create_view(self, callback, method, request=None): """ Given a callback, return an actual view instance. """ - view = callback.cls() - for attr, val in getattr(callback, 'initkwargs', {}).items(): - setattr(view, attr, val) + view = callback.cls(**getattr(callback, 'initkwargs', {})) view.args = () view.kwargs = {} view.format_kwarg = None diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 89a1fc93a5..b90f60e084 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -7,6 +7,7 @@ import re import warnings from collections import OrderedDict +from weakref import WeakKeyDictionary from django.db import models from django.utils.encoding import force_text, smart_text @@ -128,6 +129,10 @@ class ViewInspector(object): Provide subclass for per-view schema generation """ + + def __init__(self): + self.instance_schemas = WeakKeyDictionary() + def __get__(self, instance, owner): """ Enables `ViewInspector` as a Python _Descriptor_. @@ -144,9 +149,17 @@ def __get__(self, instance, owner): See: https://docs.python.org/3/howto/descriptor.html for info on descriptor usage. """ + if instance in self.instance_schemas: + return self.instance_schemas[instance] + self.view = instance return self + def __set__(self, instance, other): + self.instance_schemas[instance] = other + if other is not None: + other.view = instance + @property def view(self): """View property.""" @@ -189,6 +202,7 @@ def __init__(self, manual_fields=None): * `manual_fields`: list of `coreapi.Field` instances that will be added to auto-generated fields, overwriting on `Field.name` """ + super(AutoSchema, self).__init__() if manual_fields is None: manual_fields = [] self._manual_fields = manual_fields @@ -455,6 +469,7 @@ def __init__(self, fields, description='', encoding=None): * `fields`: list of `coreapi.Field` instances. * `descripton`: String description for view. Optional. """ + super(ManualSchema, self).__init__() assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" self._fields = fields self._description = description @@ -474,9 +489,13 @@ def get_link(self, path, method, base_url): ) -class DefaultSchema(object): +class DefaultSchema(ViewInspector): """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" def __get__(self, instance, owner): + result = super(DefaultSchema, self).__get__(instance, owner) + if not isinstance(result, DefaultSchema): + return result + inspector_class = api_settings.DEFAULT_SCHEMA_CLASS assert issubclass(inspector_class, ViewInspector), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass" inspector = inspector_class() diff --git a/tests/test_schemas.py b/tests/test_schemas.py index e4a7c8646f..c2a429ac30 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -105,6 +105,10 @@ def custom_list_action_multiple_methods_delete(self, request): """Deletion description.""" raise NotImplementedError + @action(detail=False, schema=None) + def excluded_action(self, request): + pass + def get_serializer(self, *args, **kwargs): assert self.request assert self.action @@ -735,6 +739,45 @@ class CustomView(APIView): assert len(fields) == 2 assert "my_extra_field" in [f.name for f in fields] + @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') + def test_viewset_action_with_schema(self): + class CustomViewSet(GenericViewSet): + @action(detail=True, schema=AutoSchema(manual_fields=[ + coreapi.Field( + "my_extra_field", + required=True, + location="path", + schema=coreschema.String() + ), + ])) + def extra_action(self, pk, **kwargs): + pass + + router = SimpleRouter() + router.register(r'detail', CustomViewSet, base_name='detail') + + generator = SchemaGenerator() + view = generator.create_view(router.urls[0].callback, 'GET') + link = view.schema.get_link('/a/url/{id}/', 'GET', '') + fields = link.fields + + assert len(fields) == 2 + assert "my_extra_field" in [f.name for f in fields] + + @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') + def test_viewset_action_with_null_schema(self): + class CustomViewSet(GenericViewSet): + @action(detail=True, schema=None) + def extra_action(self, pk, **kwargs): + pass + + router = SimpleRouter() + router.register(r'detail', CustomViewSet, base_name='detail') + + generator = SchemaGenerator() + view = generator.create_view(router.urls[0].callback, 'GET') + assert view.schema is None + @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') def test_view_with_manual_schema(self):