diff --git a/once/__init__.py b/once/__init__.py index 6838ba5..66f7e10 100644 --- a/once/__init__.py +++ b/once/__init__.py @@ -5,6 +5,7 @@ import enum import functools import inspect +import time import threading import typing import weakref @@ -12,10 +13,6 @@ from . import _iterator_wrappers -def _new_lock() -> threading.Lock: - return threading.Lock() - - def _is_method(func: collections.abc.Callable): """Determine if a function is a method on a class.""" if isinstance(func, (classmethod, staticmethod)): @@ -36,16 +33,25 @@ def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionTy # The function inspect.isawaitable is a bit of a misnomer - it refers # to the awaitable result of an async function, not the async function # itself. + while isinstance(func, functools.partial): + # Work around inspect not functioning properly in python < 3.10 for partial functions. + func = func.func if inspect.isasyncgenfunction(func): return _WrappedFunctionType.ASYNC_GENERATOR if inspect.isgeneratorfunction(func): return _WrappedFunctionType.SYNC_GENERATOR if inspect.iscoroutinefunction(func): return _WrappedFunctionType.ASYNC_FUNCTION - # This must come last, because it would return True for all the other types - if inspect.isfunction(func): - return _WrappedFunctionType.SYNC_FUNCTION - return _WrappedFunctionType.UNSUPPORTED + # We assume it is a callable sync function if it is callable. + if not callable(func): + return _WrappedFunctionType.UNSUPPORTED + return _WrappedFunctionType.SYNC_FUNCTION + + +class _ExecutionState(enum.Enum): + UNCALLED = 0 + WAITING = 1 + COMPLETED = 2 class _OnceBase(abc.ABC): @@ -54,15 +60,18 @@ class _OnceBase(abc.ABC): def __init__(self, func: collections.abc.Callable) -> None: functools.update_wrapper(self, func) self.func = self._inspect_function(func) - self.called = False + self.call_state = _ExecutionState.UNCALLED self.return_value: typing.Any = None self.fn_type = _wrapped_function_type(self.func) if self.fn_type == _WrappedFunctionType.UNSUPPORTED: raise SyntaxError(f"Unable to wrap a {type(func)}") - if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION: + if ( + self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION + or self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR + ): self.async_lock = asyncio.Lock() else: - self.lock = _new_lock() + self.lock = threading.Lock() @abc.abstractmethod def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable: @@ -74,61 +83,129 @@ def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.C It should return the function which should be executed once. """ + def _callable(self, func: collections.abc.Callable): + """Generate a wrapped function appropriate to the function type. + + This wrapped function will call the correct _execute_call_once function. + """ + if self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR: + + async def wrapped(*args, **kwargs): + next_value = None + iterator = self._execute_call_once_async_iter(func, *args, **kwargs) + while True: + try: + next_value = yield await iterator.asend(next_value) + except StopAsyncIteration: + return + + elif self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION: + + async def wrapped(*args, **kwargs): + return await self._execute_call_once_async(func, *args, **kwargs) + + elif self.fn_type == _WrappedFunctionType.SYNC_FUNCTION: + + def wrapped(*args, **kwargs): + return self._execute_call_once_sync(func, *args, **kwargs) + + else: + assert self.fn_type == _WrappedFunctionType.SYNC_GENERATOR + + def wrapped(*args, **kwargs): + yield from self._execute_call_once_sync_iter(func, *args, **kwargs) + + functools.update_wrapper(wrapped, func) + return wrapped + async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs): - if self.called: - return self.return_value async with self.async_lock: - if self.called: - return self.return_value - else: - self.return_value = await func(*args, **kwargs) - self.called = True - return self.return_value - - # This cannot be an async function! - def _execute_call_once_async_iter(self, func: collections.abc.Callable, *args, **kwargs): - if self.called: - return self.return_value.yield_results() - with self.lock: - if not self.called: - self.called = True + call_state = self.call_state + while call_state != _ExecutionState.COMPLETED: + if call_state == _ExecutionState.WAITING: + # Allow another thread to grab the GIL. + await asyncio.sleep(0) + async with self.async_lock: + call_state = self.call_state + if call_state == _ExecutionState.UNCALLED: + self.call_state = _ExecutionState.WAITING + # Only one thread will be allowed into this state. + if call_state == _ExecutionState.UNCALLED: + try: + return_value = await func(*args, **kwargs) + except Exception as exc: + async with self.async_lock: + self.call_state = _ExecutionState.UNCALLED + raise exc + async with self.async_lock: + self.return_value = return_value + self.call_state = _ExecutionState.COMPLETED + return self.return_value + + async def _execute_call_once_async_iter(self, func: collections.abc.Callable, *args, **kwargs): + async with self.async_lock: + if self.call_state == _ExecutionState.UNCALLED: self.return_value = _iterator_wrappers.AsyncGeneratorWrapper(func, *args, **kwargs) - return self.return_value.yield_results() - - def _sync_return(self): - if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR: - return self.return_value.yield_results().__iter__() - else: - return self.return_value + self.call_state = _ExecutionState.COMPLETED + next_value = None + iterator = self.return_value.yield_results() + while True: + try: + next_value = yield await iterator.asend(next_value) + except StopAsyncIteration: + return def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwargs): - if self.called: - return self._sync_return() with self.lock: - if self.called: - return self._sync_return() - if self.fn_type == _WrappedFunctionType.SYNC_GENERATOR: + call_state = self.call_state + while call_state != _ExecutionState.COMPLETED: + # We only hit this state in multi-threded code. To reduce contention, we invoke + # time.sleep so another thread an pick up the GIL. + if call_state == _ExecutionState.WAITING: + time.sleep(0) + with self.lock: + call_state = self.call_state + if call_state == _ExecutionState.UNCALLED: + self.call_state = _ExecutionState.WAITING + # Only one thread will be allowed into this state. + if call_state == _ExecutionState.UNCALLED: + try: + return_value = func(*args, **kwargs) + except Exception as exc: + with self.lock: + self.call_state = _ExecutionState.UNCALLED + raise exc + else: + with self.lock: + self.return_value = return_value + self.call_state = _ExecutionState.COMPLETED + return self.return_value + + def _execute_call_once_sync_iter(self, func: collections.abc.Callable, *args, **kwargs): + with self.lock: + if self.call_state == _ExecutionState.UNCALLED: self.return_value = _iterator_wrappers.GeneratorWrapper(func, *args, **kwargs) - else: - self.return_value = func(*args, **kwargs) - self.called = True - return self._sync_return() + self.call_state = _ExecutionState.COMPLETED + yield from self.return_value.yield_results() - def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs): - """Choose the appropriate call_once based on the function type.""" - if self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR: - return self._execute_call_once_async_iter(func, *args, **kwargs) - if self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION: - return self._execute_call_once_async(func, *args, **kwargs) - return self._execute_call_once_sync(func, *args, **kwargs) +class _OnceFn(_OnceBase): + def _inspect_function(self, func: collections.abc.Callable): + if _is_method(func): + raise SyntaxError( + "Attempting to use @once.once decorator on method " + "instead of @once.once_per_class or @once.once_per_instance" + ) + return func -class once(_OnceBase): # pylint: disable=invalid-name + +def once(func: collections.abc.Callable): """Decorator to ensure a function is only called once. The restriction of only one call also holds across threads. However, this restriction does not apply to unsuccessful function calls. If the function - raises an exception, the next call will invoke a new call to the function. + raises an exception, the next call will invoke a new call to the function, + unless it is in iterator, in which case the failure will be cached. If the function is called with multiple arguments, it will still only be called only once. @@ -141,17 +218,8 @@ class once(_OnceBase): # pylint: disable=invalid-name module and class level functions (i.e. non-closures), this means the return value will never be deleted. """ - - def _inspect_function(self, func: collections.abc.Callable): - if _is_method(func): - raise SyntaxError( - "Attempting to use @once.once decorator on method " - "instead of @once.once_per_class or @once.once_per_instance" - ) - return func - - def __call__(self, *args, **kwargs): - return self._execute_call_once(self.func, *args, **kwargs) + once_obj = _OnceFn(func) + return once_obj._callable(func) class once_per_class(_OnceBase): # pylint: disable=invalid-name @@ -182,11 +250,10 @@ def _inspect_function(self, func: collections.abc.Callable): # bound version of the function to the object or class. def __get__(self, obj, cls): if self.is_classmethod: - func = functools.partial(self.func, cls) - return functools.partial(self._execute_call_once, func) + return self._callable(functools.partial(self.func, cls)) if self.is_staticmethod: - return functools.partial(self._execute_call_once, self.func) - return functools.partial(self._execute_call_once, self.func, obj) + return self._callable(self.func) + return self._callable(functools.partial(self.func, obj)) class once_per_instance(_OnceBase): # pylint: disable=invalid-name @@ -233,7 +300,7 @@ def _execute_call_once_per_instance(self, obj, *args, **kwargs): if obj in self.inflight_lock: inflight_lock = self.inflight_lock[obj] else: - inflight_lock = _new_lock() + inflight_lock = threading.Lock() self.inflight_lock[obj] = inflight_lock # Now we have a per-object lock. This means that we will not block # other instances. In addition to better performance, this reduces the diff --git a/once/_iterator_wrappers.py b/once/_iterator_wrappers.py index 4ae533f..423298d 100644 --- a/once/_iterator_wrappers.py +++ b/once/_iterator_wrappers.py @@ -2,6 +2,7 @@ import collections.abc import enum import threading +import time # Before we begin, a note on the assert statements in this file: # Why are we using assert in here, you might ask, instead of implementing "proper" error handling? @@ -114,24 +115,26 @@ def __init__(self, func, *args, **kwargs) -> None: self.results: list = [] self.generating = False self.lock = threading.Lock() + self.exception: Exception | None = None def yield_results(self) -> collections.abc.Generator: # Fast path for subsequent repeated call: with self.lock: - finished = self.finished - if finished: + fast_path = self.finished and self.exception is None + if fast_path: yield from self.results return i = 0 yield_value = None next_send = None - # Fast path for subsequent calls will not require a lock while True: action: _IteratorAction | None = None # With a lock, we figure out which action to take, and then we take it after release. with self.lock: if i == len(self.results): if self.finished: + if self.exception: + raise self.exception return if self.generating: action = _IteratorAction.WAITING @@ -142,6 +145,7 @@ def yield_results(self) -> collections.abc.Generator: action = _IteratorAction.YIELDING yield_value = self.results[i] if action == _IteratorAction.WAITING: + time.sleep(0) continue if action == _IteratorAction.YIELDING: next_send = yield yield_value @@ -154,8 +158,14 @@ def yield_results(self) -> collections.abc.Generator: except StopIteration: with self.lock: self.finished = True + self.generating = False self.generator = None # Allow this to be GCed. + except Exception as e: + with self.lock: + self.finished = True self.generating = False + self.exception = e + self.generator = None # Allow this to be GCed. else: with self.lock: self.generating = False diff --git a/once_test.py b/once_test.py index 10fe922..470939d 100644 --- a/once_test.py +++ b/once_test.py @@ -2,14 +2,13 @@ # pylint: disable=missing-function-docstring import asyncio import concurrent.futures +import functools +import gc import inspect import sys import time import unittest -from unittest import mock -import threading import weakref -import gc import once @@ -26,17 +25,27 @@ async def anext(iter, default=StopAsyncIteration): return await iter.__anext__() +# This is a "large" number of workers to schedule function calls in parallel. +_N_WORKERS = 16 + + class Counter: """Holding object for a counter. If we return an integer directly, it will simply return a copy and will not update as the number of calls increases. + + The counter can also be paused by setting its paused attribute, which will be convenient to + start multiple runs to execute concurrently. """ def __init__(self) -> None: self.value = 0 + self.paused = False def get_incremented(self) -> int: + while self.paused: + pass self.value += 1 return self.value @@ -59,44 +68,98 @@ def counting_fn(*args) -> int: class TestFunctionInspection(unittest.TestCase): """Unit tests for function inspection""" - def sample_sync_method(self): + def sample_sync_method(self, _): return 1 + def test_sync_method(self): + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_sync_method), + once._WrappedFunctionType.SYNC_FUNCTION, + ) + self.assertEqual( + once._wrapped_function_type( + functools.partial(TestFunctionInspection.sample_sync_method, 1) + ), + once._WrappedFunctionType.SYNC_FUNCTION, + ) + def test_sync_function(self): - def sample_sync_fn(): + def sample_sync_fn(_1, _2): return 1 self.assertEqual( once._wrapped_function_type(sample_sync_fn), once._WrappedFunctionType.SYNC_FUNCTION ) self.assertEqual( - once._wrapped_function_type(TestFunctionInspection.sample_sync_method), + once._wrapped_function_type(once.once(sample_sync_fn)), + once._WrappedFunctionType.SYNC_FUNCTION, + ) + self.assertEqual( + once._wrapped_function_type(functools.partial(sample_sync_fn, 1)), + once._WrappedFunctionType.SYNC_FUNCTION, + ) + self.assertEqual( + once._wrapped_function_type(functools.partial(functools.partial(sample_sync_fn, 1), 2)), once._WrappedFunctionType.SYNC_FUNCTION, ) self.assertEqual( once._wrapped_function_type(lambda x: x + 1), once._WrappedFunctionType.SYNC_FUNCTION ) - async def sample_async_method(self): + async def sample_async_method(self, _): return 1 + def test_async_method(self): + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_async_method), + once._WrappedFunctionType.ASYNC_FUNCTION, + ) + self.assertEqual( + once._wrapped_function_type( + functools.partial(TestFunctionInspection.sample_async_method, 1) + ), + once._WrappedFunctionType.ASYNC_FUNCTION, + ) + def test_async_function(self): - async def sample_async_fn(): + async def sample_async_fn(_1, _2): return 1 self.assertEqual( once._wrapped_function_type(sample_async_fn), once._WrappedFunctionType.ASYNC_FUNCTION ) self.assertEqual( - once._wrapped_function_type(TestFunctionInspection.sample_async_method), + once._wrapped_function_type(once.once(sample_async_fn)), + once._WrappedFunctionType.ASYNC_FUNCTION, + ) + self.assertEqual( + once._wrapped_function_type(functools.partial(sample_async_fn, 1)), + once._WrappedFunctionType.ASYNC_FUNCTION, + ) + self.assertEqual( + once._wrapped_function_type( + functools.partial(functools.partial(sample_async_fn, 1), 2) + ), once._WrappedFunctionType.ASYNC_FUNCTION, ) - def sample_sync_generator_method(self): + def sample_sync_generator_method(self, _): yield 1 + def test_sync_generator_method(self): + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_sync_generator_method), + once._WrappedFunctionType.SYNC_GENERATOR, + ) + self.assertEqual( + once._wrapped_function_type( + functools.partial(TestFunctionInspection.sample_sync_generator_method, 1) + ), + once._WrappedFunctionType.SYNC_GENERATOR, + ) + def test_sync_generator_function(self): - def sample_sync_generator_fn(): + def sample_sync_generator_fn(_1, _2): yield 1 self.assertEqual( @@ -104,20 +167,42 @@ def sample_sync_generator_fn(): once._WrappedFunctionType.SYNC_GENERATOR, ) self.assertEqual( - once._wrapped_function_type(TestFunctionInspection.sample_sync_generator_method), + once._wrapped_function_type(once.once(sample_sync_generator_fn)), + once._WrappedFunctionType.SYNC_GENERATOR, + ) + self.assertEqual( + once._wrapped_function_type(functools.partial(sample_sync_generator_fn, 1)), + once._WrappedFunctionType.SYNC_GENERATOR, + ) + self.assertEqual( + once._wrapped_function_type( + functools.partial(functools.partial(sample_sync_generator_fn, 1), 2) + ), once._WrappedFunctionType.SYNC_GENERATOR, ) # The output of a sync generator is not a wrappable. self.assertEqual( - once._wrapped_function_type(sample_sync_generator_fn()), + once._wrapped_function_type(sample_sync_generator_fn(1, 2)), once._WrappedFunctionType.UNSUPPORTED, ) - async def sample_async_generator_method(self): + async def sample_async_generator_method(self, _): yield 1 - def test_sync_agenerator_function(self): - async def sample_async_generator_fn(): + def test_async_generator_method(self): + self.assertEqual( + once._wrapped_function_type(TestFunctionInspection.sample_async_generator_method), + once._WrappedFunctionType.ASYNC_GENERATOR, + ) + self.assertEqual( + once._wrapped_function_type( + functools.partial(TestFunctionInspection.sample_async_generator_method, 1) + ), + once._WrappedFunctionType.ASYNC_GENERATOR, + ) + + def test_async_generator_function(self): + async def sample_async_generator_fn(_1, _2): yield 1 self.assertEqual( @@ -125,12 +210,22 @@ async def sample_async_generator_fn(): once._WrappedFunctionType.ASYNC_GENERATOR, ) self.assertEqual( - once._wrapped_function_type(TestFunctionInspection.sample_async_generator_method), + once._wrapped_function_type(once.once(sample_async_generator_fn)), + once._WrappedFunctionType.ASYNC_GENERATOR, + ) + self.assertEqual( + once._wrapped_function_type(functools.partial(sample_async_generator_fn, 1)), + once._WrappedFunctionType.ASYNC_GENERATOR, + ) + self.assertEqual( + once._wrapped_function_type( + functools.partial(functools.partial(sample_async_generator_fn, 1)) + ), once._WrappedFunctionType.ASYNC_GENERATOR, ) # The output of an async generator is not a wrappable. self.assertEqual( - once._wrapped_function_type(sample_async_generator_fn()), + once._wrapped_function_type(sample_async_generator_fn(1, 2)), once._WrappedFunctionType.UNSUPPORTED, ) @@ -138,6 +233,14 @@ async def sample_async_generator_fn(): class TestOnce(unittest.TestCase): """Unit tests for once decorators.""" + def test_inspect_iterator(self): + @once.once + def yielding_iterator(): + for i in range(3): + yield i + + self.assertTrue(inspect.isgeneratorfunction(yielding_iterator)) + def test_counter_works(self): """Ensure the counter text fixture works.""" counter = Counter() @@ -155,6 +258,30 @@ def test_different_args_same_result(self): self.assertEqual(counting_fn(2), 1) self.assertEqual(counter.value, 1) + def test_partial(self): + counter = Counter() + func = once.once(functools.partial(lambda _: counter.get_incremented(), None)) + self.assertEqual(func(), 1) + self.assertEqual(func(), 1) + + def test_failing_function(self): + counter = Counter() + + @once.once + def sample_failing_fn(): + if counter.get_incremented() < 4: + raise ValueError("expected failure") + return 1 + + with self.assertRaises(ValueError): + sample_failing_fn() + self.assertEqual(counter.get_incremented(), 2) + with self.assertRaises(ValueError): + sample_failing_fn() + # This ensures that this was a new function call, not a cached result. + self.assertEqual(counter.get_incremented(), 4) + self.assertEqual(sample_failing_fn(), 1) + def test_iterator(self): counter = Counter() @@ -167,8 +294,24 @@ def yielding_iterator(): self.assertEqual(list(yielding_iterator()), [1, 2, 3]) self.assertEqual(list(yielding_iterator()), [1, 2, 3]) + def test_failing_generator(self): + counter = Counter() + + @once.once + def sample_failing_fn(): + yield counter.get_incremented() + raise ValueError("expected failure") + + with self.assertRaises(ValueError): + list(sample_failing_fn()) + with self.assertRaises(ValueError): + list(sample_failing_fn()) + self.assertEqual(next(sample_failing_fn()), 1) + self.assertEqual(next(sample_failing_fn()), 1) + def test_iterator_parallel_execution(self): counter = Counter() + counter.paused = True @once.once def yielding_iterator(): @@ -176,16 +319,44 @@ def yielding_iterator(): for _ in range(3): yield counter.get_incremented() - with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: - results = list(executor.map(lambda _: list(yielding_iterator()), range(32))) + with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: + results = executor.map(lambda _: list(yielding_iterator()), range(_N_WORKERS * 2)) + counter.paused = False # starter pistol, the race is off! for result in results: self.assertEqual(result, [1, 2, 3]) + def test_iterator_lock_not_held_during_evaluation(self): + counter = Counter() + counter.paused = False + + @once.once + def yielding_iterator(): + nonlocal counter + for _ in range(3): + yield counter.get_incremented() + + gen1 = yielding_iterator() + gen2 = yielding_iterator() + self.assertEqual(next(gen1), 1) + counter.paused = True + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + gen1_updater = executor.submit(next, gen1) + self.assertEqual(next(gen2), 1) + gen2_updater = executor.submit(next, gen2) + self.assertTrue(gen1_updater.running()) + self.assertTrue(gen2_updater.running()) + counter.paused = False + self.assertEqual(gen1_updater.result(), 2) + self.assertEqual(gen2_updater.result(), 2) + def test_threaded_single_function(self): counting_fn, counter = generate_once_counter_fn() - with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: - results = list(executor.map(counting_fn, range(32))) - self.assertEqual(len(results), 32) + counter.paused = True + with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: + results_generator = executor.map(counting_fn, range(_N_WORKERS * 2)) + counter.paused = False # starter pistol, the race is off! + results = list(results_generator) + self.assertEqual(len(results), _N_WORKERS * 2) for r in results: self.assertEqual(r, 1) self.assertEqual(counter.value, 1) @@ -196,14 +367,17 @@ def test_threaded_multiple_functions(self): for _ in range(4): cfn, counter = generate_once_counter_fn() + counter.paused = True counters.append(counter) fns.append(cfn) promises = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: for cfn in fns: - for _ in range(16): + for _ in range(_N_WORKERS): promises.append(executor.submit(cfn)) + for counter in counters: + counter.paused = False del cfn fns.clear() for promise in promises: @@ -270,45 +444,6 @@ def closure(): gc.collect() self.assertIsNone(ephemeral_ref()) - @mock.patch.object(once, "_new_lock") - def test_lock_bypass(self, lock_mocker) -> None: - """Test both with and without lock bypass cache lookup.""" - - # We mock the lock to return our specific lock, so we can specifically - # test behavior with it held. - lock = threading.Lock() - lock_mocker.return_value = lock - - counter = Counter() - - @once.once - def sample_fn() -> int: - nonlocal counter - return counter.get_incremented() - - with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: - with lock: - potential_first_call_promises = [executor.submit(sample_fn) for i in range(32)] - # Give the promises enough time to finish, if they were not blocked. - # The test will still pass without this, but we wouldn't be - # testing anything. - time.sleep(0.01) - # At this point, all of the promises will be waiting for the lock, - # and none of them will have completed. - for promise in potential_first_call_promises: - self.assertFalse(promise.done()) - # Now that we have released the lock, all of these should complete. - for promise in potential_first_call_promises: - self.assertEqual(promise.result(), 1) - self.assertEqual(counter.value, 1) - # Now that we know the function has already been called, we should - # be able to get a result without waiting for a lock. - with lock: - bypass_lock_promises = [executor.submit(sample_fn) for i in range(32)] - for promise in bypass_lock_promises: - self.assertEqual(promise.result(), 1) - self.assertEqual(counter.value, 1) - def test_function_signature_preserved(self): @once.once def type_annotated_fn(arg: float) -> int: @@ -386,9 +521,9 @@ def value(self): # pylint: disable=inconsistent-return-statements a = _CallOnceClass("a", self) # pylint: disable=invalid-name b = _CallOnceClass("b", self) # pylint: disable=invalid-name - with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: - a_jobs = [executor.submit(a.value) for _ in range(16)] - b_jobs = [executor.submit(b.value) for _ in range(16)] + with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: + a_jobs = [executor.submit(a.value) for _ in range(_N_WORKERS // 2)] + b_jobs = [executor.submit(b.value) for _ in range(_N_WORKERS // 2)] for a_job in a_jobs: self.assertEqual(a_job.result(), "a") for b_job in b_jobs: @@ -402,7 +537,6 @@ def value(self): # pylint: disable=inconsistent-return-statements def test_once_per_instance_do_not_block_each_other(self): class _BlockableClass: def __init__(self, test: unittest.TestCase): - self.lock = threading.Lock() self.test = test self.started = False self.counter = Counter() @@ -410,24 +544,23 @@ def __init__(self, test: unittest.TestCase): @once.once_per_instance def run(self) -> int: self.started = True - with self.lock: - pass return self.counter.get_incremented() a = _BlockableClass(self) + a.counter.paused = True b = _BlockableClass(self) with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: - with a.lock: - a_job = executor.submit(a.run) - while not a.started: - pass - # At this point, the A job has started. However, it cannot - # complete while we hold its lock. Despite this, we want to ensure - # that B can still run. - b_job = executor.submit(b.run) - # The b_job will deadlock and this will fail if different - # object executions block each other. - self.assertEqual(b_job.result(timeout=5), 1) + a_job = executor.submit(a.run) + while not a.started: + time.sleep(0) + # At this point, the A job has started. However, it cannot + # complete while paused. Despite this, we want to ensure + # that B can still run. + b_job = executor.submit(b.run) + # The b_job will deadlock and this will fail if different + # object executions block each other. + self.assertEqual(b_job.result(timeout=5), 1) + a.counter.paused = False self.assertEqual(a_job.result(timeout=5), 1) def test_once_per_class_classmethod(self): @@ -478,30 +611,45 @@ def receiving_iterator(): self.assertEqual(list(receiving_iterator()), [0, 1, 2, 5]) def test_receiving_iterator_parallel_execution(self): + # Pause so we actually are able to test parallel execution, by ensuring that each exec + # does not complete before the next one is scheduled. + paused = True + @once.once def receiving_iterator(): + nonlocal paused next = yield 0 while next is not None: + while paused: + pass next = yield next def call_iterator(_): gen = receiving_iterator() result = [] result.append(gen.send(None)) - for i in range(1, 32): + for i in range(1, _N_WORKERS * 4): result.append(gen.send(i)) return result - with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: - results = list(executor.map(call_iterator, range(32))) + with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: + results = executor.map(call_iterator, range(_N_WORKERS * 2)) + paused = False # starter pistol, the race is off! for result in results: - self.assertEqual(result, list(range(32))) + self.assertEqual(result, list(range(_N_WORKERS * 4))) def test_receiving_iterator_parallel_execution_halting(self): + # Pause so we actually are able to test parallel execution, by ensuring that each exec + # does not complete before the next one is scheduled. + paused = True + @once.once def receiving_iterator(): + nonlocal paused next = yield 0 while next is not None: + while paused: + pass next = yield next def call_iterator(n): @@ -515,8 +663,9 @@ def call_iterator(n): # Unlike the previous test, each execution should yield lists of different lengths. # This ensures that the iterator does not hang, even if not exhausted - with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: - results = list(executor.map(call_iterator, range(1, 32))) + with concurrent.futures.ThreadPoolExecutor(max_workers=_N_WORKERS) as executor: + results = executor.map(call_iterator, range(1, _N_WORKERS * 2)) + paused = False # starter pistol, the race is off! for i, result in enumerate(results): self.assertEqual(result, list(range(i + 1))) @@ -542,14 +691,31 @@ async def counting_fn2(): self.assertEqual(await counting_fn2(), 2) self.assertEqual(await counting_fn2(), 2) + async def test_failing_function(self): + counter = Counter() + + @once.once + async def sample_failing_fn(): + if counter.get_incremented() < 4: + raise ValueError("expected failure") + return 1 + + with self.assertRaises(ValueError): + await sample_failing_fn() + self.assertEqual(counter.get_incremented(), 2) + with self.assertRaises(ValueError): + await sample_failing_fn() + # This ensures that this was a new function call, not a cached result. + self.assertEqual(counter.get_incremented(), 4) + self.assertEqual(await sample_failing_fn(), 1) + async def test_inspect_func(self): @once.once async def async_func(): return True - # Unfortunately these are corrupted by our @once.once. - # self.assertFalse(inspect.isasyncgenfunction(async_func)) - # self.assertTrue(inspect.iscoroutinefunction(async_func)) + self.assertFalse(inspect.isasyncgenfunction(async_func)) + self.assertTrue(inspect.iscoroutinefunction(async_func)) coroutine = async_func() self.assertTrue(inspect.iscoroutine(coroutine)) @@ -565,9 +731,8 @@ async def async_yielding_iterator(): for i in range(3): yield i - # Unfortunately these are corrupted by our @once.once. - # self.assertTrue(inspect.isasyncgenfunction(async_yielding_iterator)) - # self.assertTrue(inspect.iscoroutinefunction(async_yielding_iterator)) + self.assertTrue(inspect.isasyncgenfunction(async_yielding_iterator)) + self.assertFalse(inspect.iscoroutinefunction(async_yielding_iterator)) coroutine = async_yielding_iterator() self.assertFalse(inspect.iscoroutine(coroutine)) @@ -589,6 +754,22 @@ async def async_yielding_iterator(): self.assertEqual([i async for i in async_yielding_iterator()], [1, 2, 3]) self.assertEqual([i async for i in async_yielding_iterator()], [1, 2, 3]) + @unittest.skip("This currently hangs and needs to be fixed, GitHub Issue #12") + async def test_failing_generator(self): + counter = Counter() + + @once.once + async def sample_failing_fn(): + yield counter.get_incremented() + raise ValueError("expected failure") + + with self.assertRaises(ValueError): + [i async for i in sample_failing_fn()] + with self.assertRaises(ValueError): + [i async for i in sample_failing_fn()] + self.assertEqual(await anext(sample_failing_fn()), 1) + self.assertEqual(await anext(sample_failing_fn()), 1) + async def test_iterator_is_lazily_evaluted(self): counter = Counter() @@ -636,6 +817,45 @@ async def async_receiving_iterator(): self.assertEqual(await anext(gen_1, None), None) self.assertEqual([i async for i in async_receiving_iterator()], [0, 1, 2, 5]) + async def test_receiving_iterator_parallel_execution(self): + @once.once + async def receiving_iterator(): + next = yield 0 + while next is not None: + next = yield next + + async def call_iterator(_): + gen = receiving_iterator() + result = [] + result.append(await gen.asend(None)) + for i in range(1, _N_WORKERS): + result.append(await gen.asend(i)) + return result + + results = map(call_iterator, range(_N_WORKERS)) + for result in results: + self.assertEqual(await result, list(range(_N_WORKERS))) + + async def test_receiving_iterator_parallel_execution_halting(self): + @once.once + async def receiving_iterator(): + next = yield 0 + while next is not None: + next = yield next + + async def call_iterator(n): + """Call the iterator but end early""" + gen = receiving_iterator() + result = [] + result.append(await gen.asend(None)) + for i in range(1, n): + result.append(await gen.asend(i)) + return result + + results = map(call_iterator, range(1, _N_WORKERS)) + for i, result in enumerate(results): + self.assertEqual(await result, list(range(i + 1))) + @unittest.skipIf(not hasattr(asyncio, "Barrier"), "Requires Barrier to evaluate") async def test_iterator_lock_not_held_during_evaluation(self): counter = Counter() @@ -696,6 +916,7 @@ async def value(cls): nonlocal counter return counter.get_incremented() + self.assertTrue(inspect.iscoroutinefunction(_CallOnceClass.value)) self.assertEqual(await _CallOnceClass.value(), 1) self.assertEqual(await _CallOnceClass.value(), 1)