diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index 7980670ed39..ccd79f244ae 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -227,11 +227,28 @@ from rest_framework.schemas.openapi import AutoSchema class ExampleView(APIView): """APIView subclass with custom schema introspection.""" - schema = AutoSchema(operation_name="Custom") + schema = AutoSchema(operation_id_base="Custom") ``` The previous example will generate the following operationid: "ListCustoms", "RetrieveCustom", "UpdateCustom", "PartialUpdateCustom", "DestroyCustom". +You need to provide the singular form of he operation name. For the list operation, a "s" will be append at the end of the name. + +If you need more configuration over the `operationId` field, you can override the `get_operation_id_base` and `get_operation_id` methods from the `AutoSchema` class. + +```python +class CustomSchema(AutoSchema): + def get_operation_id_base(self, action): + pass + + def get_operation_id(self, path, method): + pass + +class CustomView(APIView): + """APIView subclass with custom schema introspection.""" + schema = CustomSchema() +``` + [openapi]: https://github.com/OAI/OpenAPI-Specification [openapi-specification-extensions]: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#specification-extensions [openapi-operation]: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#operationObject diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index fdabd444322..4e4ede07df9 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -80,17 +80,17 @@ class AutoSchema(ViewInspector): 'delete': 'Destroy', } - def __init__(self, operation_name=None): + def __init__(self, operation_id_base=None): """ - :param operation_name: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name. + :param operation_id_base: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name. """ super().__init__() - self.operation_name = operation_name + self.operation_id_base = operation_id_base def get_operation(self, path, method): operation = {} - operation['operationId'] = self._get_operation_id(path, method) + operation['operationId'] = self.get_operation_id(path, method) operation['description'] = self.get_description(path, method) parameters = [] @@ -106,22 +106,14 @@ def get_operation(self, path, method): return operation - def _get_operation_id(self, path, method): + def get_operation_id_base(self, action): """ - Compute an operation ID from the model, serializer or view name. + Compute the base part for operation ID from the model, serializer or view name. """ - method_name = getattr(self.view, 'action', method.lower()) - if is_list_view(path, method, self.view): - action = 'list' - elif method_name not in self.method_mapping: - action = method_name - else: - action = self.method_mapping[method.lower()] - model = getattr(getattr(self.view, 'queryset', None), 'model', None) - if self.operation_name is not None: - name = self.operation_name + if self.operation_id_base is not None: + name = self.operation_id_base # Try to deduce the ID from the view's model elif model is not None: @@ -149,6 +141,22 @@ def _get_operation_id(self, path, method): if action == 'list' and not name.endswith('s'): # listThings instead of listThing name += 's' + return name + + def get_operation_id(self, path, method): + """ + Compute an operation ID from the view type and get_operation_id_base method. + """ + method_name = getattr(self.view, 'action', method.lower()) + if is_list_view(path, method, self.view): + action = 'list' + elif method_name not in self.method_mapping: + action = method_name + else: + action = self.method_mapping[method.lower()] + + name = self.get_operation_id_base(action) + return action + name def _get_path_parameters(self, path, method): diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index fd3e47a42fd..19277f95ea5 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -521,9 +521,24 @@ def test_operation_id_generation(self): inspector = AutoSchema() inspector.view = view - operationId = inspector._get_operation_id(path, method) + operationId = inspector.get_operation_id(path, method) assert operationId == 'listExamples' + def test_operation_id_custom_operation_id_base(self): + path = '/' + method = 'GET' + + view = create_view( + views.ExampleGenericAPIView, + method, + create_request(path), + ) + inspector = AutoSchema(operation_id_base="Ulysse") + inspector.view = view + + operationId = inspector.get_operation_id(path, method) + assert operationId == 'listUlysses' + def test_operation_id_custom_name(self): path = '/' method = 'GET' @@ -533,12 +548,48 @@ def test_operation_id_custom_name(self): method, create_request(path), ) - inspector = AutoSchema(operation_name="Ulysse") + inspector = AutoSchema(operation_id_base='Ulysse') inspector.view = view - operationId = inspector._get_operation_id(path, method) + operationId = inspector.get_operation_id(path, method) assert operationId == 'listUlysses' + def test_operation_id_override_get(self): + class CustomSchema(AutoSchema): + def get_operation_id(self, path, method): + return 'myCustomOperationId' + + path = '/' + method = 'GET' + view = create_view( + views.ExampleGenericAPIView, + method, + create_request(path), + ) + inspector = CustomSchema() + inspector.view = view + + operationId = inspector.get_operation_id(path, method) + assert operationId == 'myCustomOperationId' + + def test_operation_id_override_base(self): + class CustomSchema(AutoSchema): + def get_operation_id_base(self, action): + return 'Item' + + path = '/' + method = 'GET' + view = create_view( + views.ExampleGenericAPIView, + method, + create_request(path), + ) + inspector = CustomSchema() + inspector.view = view + + operationId = inspector.get_operation_id(path, method) + assert operationId == 'listItem' + def test_repeat_operation_ids(self): router = routers.SimpleRouter() router.register('account', views.ExampleGenericViewSet, basename="account")