diff --git a/asgiref/sync.py b/asgiref/sync.py index 55cce779..6d19faad 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -348,6 +348,15 @@ class SyncToAsync: else: thread_sensitive_context: None = None + # Contextvar that is used to detect if the single thread executor + # would be awaited on while already being used in the same context + if sys.version_info >= (3, 7): + deadlock_context: "contextvars.ContextVar[bool]" = contextvars.ContextVar( + "deadlock_context" + ) + else: + deadlock_context: None = None + # Maintaining a weak reference to the context ensures that thread pools are # erased once the context goes out of scope. This terminates the thread pool. context_to_thread_executor: "weakref.WeakKeyDictionary[object, ThreadPoolExecutor]" = ( @@ -396,9 +405,15 @@ async def __call__(self, *args, **kwargs): # Create new thread executor in current context executor = ThreadPoolExecutor(max_workers=1) self.context_to_thread_executor[thread_sensitive_context] = executor + elif self.deadlock_context and self.deadlock_context.get(False): + raise RuntimeError( + "Single thread executor already being used, would deadlock" + ) else: # Otherwise, we run it in a fixed single thread executor = self.single_thread_executor + if self.deadlock_context: + self.deadlock_context.set(True) else: # Use the passed in executor, or the loop's default if it is None executor = self._executor @@ -429,6 +444,8 @@ async def __call__(self, *args, **kwargs): if contextvars is not None: _restore_context(context) + if self.deadlock_context: + self.deadlock_context.set(False) return ret diff --git a/tests/test_sync.py b/tests/test_sync.py index 8ed76a79..84daa55f 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,6 +1,7 @@ import asyncio import functools import multiprocessing +import sys import threading import time from concurrent.futures import ThreadPoolExecutor @@ -669,3 +670,27 @@ def sync_func(): thread_sensitive=True, executor=custom_executor, ) + + +@pytest.mark.skipif(sys.version_info < (3, 7), reason="Issue persists with 3.6") +def test_sync_to_async_deadlock_raises(): + def db_write(): + pass + + async def io_task(): + await sync_to_async(db_write)() + + async def do_io_tasks(): + t = asyncio.create_task(io_task()) + await t + # await asyncio.gather(io_task()) # Also deadlocks + # await io_task() # Works + + def view(): + async_to_sync(do_io_tasks)() + + async def server_entry(): + await sync_to_async(view)() + + with pytest.raises(RuntimeError): + asyncio.run(server_entry())