diff --git a/aiohttp_apispec/aiohttp_apispec.py b/aiohttp_apispec/aiohttp_apispec.py index 7d6380c..f717864 100644 --- a/aiohttp_apispec/aiohttp_apispec.py +++ b/aiohttp_apispec/aiohttp_apispec.py @@ -6,7 +6,7 @@ from aiohttp.hdrs import METH_ALL, METH_ANY from apispec import APISpec from apispec.core import VALID_METHODS_OPENAPI_V2 -from apispec.ext.marshmallow import MarshmallowPlugin +from apispec.ext.marshmallow import MarshmallowPlugin, common from jinja2 import Template from webargs.aiohttpparser import parser @@ -17,6 +17,16 @@ VALID_RESPONSE_FIELDS = {"description", "headers", "examples"} +def resolver(schema): + schema_instance = common.resolve_schema_instance(schema) + prefix = "Partial-" if schema_instance.partial else "" + schema_cls = common.resolve_schema_cls(schema) + name = prefix + schema_cls.__name__ + if name.endswith("Schema"): + return name[:-6] or name + return name + + class AiohttpApiSpec: def __init__( self, @@ -31,7 +41,7 @@ def __init__( **kwargs ): - self.plugin = MarshmallowPlugin() + self.plugin = MarshmallowPlugin(schema_name_resolver=resolver) self.spec = APISpec(plugins=(self.plugin,), openapi_version="2.0", **kwargs) self.url = url diff --git a/tests/conftest.py b/tests/conftest.py index 3415126..cfaf4dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -82,6 +82,10 @@ async def handler_get(request): async def handler_post(request): return web.json_response({"msg": "done", "data": {}}) + @request_schema(RequestSchema(partial=True)) + async def handler_post_partial(request): + return web.json_response({"msg": "done", "data": {}}) + @request_schema(RequestSchema()) async def handler_post_callable_schema(request): return web.json_response({"msg": "done", "data": {}}) @@ -162,6 +166,7 @@ async def validated_view(request: web.Request): [ web.get("/test", handler_get), web.post("/test", handler_post), + web.post("/test_partial", handler_post_partial), web.post("/test_call", handler_post_callable_schema), web.get("/other", other), web.get("/echo", handler_get_echo), @@ -181,6 +186,7 @@ async def validated_view(request: web.Request): [ web.get("/v1/test", handler_get), web.post("/v1/test", handler_post), + web.post("/v1/test_partial", handler_post_partial), web.post("/v1/test_call", handler_post_callable_schema), web.get("/v1/other", other), web.get("/v1/echo", handler_get_echo), diff --git a/tests/test_documentation.py b/tests/test_documentation.py index 68d6414..8f6918f 100644 --- a/tests/test_documentation.py +++ b/tests/test_documentation.py @@ -114,20 +114,22 @@ async def test_app_swagger_json(aiohttp_app): sort_keys=True, ) + _request_properties = { + "properties": { + "bool_field": {"type": "boolean"}, + "id": {"format": "int32", "type": "integer"}, + "list_field": { + "items": {"format": "int32", "type": "integer"}, + "type": "array", + }, + "name": {"description": "name", "type": "string"}, + }, + "type": "object", + } assert json.dumps(docs["definitions"], sort_keys=True) == json.dumps( { - "Request": { - "properties": { - "bool_field": {"type": "boolean"}, - "id": {"format": "int32", "type": "integer"}, - "list_field": { - "items": {"format": "int32", "type": "integer"}, - "type": "array", - }, - "name": {"description": "name", "type": "string"}, - }, - "type": "object", - }, + "Request": _request_properties, + "Partial-Request": _request_properties, "Response": { "properties": {"data": {"type": "object"}, "msg": {"type": "string"}}, "type": "object",