diff --git a/starlette/background.py b/starlette/background.py index 1160baeed..d0a186d74 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,7 +1,7 @@ -import asyncio import typing from starlette.concurrency import run_in_threadpool +from starlette.utils import iscoroutinefunction class BackgroundTask: @@ -11,7 +11,7 @@ def __init__( self.func = func self.args = args self.kwargs = kwargs - self.is_async = asyncio.iscoroutinefunction(func) + self.is_async = iscoroutinefunction(func) async def __call__(self) -> None: if self.is_async: diff --git a/starlette/utils.py b/starlette/utils.py new file mode 100644 index 000000000..c993580e9 --- /dev/null +++ b/starlette/utils.py @@ -0,0 +1,7 @@ +import inspect + + +def iscoroutinefunction(obj: object) -> bool: + if inspect.iscoroutinefunction(obj): + return True + return callable(obj) and inspect.iscoroutinefunction(obj.__call__) diff --git a/tests/test_background.py b/tests/test_background.py index e299ec362..f59ad966c 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -21,6 +21,26 @@ async def app(scope, receive, send): assert TASK_COMPLETE +def test_async_callable_task(test_client_factory): + TASK_COMPLETE = False + + class async_callable_task: + async def __call__(self): + nonlocal TASK_COMPLETE + TASK_COMPLETE = True + + task = BackgroundTask(async_callable_task()) + + async def app(scope, receive, send): + response = Response("task initiated", media_type="text/plain", background=task) + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/") + assert response.text == "task initiated" + assert TASK_COMPLETE + + def test_sync_task(test_client_factory): TASK_COMPLETE = False diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..3ba81695c --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,31 @@ +from starlette.utils import iscoroutinefunction + + +def test_is_coroutine_function(): + async def func(): + pass # pragma: no cover + + assert iscoroutinefunction(func) + + +def test_is_not_coroutine_function(): + def func(): + pass # pragma: no cover + + assert not iscoroutinefunction(func) + + +def test_is_async_callable(): + class async_callable_obj: + async def __call__(self): + pass # pragma: no cover + + assert iscoroutinefunction(async_callable_obj()) + + +def test_is_not_asnyc_callable(): + class callable_obj: + def __call__(self): + pass # pragma: no cover + + assert not iscoroutinefunction(callable_obj())