diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index 22894a9782..2b83e0671b 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -603,6 +603,31 @@ Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fiel Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method of any filter classes used by the view. +### get_manual_fields(self, path, method) + +Return a list of `coreapi.Field()` instances to be added to or replace generated fields. Defaults to (optional) `manual_fields` passed to `AutoSchema` constructor. + +May be overridden to customise manual fields by `path` or `method`. For example, a per-method adjustment may look like this: + +```python +def get_manual_fields(self, path, method): + """Example adding per-method fields.""" + + extra_fields = [] + if method=='GET': + extra_fields = # ... list of extra fields for GET ... + if method=='POST': + extra_fields = # ... list of extra fields for POST ... + + manual_fields = super().get_manual_fields() + return manual_fields + extra_fields +``` + +### update_fields(fields, update_with) + +Utility `staticmethod`. Encapsulates logic to add or replace fields from a list +by `Field.name`. May be overridden to adjust replacement criteria. + ## ManualSchema diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index 44d8c7a120..2f2cdf1a15 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -40,6 +40,25 @@ You can determine your currently installed version using `pip freeze`: ## 3.7.x series +### 3.7.4 + +**Date**: UNRELEASED + +* Extract method for `manual_fields` processing [#5633][gh5633] + + Allows for easier customisation of `manual_fields` processing, for example + to provide per-method manual fields. `AutoSchema` adds `get_manual_fields`, + as the intended override point, and a utility method `update_fields`, to + handle by-name field replacement from a list, which, in general, you are not + expected to override. + + Note: `AutoSchema.__init__` now ensures `manual_fields` is a list. + Previously may have been stored internally as `None`. + + +[gh5633]: https://github.com/encode/django-rest-framework/issues/5633 + + ### 3.7.3 diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 47f5b9e13e..008d7c0910 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -172,7 +172,8 @@ def __init__(self, manual_fields=None): * `manual_fields`: list of `coreapi.Field` instances that will be added to auto-generated fields, overwriting on `Field.name` """ - + if manual_fields is None: + manual_fields = [] self._manual_fields = manual_fields def get_link(self, path, method, base_url): @@ -181,11 +182,8 @@ def get_link(self, path, method, base_url): fields += self.get_pagination_fields(path, method) fields += self.get_filter_fields(path, method) - if self._manual_fields is not None: - by_name = {f.name: f for f in fields} - for f in self._manual_fields: - by_name[f.name] = f - fields = list(by_name.values()) + manual_fields = self.get_manual_fields(path, method) + fields = self.update_fields(fields, manual_fields) if fields and any([field.location in ('form', 'body') for field in fields]): encoding = self.get_encoding(path, method) @@ -379,6 +377,31 @@ def get_filter_fields(self, path, method): fields += filter_backend().get_schema_fields(self.view) return fields + def get_manual_fields(self, path, method): + return self._manual_fields + + @staticmethod + def update_fields(fields, update_with): + """ + Update list of coreapi.Field instances, overwriting on `Field.name`. + + Utility function to handle replacing coreapi.Field fields + from a list by name. Used to handle `manual_fields`. + + Parameters: + + * `fields`: list of `coreapi.Field` instances to update + * `update_with: list of `coreapi.Field` instances to add or replace. + """ + if not update_with: + return fields + + by_name = OrderedDict((f.name, f) for f in fields) + for f in update_with: + by_name[f.name] = f + fields = list(by_name.values()) + return fields + def get_encoding(self, path, method): """ Return the 'encoding' parameter to use for a given endpoint. diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 56692d4f59..ba561a9597 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -516,7 +516,7 @@ def test_4605_regression(self): assert prefix == '/' -class TestDescriptor(TestCase): +class TestAutoSchema(TestCase): def test_apiview_schema_descriptor(self): view = APIView() @@ -528,7 +528,43 @@ def test_get_link_requires_instance(self): with pytest.raises(AssertionError): descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert? - def test_manual_fields(self): + def test_update_fields(self): + """ + That updating fields by-name helper is correct + + Recall: `update_fields(fields, update_with)` + """ + schema = AutoSchema() + fields = [] + + # Adds a field... + fields = schema.update_fields(fields, [ + coreapi.Field( + "my_field", + required=True, + location="path", + schema=coreschema.String() + ), + ]) + + assert len(fields) == 1 + assert fields[0].name == "my_field" + + # Replaces a field... + fields = schema.update_fields(fields, [ + coreapi.Field( + "my_field", + required=False, + location="path", + schema=coreschema.String() + ), + ]) + + assert len(fields) == 1 + assert fields[0].required is False + + def test_get_manual_fields(self): + """That get_manual_fields is applied during get_link""" class CustomView(APIView): schema = AutoSchema(manual_fields=[