diff --git a/sentry_sdk/integrations/asyncio.py b/sentry_sdk/integrations/asyncio.py index ab07ffc3cb..c18089a492 100644 --- a/sentry_sdk/integrations/asyncio.py +++ b/sentry_sdk/integrations/asyncio.py @@ -16,39 +16,40 @@ from typing import Any -def _sentry_task_factory(loop, coro): - # type: (Any, Any) -> Task[None] +def patch_asyncio(): + # type: () -> None + orig_task_factory = None + try: + loop = asyncio.get_running_loop() + orig_task_factory = loop.get_task_factory() - async def _coro_creating_hub_and_span(): - # type: () -> None - hub = Hub(Hub.current) - with hub: - with hub.start_span(op=OP.FUNCTION, description=coro.__qualname__): - await coro + def _sentry_task_factory(loop, coro): + # type: (Any, Any) -> Any - # Trying to use user set task factory (if there is one) - orig_factory = loop.get_task_factory() - if orig_factory: - return orig_factory(loop, _coro_creating_hub_and_span) + async def _coro_creating_hub_and_span(): + # type: () -> None + hub = Hub(Hub.current) + with hub: + with hub.start_span(op=OP.FUNCTION, description=coro.__qualname__): + await coro - # The default task factory in `asyncio` does not have its own function - # but is just a couple of lines in `asyncio.base_events.create_task()` - # Those lines are copied here. + # Trying to use user set task factory (if there is one) + if orig_task_factory: + return orig_task_factory(loop, _coro_creating_hub_and_span()) # type: ignore - # WARNING: - # If the default behavior of the task creation in asyncio changes, - # this will break! - task = Task(_coro_creating_hub_and_span, loop=loop) # type: ignore - if task._source_traceback: # type: ignore - del task._source_traceback[-1] # type: ignore + # The default task factory in `asyncio` does not have its own function + # but is just a couple of lines in `asyncio.base_events.create_task()` + # Those lines are copied here. - return task + # WARNING: + # If the default behavior of the task creation in asyncio changes, + # this will break! + task = Task(_coro_creating_hub_and_span(), loop=loop) + if task._source_traceback: # type: ignore + del task._source_traceback[-1] # type: ignore + return task -def patch_asyncio(): - # type: () -> None - try: - loop = asyncio.get_running_loop() loop.set_task_factory(_sentry_task_factory) except RuntimeError: # When there is no running loop, we have nothing to patch.