From 8c5547cf60bd065ebb1a43681b746e98d078206d Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Wed, 11 Oct 2023 21:59:27 -0700 Subject: [PATCH] Refactor once_per_instance to support async. It now re-uses the logic we have already implemented in _OnceBase to avoid re-inventing the wheel. --- once/__init__.py | 112 ++++++++++++++--------------------------------- once_test.py | 12 ++--- 2 files changed, 38 insertions(+), 86 deletions(-) diff --git a/once/__init__.py b/once/__init__.py index 66f7e10..0a9e64b 100644 --- a/once/__init__.py +++ b/once/__init__.py @@ -22,17 +22,17 @@ def _is_method(func: collections.abc.Callable): class _WrappedFunctionType(enum.Enum): - UNSUPPORTED = 0 - SYNC_FUNCTION = 1 - ASYNC_FUNCTION = 2 - SYNC_GENERATOR = 3 - ASYNC_GENERATOR = 4 + SYNC_FUNCTION = 0 + ASYNC_FUNCTION = 1 + SYNC_GENERATOR = 2 + ASYNC_GENERATOR = 3 def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionType: # The function inspect.isawaitable is a bit of a misnomer - it refers # to the awaitable result of an async function, not the async function # itself. + original_func = func while isinstance(func, functools.partial): # Work around inspect not functioning properly in python < 3.10 for partial functions. func = func.func @@ -42,10 +42,9 @@ def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionTy return _WrappedFunctionType.SYNC_GENERATOR if inspect.iscoroutinefunction(func): return _WrappedFunctionType.ASYNC_FUNCTION - # We assume it is a callable sync function if it is callable. - if not callable(func): - return _WrappedFunctionType.UNSUPPORTED - return _WrappedFunctionType.SYNC_FUNCTION + if inspect.isfunction(func): + return _WrappedFunctionType.SYNC_FUNCTION + raise SyntaxError(f"Unable to determine function type for {repr(original_func)}") class _ExecutionState(enum.Enum): @@ -57,14 +56,10 @@ class _ExecutionState(enum.Enum): class _OnceBase(abc.ABC): """Abstract Base Class for once function decorators.""" - def __init__(self, func: collections.abc.Callable) -> None: - functools.update_wrapper(self, func) - self.func = self._inspect_function(func) + def __init__(self, fn_type: _WrappedFunctionType) -> None: self.call_state = _ExecutionState.UNCALLED self.return_value: typing.Any = None - self.fn_type = _wrapped_function_type(self.func) - if self.fn_type == _WrappedFunctionType.UNSUPPORTED: - raise SyntaxError(f"Unable to wrap a {type(func)}") + self.fn_type = fn_type if ( self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION or self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR @@ -73,16 +68,6 @@ def __init__(self, func: collections.abc.Callable) -> None: else: self.lock = threading.Lock() - @abc.abstractmethod - def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable: - """Inspect the passed-in function to ensure it can be wrapped. - - This function should raise a SyntaxError if the passed-in function is - not suitable. - - It should return the function which should be executed once. - """ - def _callable(self, func: collections.abc.Callable): """Generate a wrapped function appropriate to the function type. @@ -110,6 +95,8 @@ def wrapped(*args, **kwargs): return self._execute_call_once_sync(func, *args, **kwargs) else: + if self.fn_type != _WrappedFunctionType.SYNC_GENERATOR: + print(self.fn_type) assert self.fn_type == _WrappedFunctionType.SYNC_GENERATOR def wrapped(*args, **kwargs): @@ -189,16 +176,6 @@ def _execute_call_once_sync_iter(self, func: collections.abc.Callable, *args, ** yield from self.return_value.yield_results() -class _OnceFn(_OnceBase): - def _inspect_function(self, func: collections.abc.Callable): - if _is_method(func): - raise SyntaxError( - "Attempting to use @once.once decorator on method " - "instead of @once.once_per_class or @once.once_per_instance" - ) - return func - - def once(func: collections.abc.Callable): """Decorator to ensure a function is only called once. @@ -218,7 +195,12 @@ def once(func: collections.abc.Callable): module and class level functions (i.e. non-closures), this means the return value will never be deleted. """ - once_obj = _OnceFn(func) + if _is_method(func): + raise SyntaxError( + "Attempting to use @once.once decorator on method " + "instead of @once.once_per_class or @once.once_per_instance" + ) + once_obj = _OnceBase(_wrapped_function_type(func)) return once_obj._callable(func) @@ -228,6 +210,10 @@ class once_per_class(_OnceBase): # pylint: disable=invalid-name is_classmethod: bool is_staticmethod: bool + def __init__(self, func: collections.abc.Callable) -> None: + self.func = self._inspect_function(func) + super().__init__(_wrapped_function_type(self.func)) + def _inspect_function(self, func: collections.abc.Callable): if not _is_method(func): raise SyntaxError( @@ -260,15 +246,14 @@ class once_per_instance(_OnceBase): # pylint: disable=invalid-name """A version of once for class methods which runs once per instance.""" def __init__(self, func: collections.abc.Callable) -> None: - super().__init__(func) - self.return_value: weakref.WeakKeyDictionary[ - typing.Any, typing.Any + self.func = self._inspect_function(func) + super().__init__(_wrapped_function_type(self.func)) + self.callables_lock = threading.Lock() + self.callables: weakref.WeakKeyDictionary[ + typing.Any, collections.abc.Callable ] = weakref.WeakKeyDictionary() - self.inflight_lock: typing.Dict[typing.Any, threading.Lock] = {} def _inspect_function(self, func: collections.abc.Callable): - if inspect.isasyncgenfunction(func) or inspect.iscoroutinefunction(func): - raise SyntaxError("once_per_instance not (yet) supported for async") if isinstance(func, (classmethod, staticmethod)): raise SyntaxError("Must use @once.once_per_class on classmethod and staticmethod") if not _is_method(func): @@ -282,39 +267,10 @@ def _inspect_function(self, func: collections.abc.Callable): # bound version of the function to the object. def __get__(self, obj, cls): del cls - return functools.partial(self._execute_call_once_per_instance, obj) - - def _execute_call_once_per_instance(self, obj, *args, **kwargs): - # We only append to the call history, and do not overwrite or remove keys. - # Therefore, we can check the call history without a lock for an early - # exit. - # Another concern might be the weakref dictionary for return_value - # getting garbage collected without a lock. However, because - # user_function references whichever key it matches, it cannot be - # garbage collected during this call. - if obj in self.return_value: - return self.return_value[obj] - with self.lock: - if obj in self.return_value: - return self.return_value[obj] - if obj in self.inflight_lock: - inflight_lock = self.inflight_lock[obj] - else: - inflight_lock = threading.Lock() - self.inflight_lock[obj] = inflight_lock - # Now we have a per-object lock. This means that we will not block - # other instances. In addition to better performance, this reduces the - # potential for deadlocks. - with inflight_lock: - if obj in self.return_value: - return self.return_value[obj] - result = self.func(obj, *args, **kwargs) - self.return_value[obj] = result - # At this point, any new call will find a cache hit before - # even grabbing a lock. It is now safe to clean up the inflight - # lock entry from the dictionary, as all subsequent will not need - # it. Any other previously called inflight requests already have - # their reference to the lock object, and do not need it present - # in this dict either. - self.inflight_lock.pop(obj) - return result + with self.callables_lock: + if callable := self.callables.get(obj): + return callable + once_obj = _OnceBase(self.fn_type) + callable = once_obj._callable(functools.partial(self.func, obj)) + self.callables[obj] = callable + return callable diff --git a/once_test.py b/once_test.py index 470939d..18c0fa7 100644 --- a/once_test.py +++ b/once_test.py @@ -181,10 +181,8 @@ def sample_sync_generator_fn(_1, _2): once._WrappedFunctionType.SYNC_GENERATOR, ) # The output of a sync generator is not a wrappable. - self.assertEqual( - once._wrapped_function_type(sample_sync_generator_fn(1, 2)), - once._WrappedFunctionType.UNSUPPORTED, - ) + with self.assertRaises(SyntaxError): + once._wrapped_function_type(sample_sync_generator_fn(1, 2)) async def sample_async_generator_method(self, _): yield 1 @@ -224,10 +222,8 @@ async def sample_async_generator_fn(_1, _2): once._WrappedFunctionType.ASYNC_GENERATOR, ) # The output of an async generator is not a wrappable. - self.assertEqual( - once._wrapped_function_type(sample_async_generator_fn(1, 2)), - once._WrappedFunctionType.UNSUPPORTED, - ) + with self.assertRaises(SyntaxError): + once._wrapped_function_type(sample_async_generator_fn(1, 2)) class TestOnce(unittest.TestCase):