diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index c3ff8b74c0b..a09f69aeba2 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -7,6 +7,7 @@ import re import warnings from collections import OrderedDict +from weakref import WeakKeyDictionary from django.db import models from django.utils.encoding import force_text, smart_text @@ -129,6 +130,31 @@ class ViewInspector(object): Provide subclass for per-view schema generation """ + def __init__(self): + self.instance_schemas = WeakKeyDictionary() + + def __get__(self, instance, owner): + """ + Enables `ViewInspector` as a Python _Descriptor_. + This is how `view.schema` knows about `view`. + `__get__` is called when the descriptor is accessed on the owner. + (That will be when view.schema is called in our case.) + `owner` is always the owner class. (An APIView, or subclass for us.) + `instance` is the view instance or `None` if accessed from the class, + rather than an instance. + See: https://docs.python.org/3/howto/descriptor.html for info on + descriptor usage. + """ + if instance in self.instance_schemas: + return self.instance_schemas[instance] + + self.view = instance + return self + + def __set__(self, instance, other): + self.instance_schemas[instance] = other + other.view = instance + @property def view(self): """View property.""" @@ -171,6 +197,7 @@ def __init__(self, manual_fields=None): * `manual_fields`: list of `coreapi.Field` instances that will be added to auto-generated fields, overwriting on `Field.name` """ + super(AutoSchema, self).__init__() if manual_fields is None: manual_fields = [] self._manual_fields = manual_fields @@ -437,6 +464,7 @@ def __init__(self, fields, description='', encoding=None): * `fields`: list of `coreapi.Field` instances. * `descripton`: String description for view. Optional. """ + super(ManualSchema, self).__init__() assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" self._fields = fields self._description = description @@ -456,9 +484,15 @@ def get_link(self, path, method, base_url): ) -class DefaultSchema(object): +class DefaultSchema(ViewInspector): """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" def __get__(self, instance, owner): + result = super(DefaultSchema, self).__get__(instance, owner) + if not isinstance(result, DefaultSchema): + return result + inspector_class = api_settings.DEFAULT_SCHEMA_CLASS assert issubclass(inspector_class, ViewInspector), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass" - return inspector_class() + inspector = inspector_class() + inspector.view = instance + return inspector diff --git a/rest_framework/views.py b/rest_framework/views.py index e25f2882934..1f51517db32 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,8 +3,6 @@ """ from __future__ import unicode_literals -from copy import deepcopy - from django.conf import settings from django.core.exceptions import PermissionDenied from django.db import connection, models, transaction @@ -110,14 +108,6 @@ class APIView(View): schema = DefaultSchema() - def __init__(self, **kwargs): - super(APIView, self).__init__(**kwargs) - if self.schema is not None: - # copy class-level schema to prevent instances using the same object - if 'schema' not in self.__dict__: - self.schema = deepcopy(self.schema) - self.schema.view = self - @classmethod def as_view(cls, **initkwargs): """