From c6015c0f14db228c1c37ed03ac83cf2262038b6d Mon Sep 17 00:00:00 2001 From: DPR Date: Fri, 30 Jul 2021 01:59:56 +0800 Subject: [PATCH 1/4] Added - custom iscoroutinefunction func --- starlette/background.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/starlette/background.py b/starlette/background.py index 1160baeed..3a0e5d819 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,9 +1,17 @@ -import asyncio +import inspect import typing from starlette.concurrency import run_in_threadpool +def iscoroutinefunction(obj): + if inspect.iscoroutinefunction(obj): + return True + if hasattr(obj, '__call__') and inspect.iscoroutinefunction(obj.__call__): + return True + return False + + class BackgroundTask: def __init__( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any @@ -11,7 +19,7 @@ def __init__( self.func = func self.args = args self.kwargs = kwargs - self.is_async = asyncio.iscoroutinefunction(func) + self.is_async = iscoroutinefunction(func) async def __call__(self) -> None: if self.is_async: From 88d2496b2280fd8f031fcc37d0136335b4a8ea6e Mon Sep 17 00:00:00 2001 From: dpr-0 Date: Sat, 21 Aug 2021 15:06:22 +0800 Subject: [PATCH 2/4] Refactor - custom iscoroutinefunction --- starlette/background.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/starlette/background.py b/starlette/background.py index 3a0e5d819..139b78b45 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -4,12 +4,10 @@ from starlette.concurrency import run_in_threadpool -def iscoroutinefunction(obj): +def iscoroutinefunction(obj: object) -> bool: if inspect.iscoroutinefunction(obj): return True - if hasattr(obj, '__call__') and inspect.iscoroutinefunction(obj.__call__): - return True - return False + return callable(obj) and inspect.iscoroutinefunction(obj.__call__) class BackgroundTask: From 1d9d82fad883b3371fee0c2b8e38253ba4194c6f Mon Sep 17 00:00:00 2001 From: dpr-0 Date: Sat, 21 Aug 2021 15:06:25 +0800 Subject: [PATCH 3/4] Add - test case for async callable obj --- tests/test_background.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_background.py b/tests/test_background.py index e299ec362..f59ad966c 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -21,6 +21,26 @@ async def app(scope, receive, send): assert TASK_COMPLETE +def test_async_callable_task(test_client_factory): + TASK_COMPLETE = False + + class async_callable_task: + async def __call__(self): + nonlocal TASK_COMPLETE + TASK_COMPLETE = True + + task = BackgroundTask(async_callable_task()) + + async def app(scope, receive, send): + response = Response("task initiated", media_type="text/plain", background=task) + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/") + assert response.text == "task initiated" + assert TASK_COMPLETE + + def test_sync_task(test_client_factory): TASK_COMPLETE = False From d41ad1317bf0cfc312ce48b9114eb5bddba938a0 Mon Sep 17 00:00:00 2001 From: dpr-0 Date: Thu, 23 Sep 2021 11:14:33 +0800 Subject: [PATCH 4/4] Move iscoroutinefunction to utils.py (#1256) --- starlette/background.py | 8 +------- starlette/utils.py | 7 +++++++ tests/test_utils.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) create mode 100644 starlette/utils.py create mode 100644 tests/test_utils.py diff --git a/starlette/background.py b/starlette/background.py index 139b78b45..d0a186d74 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,13 +1,7 @@ -import inspect import typing from starlette.concurrency import run_in_threadpool - - -def iscoroutinefunction(obj: object) -> bool: - if inspect.iscoroutinefunction(obj): - return True - return callable(obj) and inspect.iscoroutinefunction(obj.__call__) +from starlette.utils import iscoroutinefunction class BackgroundTask: diff --git a/starlette/utils.py b/starlette/utils.py new file mode 100644 index 000000000..c993580e9 --- /dev/null +++ b/starlette/utils.py @@ -0,0 +1,7 @@ +import inspect + + +def iscoroutinefunction(obj: object) -> bool: + if inspect.iscoroutinefunction(obj): + return True + return callable(obj) and inspect.iscoroutinefunction(obj.__call__) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..3ba81695c --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,31 @@ +from starlette.utils import iscoroutinefunction + + +def test_is_coroutine_function(): + async def func(): + pass # pragma: no cover + + assert iscoroutinefunction(func) + + +def test_is_not_coroutine_function(): + def func(): + pass # pragma: no cover + + assert not iscoroutinefunction(func) + + +def test_is_async_callable(): + class async_callable_obj: + async def __call__(self): + pass # pragma: no cover + + assert iscoroutinefunction(async_callable_obj()) + + +def test_is_not_asnyc_callable(): + class callable_obj: + def __call__(self): + pass # pragma: no cover + + assert not iscoroutinefunction(callable_obj())