Skip to content

Commit

Permalink
Added the possibility of lazy initialization of asynchronous parts of…
Browse files Browse the repository at this point in the history
… the Queue.

The behavior of all functions after initialization has not changed. Also the duplicated code from `_notify_(a)sync_not_(empty/full)` was moved to a separate `_notify_(a)sync_condition` function. As a result, the bug that `_notify_sync_not_empty` did not add a handler to `_pending` was also fixed.

Prior to full initialization, some `Queue` attributes are replaced with dummies. It is worth noting that `async_q` is replaced with an instance of the `PreInitDummyAsyncQueue` class before `Queue` is fully initialized. Although after full initialization, `Queue.async_q` is replaced by the desired object, it is worth remembering that the reference to `PreInitDummyAsyncQueue` obj could remain with the user. However, this is not a problem since after initialization, the dummy starts working as a proxy.
  • Loading branch information
s0d3s committed Mar 30, 2023
1 parent dc2fb08 commit 01424a6
Showing 1 changed file with 238 additions and 38 deletions.
276 changes: 238 additions & 38 deletions janus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import asyncio
import sys
import threading
from concurrent.futures import ThreadPoolExecutor, Future, CancelledError
from time import time as time_time
from asyncio import AbstractEventLoop
from asyncio import QueueEmpty as AsyncQueueEmpty
from asyncio import QueueFull as AsyncQueueFull
from collections import deque
from heapq import heappop, heappush
from queue import Empty as SyncQueueEmpty
from queue import Full as SyncQueueFull
from typing import Any, Callable, Deque, Generic, List, Optional, Set, TypeVar
from typing import Any, Callable, Deque, Generic, List, Optional, Set, TypeVar, Union, Tuple

from typing_extensions import Protocol

Expand All @@ -25,6 +27,137 @@

T = TypeVar("T")
OptFloat = Optional[float]
PostAsyncInit = Optional[T]


class InitAsyncPartsMixin:
@property
def already_initialized(self) -> bool:
"""Indicate that instance already initialized"""
raise NotImplementedError()

@property
def _also_initialize_when_triggered(self) -> List["InitAsyncPartsMixin"]:
"""Returns a list of objects whose async parts must also be initialized."""
return []

@property
def _list_of_methods_to_patch(self) -> List[Tuple[str, str]]:
"""Return list of ('cur_methods', 'new_method') for monkey-patching
List of methods whose behavior has been changed to be use without initializing the async parts
"""
return []

def _async_post_init_patch_methods(self):
"""Monkey patching"""
for method_name, new_method in ((cm, getattr(self, nm)) for cm, nm in self._list_of_methods_to_patch):
setattr(self, method_name, new_method)

def _async_post_init_handler(self, loop: Optional[AbstractEventLoop] = None, **params) -> Optional[AbstractEventLoop]:
"""Handle initializing of asynchronous parts of object"""
if self.already_initialized:
return loop

if loop is None:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
raise RuntimeError("Async parts of 'janus' must be initialized only from running loop. "
"Do not call async from sync code")

# 'already_initialized' must be True after call
self._async_init(loop, **params)
self._async_post_init_patch_methods()

for need_init in self._also_initialize_when_triggered:
if not isinstance(need_init, InitAsyncPartsMixin):
raise ValueError("'_also_initialize_when_triggered' must contain only instances of "
"class that inherited from 'InitAsyncPartsMixin'")
need_init._async_post_init_handler(loop)

return loop

def _async_init(self, loop: AbstractEventLoop, **params):
"""Override to change behavior
The actions of this function should affect the "value" of 'already_initialized' (set it to True)
"""
...

def trigger_async_initialization(self, **params):
"""Trigger initialization of async parts
Public alias for '_async_post_init_handler'
"""
return self._async_post_init_handler(**params)


class PreInitDummyLoop:
"""Replacement for a 'Queue._loop', until the async part is fully initialized"""

def __init__(self):
self.executor = ThreadPoolExecutor(thread_name_prefix="PreInitDummyLoop-")
self.pending = set() # type: Set[Future[Any]]

@staticmethod
def time():
"""Replacement of '_loop.time' in '_SyncQueueProxy.get' and '_SyncQueueProxy.put'"""
return time_time()

def call_soon_threadsafe(self, callback: Callable[..., None]):
future = self.executor.submit(callback)
self.pending.add(future)
future.add_done_callback(self.pending.discard)

def run_in_executor(self, callback: Callable[..., None]):
future = self.executor.submit(callback)
self.pending.add(future)
future.add_done_callback(self.pending.discard)

def wait(self):
for task in self.pending:
try:
task.result()
except CancelledError:
...

def cleanup(self):
for task in self.pending:
task.cancel()


class PreInitDummyAsyncQueue:
"""Replacement of 'Queue.async_q'
Will trigger initialization of async part, on every access to attrs.
If, after full initialization, someone has a link to it, it starts working as a proxy,
redirecting everything to the actual 'async_q'
"""

def __init__(self, trigger_obj: "Queue"):
self.trigger_obj = trigger_obj
self.already_triggered = threading.Event()

def __getattribute__(self, item):
already_triggered = super().__getattribute__("already_triggered") # type: threading.Event
trigger_obj = super().__getattribute__("trigger_obj")

if already_triggered.is_set():
async_q = getattr(trigger_obj, "async_q")

if not isinstance(async_q, _AsyncQueueProxy):
raise RuntimeError("Async parts multi-initialization detected. You cannot access 'async_q' attrs "
"until full initialization is complete")
return getattr(async_q, item)

already_triggered.set()
trigger_obj.trigger_async_initialization()
async_q = getattr(trigger_obj, "async_q")

if isinstance(async_q, PreInitDummyAsyncQueue):
raise RuntimeError("Error during post initializing. 'async_q' must be replaced with actual 'AsyncQueue'")
return getattr(async_q, item)


class BaseQueue(Protocol[T]):
Expand Down Expand Up @@ -82,11 +215,13 @@ async def join(self) -> None:
...


class Queue(Generic[T]):
def __init__(self, maxsize: int = 0, loop: Optional[AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_running_loop()
class Queue(Generic[T], InitAsyncPartsMixin):
def __init__(self, maxsize: int = 0, init_async_part: bool = True) -> None:
self._maxsize = maxsize

# will be set after the async part is initialized
self.full_init = threading.Event()

self._init(maxsize)

self._unfinished_tasks = 0
Expand All @@ -96,20 +231,60 @@ def __init__(self, maxsize: int = 0, loop: Optional[AbstractEventLoop] = None) -
self._sync_not_full = threading.Condition(self._sync_mutex)
self._all_tasks_done = threading.Condition(self._sync_mutex)

self._closing = False
self._pending = set() # type: Set[asyncio.Future[Any]]

self._loop = PreInitDummyLoop() # type: Union[PreInitDummyLoop, AbstractEventLoop]

self._async_mutex = asyncio.Lock() # type: PostAsyncInit[asyncio.Lock]
if sys.version_info[:3] == (3, 10, 0):
# Workaround for Python 3.10 bug, see #358:
getattr(self._async_mutex, "_get_loop", lambda: None)()
self._async_not_empty = None # type: PostAsyncInit[asyncio.Condition]
self._async_not_full = None # type: PostAsyncInit[asyncio.Condition]
# set 'threading.Event' to not change behavior
self._finished = threading.Event() # type: Union[asyncio.Event, threading.Event]

def before_init_async_parts_dummy_handler(
callback: Callable[..., None], *args: Any
) -> None:
callback(*args)

self._call_soon_threadsafe = before_init_async_parts_dummy_handler

self._call_soon = before_init_async_parts_dummy_handler

self._sync_queue = _SyncQueueProxy(self)
self._async_queue = PreInitDummyAsyncQueue(self) # type: Union[PreInitDummyAsyncQueue, "_AsyncQueueProxy[T]"]

if init_async_part:
self.trigger_async_initialization()

@property
def already_initialized(self) -> bool:
"""Return True if all parts of 'Queue'(sync/async) are initialized"""
return self.full_init.is_set()

def _async_init(self, loop: AbstractEventLoop, **params):
self._loop = loop

self._async_queue = _AsyncQueueProxy(self)
self._async_mutex = asyncio.Lock()
if sys.version_info[:3] == (3, 10, 0):
# Workaround for Python 3.10 bug, see #358:
getattr(self._async_mutex, "_get_loop", lambda: None)()
self._async_not_empty = asyncio.Condition(self._async_mutex)
self._async_not_full = asyncio.Condition(self._async_mutex)

_finished = self._finished
self._finished = asyncio.Event()
self._finished.set()

self._closing = False
self._pending = set() # type: Set[asyncio.Future[Any]]
if not _finished.is_set():
_finished.set()

def checked_call_soon_threadsafe(
callback: Callable[..., None], *args: Any
callback: Callable[..., None], *args: Any
) -> None:
try:
self._loop.call_soon_threadsafe(callback, *args)
Expand All @@ -125,14 +300,23 @@ def checked_call_soon(callback: Callable[..., None], *args: Any) -> None:

self._call_soon = checked_call_soon

self._sync_queue = _SyncQueueProxy(self)
self._async_queue = _AsyncQueueProxy(self)
self.full_init.set()

@property
def _list_of_methods_to_patch(self) -> List[Tuple[str, str]]:
return [
("_notify_sync_condition", "_post_async_init_notify_sync_condition"),
("_notify_async_condition", "_post_async_init_notify_async_condition"),
]

def close(self) -> None:
with self._sync_mutex:
self._closing = True
for fut in self._pending:
fut.cancel()
if isinstance(self._loop, PreInitDummyLoop):
self._loop.cleanup()
else:
for fut in self._pending:
fut.cancel()
self._finished.set() # unblocks all async_q.join()
self._all_tasks_done.notify_all() # unblocks all sync_q.join()

Expand All @@ -147,9 +331,13 @@ async def wait_closed(self) -> None:
# _notify_async_not_empty, _notify_async_not_full
# methods.
await asyncio.sleep(0)
if not self._pending:
return
await asyncio.wait(self._pending)

if isinstance(self._loop, PreInitDummyLoop):
self._loop.wait()
else:
if not self._pending:
return
await asyncio.wait(self._pending)

@property
def closed(self) -> bool:
Expand All @@ -164,7 +352,7 @@ def sync_q(self) -> "_SyncQueueProxy[T]":
return self._sync_queue

@property
def async_q(self) -> "_AsyncQueueProxy[T]":
def async_q(self) -> Union[PreInitDummyAsyncQueue, "_AsyncQueueProxy[T]"]:
return self._async_queue

# Override these methods to implement other queue organizations
Expand All @@ -190,26 +378,37 @@ def _put_internal(self, item: T) -> None:
self._unfinished_tasks += 1
self._finished.clear()

def _notify_sync_not_empty(self) -> None:
def _post_async_init_notify_sync_condition(self, condition: asyncio.Condition) -> None:
""" Replacement for '_notify_sync_condition', after initializing the async parts """
def f() -> None:
with self._sync_mutex:
self._sync_not_empty.notify()
condition.notify()

fut = asyncio.ensure_future(self._loop.run_in_executor(None, f), loop=self._loop)
fut.add_done_callback(self._pending.discard)
self._pending.add(fut)

self._loop.run_in_executor(None, f)
def _notify_sync_condition(self, condition: asyncio.Condition) -> None:
"""A single interface for notifying sync conditions"""
loop = self._loop # type: PreInitDummyLoop

def _notify_sync_not_full(self) -> None:
def f() -> None:
with self._sync_mutex:
self._sync_not_full.notify()
condition.notify()

fut = asyncio.ensure_future(self._loop.run_in_executor(None, f))
fut.add_done_callback(self._pending.discard)
self._pending.add(fut)
loop.run_in_executor(f)

def _notify_async_not_empty(self, *, threadsafe: bool) -> None:
def _notify_sync_not_empty(self) -> None:
self._notify_sync_condition(self._sync_not_empty)

def _notify_sync_not_full(self) -> None:
self._notify_sync_condition(self._sync_not_full)

def _post_async_init_notify_async_condition(self, condition: asyncio.Condition, threadsafe: bool):
""" Replacement for '_notify_async_condition', after initializing the async parts """
async def f() -> None:
async with self._async_mutex:
self._async_not_empty.notify()
condition.notify()

def task_maker() -> None:
task = self._loop.create_task(f())
Expand All @@ -221,20 +420,17 @@ def task_maker() -> None:
else:
self._call_soon(task_maker)

def _notify_async_not_full(self, *, threadsafe: bool) -> None:
async def f() -> None:
async with self._async_mutex:
self._async_not_full.notify()
def _notify_async_condition(self, condition: asyncio.Condition, threadsafe: bool):
"""A single interface for notifying async conditions
def task_maker() -> None:
task = self._loop.create_task(f())
task.add_done_callback(self._pending.discard)
self._pending.add(task)
Useless until async parts are not initialized"""
...

if threadsafe:
self._call_soon_threadsafe(task_maker)
else:
self._call_soon(task_maker)
def _notify_async_not_empty(self, *, threadsafe: bool) -> None:
self._notify_async_condition(self._async_not_empty, threadsafe)

def _notify_async_not_full(self, *, threadsafe: bool) -> None:
self._notify_async_condition(self._async_not_full, threadsafe)

def _check_closing(self) -> None:
if self._closing:
Expand Down Expand Up @@ -272,14 +468,18 @@ def task_done(self) -> None:
Raises a ValueError if called more times than there were items
placed in the queue.
"""
def f():
with self._parent._all_tasks_done:
self._parent._finished.set()

self._parent._check_closing()
with self._parent._all_tasks_done:
unfinished = self._parent._unfinished_tasks - 1
if unfinished <= 0:
if unfinished < 0:
raise ValueError("task_done() called too many times")
self._parent._all_tasks_done.notify_all()
self._parent._loop.call_soon_threadsafe(self._parent._finished.set)
self._parent._loop.call_soon_threadsafe(f)
self._parent._unfinished_tasks = unfinished

def join(self) -> None:
Expand Down

0 comments on commit 01424a6

Please sign in to comment.