From fe908b1c29f2f890110a1e64813083ae314a2d9f Mon Sep 17 00:00:00 2001 From: Vlad Stefan Munteanu Date: Tue, 2 Feb 2021 12:30:30 +0200 Subject: [PATCH] Fix functools.partial async handlers for classmethods (#1106) * Showcase the bug * Fixed functools.partial usage with classmethods * Updated comment * Updated docstring according to suggestion Co-authored-by: Jamie Hewland Co-authored-by: Jamie Hewland --- starlette/routing.py | 8 +++----- tests/test_routing.py | 23 +++++++++++++++++++++-- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index ce5e4d192..1e6ae0b55 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -2,7 +2,6 @@ import functools import inspect import re -import sys import traceback import typing from enum import Enum @@ -33,11 +32,10 @@ class Match(Enum): def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: """ Correctly determines if an object is a coroutine function, - with a fix for partials on Python < 3.8. + including those wrapped in functools.partial objects. """ - if sys.version_info < (3, 8): # pragma: no cover - while isinstance(obj, functools.partial): - obj = obj.func + while isinstance(obj, functools.partial): + obj = obj.func return inspect.iscoroutinefunction(obj) diff --git a/tests/test_routing.py b/tests/test_routing.py index 27640efe4..8927c60cd 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -590,16 +590,35 @@ def run_shutdown(): pass # pragma: nocover +class AsyncEndpointClassMethod: + @classmethod + async def async_endpoint(cls, arg, request): + return JSONResponse({"arg": arg}) + + async def _partial_async_endpoint(arg, request): return JSONResponse({"arg": arg}) partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo") +partial_cls_async_endpoint = functools.partial( + AsyncEndpointClassMethod.async_endpoint, "foo" +) -partial_async_app = Router(routes=[Route("/", partial_async_endpoint)]) +partial_async_app = Router( + routes=[ + Route("/", partial_async_endpoint), + Route("/cls", partial_cls_async_endpoint), + ] +) def test_partial_async_endpoint(): - response = TestClient(partial_async_app).get("/") + test_client = TestClient(partial_async_app) + response = test_client.get("/") assert response.status_code == 200 assert response.json() == {"arg": "foo"} + + cls_method_response = test_client.get("/cls") + assert cls_method_response.status_code == 200 + assert cls_method_response.json() == {"arg": "foo"}