Skip to content

Commit

Permalink
Extract method for manual_fields processing (#5633)
Browse files Browse the repository at this point in the history
* Extract method for `manual_fields` processing

Allows reuse of logic to replace Field instances in a field list by `Field.name`.

Adds a utility function for the logic plus a wrapper method on `AutoSchema`.

Closes #5632

* Manual fields suggestions (#2)

* Use OrderedDict in inspectors

* Move empty check to 'update_fields()'

* Make 'update_fields()' an AutoSchema staticmethod

* Add 'AutoSchema.get_manual_fields()'

* Conform '.get_manual_fields()' to other methods

* Add test for update_fields

* Make sure `manual_fields` is a list.

(As documented to be)

* Add docs for new AutoSchema methods.

* `get_manual_fields`
* `update_fields`

* Add release notes for PR.
  • Loading branch information
carltongibson authored Dec 4, 2017
1 parent daba5e9 commit a0cdba6
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 8 deletions.
25 changes: 25 additions & 0 deletions docs/api-guide/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,31 @@ Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fiel

Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method of any filter classes used by the view.

### get_manual_fields(self, path, method)

Return a list of `coreapi.Field()` instances to be added to or replace generated fields. Defaults to (optional) `manual_fields` passed to `AutoSchema` constructor.

May be overridden to customise manual fields by `path` or `method`. For example, a per-method adjustment may look like this:

```python
def get_manual_fields(self, path, method):
"""Example adding per-method fields."""

extra_fields = []
if method=='GET':
extra_fields = # ... list of extra fields for GET ...
if method=='POST':
extra_fields = # ... list of extra fields for POST ...

manual_fields = super().get_manual_fields()
return manual_fields + extra_fields
```

### update_fields(fields, update_with)

Utility `staticmethod`. Encapsulates logic to add or replace fields from a list
by `Field.name`. May be overridden to adjust replacement criteria.


## ManualSchema

Expand Down
19 changes: 19 additions & 0 deletions docs/topics/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ You can determine your currently installed version using `pip freeze`:

## 3.7.x series

### 3.7.4

**Date**: UNRELEASED

* Extract method for `manual_fields` processing [#5633][gh5633]

Allows for easier customisation of `manual_fields` processing, for example
to provide per-method manual fields. `AutoSchema` adds `get_manual_fields`,
as the intended override point, and a utility method `update_fields`, to
handle by-name field replacement from a list, which, in general, you are not
expected to override.

Note: `AutoSchema.__init__` now ensures `manual_fields` is a list.
Previously may have been stored internally as `None`.


[gh5633]: https://github.com/encode/django-rest-framework/issues/5633



### 3.7.3

Expand Down
35 changes: 29 additions & 6 deletions rest_framework/schemas/inspectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def __init__(self, manual_fields=None):
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""

if manual_fields is None:
manual_fields = []
self._manual_fields = manual_fields

def get_link(self, path, method, base_url):
Expand All @@ -181,11 +182,8 @@ def get_link(self, path, method, base_url):
fields += self.get_pagination_fields(path, method)
fields += self.get_filter_fields(path, method)

if self._manual_fields is not None:
by_name = {f.name: f for f in fields}
for f in self._manual_fields:
by_name[f.name] = f
fields = list(by_name.values())
manual_fields = self.get_manual_fields(path, method)
fields = self.update_fields(fields, manual_fields)

if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method)
Expand Down Expand Up @@ -379,6 +377,31 @@ def get_filter_fields(self, path, method):
fields += filter_backend().get_schema_fields(self.view)
return fields

def get_manual_fields(self, path, method):
return self._manual_fields

@staticmethod
def update_fields(fields, update_with):
"""
Update list of coreapi.Field instances, overwriting on `Field.name`.
Utility function to handle replacing coreapi.Field fields
from a list by name. Used to handle `manual_fields`.
Parameters:
* `fields`: list of `coreapi.Field` instances to update
* `update_with: list of `coreapi.Field` instances to add or replace.
"""
if not update_with:
return fields

by_name = OrderedDict((f.name, f) for f in fields)
for f in update_with:
by_name[f.name] = f
fields = list(by_name.values())
return fields

def get_encoding(self, path, method):
"""
Return the 'encoding' parameter to use for a given endpoint.
Expand Down
40 changes: 38 additions & 2 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def test_4605_regression(self):
assert prefix == '/'


class TestDescriptor(TestCase):
class TestAutoSchema(TestCase):

def test_apiview_schema_descriptor(self):
view = APIView()
Expand All @@ -528,7 +528,43 @@ def test_get_link_requires_instance(self):
with pytest.raises(AssertionError):
descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert?

def test_manual_fields(self):
def test_update_fields(self):
"""
That updating fields by-name helper is correct
Recall: `update_fields(fields, update_with)`
"""
schema = AutoSchema()
fields = []

# Adds a field...
fields = schema.update_fields(fields, [
coreapi.Field(
"my_field",
required=True,
location="path",
schema=coreschema.String()
),
])

assert len(fields) == 1
assert fields[0].name == "my_field"

# Replaces a field...
fields = schema.update_fields(fields, [
coreapi.Field(
"my_field",
required=False,
location="path",
schema=coreschema.String()
),
])

assert len(fields) == 1
assert fields[0].required is False

def test_get_manual_fields(self):
"""That get_manual_fields is applied during get_link"""

class CustomView(APIView):
schema = AutoSchema(manual_fields=[
Expand Down

0 comments on commit a0cdba6

Please sign in to comment.