Skip to content

Commit

Permalink
Add ViewInspector setter to store instances
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan P Kilby committed May 17, 2018
1 parent e93fa36 commit 1b489fd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
38 changes: 36 additions & 2 deletions rest_framework/schemas/inspectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
import warnings
from collections import OrderedDict
from weakref import WeakKeyDictionary

from django.db import models
from django.utils.encoding import force_text, smart_text
Expand Down Expand Up @@ -129,6 +130,31 @@ class ViewInspector(object):
Provide subclass for per-view schema generation
"""

def __init__(self):
self.instance_schemas = WeakKeyDictionary()

def __get__(self, instance, owner):
"""
Enables `ViewInspector` as a Python _Descriptor_.
This is how `view.schema` knows about `view`.
`__get__` is called when the descriptor is accessed on the owner.
(That will be when view.schema is called in our case.)
`owner` is always the owner class. (An APIView, or subclass for us.)
`instance` is the view instance or `None` if accessed from the class,
rather than an instance.
See: https://docs.python.org/3/howto/descriptor.html for info on
descriptor usage.
"""
if instance in self.instance_schemas:
return self.instance_schemas[instance]

self.view = instance
return self

def __set__(self, instance, other):
self.instance_schemas[instance] = other
other.view = instance

@property
def view(self):
"""View property."""
Expand Down Expand Up @@ -171,6 +197,7 @@ def __init__(self, manual_fields=None):
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""
super(AutoSchema, self).__init__()
if manual_fields is None:
manual_fields = []
self._manual_fields = manual_fields
Expand Down Expand Up @@ -437,6 +464,7 @@ def __init__(self, fields, description='', encoding=None):
* `fields`: list of `coreapi.Field` instances.
* `descripton`: String description for view. Optional.
"""
super(ManualSchema, self).__init__()
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
Expand All @@ -456,9 +484,15 @@ def get_link(self, path, method, base_url):
)


class DefaultSchema(object):
class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
def __get__(self, instance, owner):
result = super(DefaultSchema, self).__get__(instance, owner)
if not isinstance(result, DefaultSchema):
return result

inspector_class = api_settings.DEFAULT_SCHEMA_CLASS
assert issubclass(inspector_class, ViewInspector), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
return inspector_class()
inspector = inspector_class()
inspector.view = instance
return inspector
10 changes: 0 additions & 10 deletions rest_framework/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
"""
from __future__ import unicode_literals

from copy import deepcopy

from django.conf import settings
from django.core.exceptions import PermissionDenied
from django.db import connection, models, transaction
Expand Down Expand Up @@ -110,14 +108,6 @@ class APIView(View):

schema = DefaultSchema()

def __init__(self, **kwargs):
super(APIView, self).__init__(**kwargs)
if self.schema is not None:
# copy class-level schema to prevent instances using the same object
if 'schema' not in self.__dict__:
self.schema = deepcopy(self.schema)
self.schema.view = self

@classmethod
def as_view(cls, **initkwargs):
"""
Expand Down

0 comments on commit 1b489fd

Please sign in to comment.