Skip to content

Commit

Permalink
Refactor once_per_instance to support async.
Browse files Browse the repository at this point in the history
It now re-uses the logic we have already implemented in _OnceBase to
avoid re-inventing the wheel.
  • Loading branch information
aebrahim committed Oct 12, 2023
1 parent a966c95 commit 8c5547c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 86 deletions.
112 changes: 34 additions & 78 deletions once/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ def _is_method(func: collections.abc.Callable):


class _WrappedFunctionType(enum.Enum):
UNSUPPORTED = 0
SYNC_FUNCTION = 1
ASYNC_FUNCTION = 2
SYNC_GENERATOR = 3
ASYNC_GENERATOR = 4
SYNC_FUNCTION = 0
ASYNC_FUNCTION = 1
SYNC_GENERATOR = 2
ASYNC_GENERATOR = 3


def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionType:
# The function inspect.isawaitable is a bit of a misnomer - it refers
# to the awaitable result of an async function, not the async function
# itself.
original_func = func
while isinstance(func, functools.partial):
# Work around inspect not functioning properly in python < 3.10 for partial functions.
func = func.func
Expand All @@ -42,10 +42,9 @@ def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionTy
return _WrappedFunctionType.SYNC_GENERATOR
if inspect.iscoroutinefunction(func):
return _WrappedFunctionType.ASYNC_FUNCTION
# We assume it is a callable sync function if it is callable.
if not callable(func):
return _WrappedFunctionType.UNSUPPORTED
return _WrappedFunctionType.SYNC_FUNCTION
if inspect.isfunction(func):
return _WrappedFunctionType.SYNC_FUNCTION
raise SyntaxError(f"Unable to determine function type for {repr(original_func)}")


class _ExecutionState(enum.Enum):
Expand All @@ -57,14 +56,10 @@ class _ExecutionState(enum.Enum):
class _OnceBase(abc.ABC):
"""Abstract Base Class for once function decorators."""

def __init__(self, func: collections.abc.Callable) -> None:
functools.update_wrapper(self, func)
self.func = self._inspect_function(func)
def __init__(self, fn_type: _WrappedFunctionType) -> None:
self.call_state = _ExecutionState.UNCALLED
self.return_value: typing.Any = None
self.fn_type = _wrapped_function_type(self.func)
if self.fn_type == _WrappedFunctionType.UNSUPPORTED:
raise SyntaxError(f"Unable to wrap a {type(func)}")
self.fn_type = fn_type
if (
self.fn_type == _WrappedFunctionType.ASYNC_FUNCTION
or self.fn_type == _WrappedFunctionType.ASYNC_GENERATOR
Expand All @@ -73,16 +68,6 @@ def __init__(self, func: collections.abc.Callable) -> None:
else:
self.lock = threading.Lock()

@abc.abstractmethod
def _inspect_function(self, func: collections.abc.Callable) -> collections.abc.Callable:
"""Inspect the passed-in function to ensure it can be wrapped.
This function should raise a SyntaxError if the passed-in function is
not suitable.
It should return the function which should be executed once.
"""

def _callable(self, func: collections.abc.Callable):
"""Generate a wrapped function appropriate to the function type.
Expand Down Expand Up @@ -110,6 +95,8 @@ def wrapped(*args, **kwargs):
return self._execute_call_once_sync(func, *args, **kwargs)

else:
if self.fn_type != _WrappedFunctionType.SYNC_GENERATOR:
print(self.fn_type)
assert self.fn_type == _WrappedFunctionType.SYNC_GENERATOR

def wrapped(*args, **kwargs):
Expand Down Expand Up @@ -189,16 +176,6 @@ def _execute_call_once_sync_iter(self, func: collections.abc.Callable, *args, **
yield from self.return_value.yield_results()


class _OnceFn(_OnceBase):
def _inspect_function(self, func: collections.abc.Callable):
if _is_method(func):
raise SyntaxError(
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
return func


def once(func: collections.abc.Callable):
"""Decorator to ensure a function is only called once.
Expand All @@ -218,7 +195,12 @@ def once(func: collections.abc.Callable):
module and class level functions (i.e. non-closures), this means the return
value will never be deleted.
"""
once_obj = _OnceFn(func)
if _is_method(func):
raise SyntaxError(
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
once_obj = _OnceBase(_wrapped_function_type(func))
return once_obj._callable(func)


Expand All @@ -228,6 +210,10 @@ class once_per_class(_OnceBase): # pylint: disable=invalid-name
is_classmethod: bool
is_staticmethod: bool

def __init__(self, func: collections.abc.Callable) -> None:
self.func = self._inspect_function(func)
super().__init__(_wrapped_function_type(self.func))

def _inspect_function(self, func: collections.abc.Callable):
if not _is_method(func):
raise SyntaxError(
Expand Down Expand Up @@ -260,15 +246,14 @@ class once_per_instance(_OnceBase): # pylint: disable=invalid-name
"""A version of once for class methods which runs once per instance."""

def __init__(self, func: collections.abc.Callable) -> None:
super().__init__(func)
self.return_value: weakref.WeakKeyDictionary[
typing.Any, typing.Any
self.func = self._inspect_function(func)
super().__init__(_wrapped_function_type(self.func))
self.callables_lock = threading.Lock()
self.callables: weakref.WeakKeyDictionary[
typing.Any, collections.abc.Callable
] = weakref.WeakKeyDictionary()
self.inflight_lock: typing.Dict[typing.Any, threading.Lock] = {}

def _inspect_function(self, func: collections.abc.Callable):
if inspect.isasyncgenfunction(func) or inspect.iscoroutinefunction(func):
raise SyntaxError("once_per_instance not (yet) supported for async")
if isinstance(func, (classmethod, staticmethod)):
raise SyntaxError("Must use @once.once_per_class on classmethod and staticmethod")
if not _is_method(func):
Expand All @@ -282,39 +267,10 @@ def _inspect_function(self, func: collections.abc.Callable):
# bound version of the function to the object.
def __get__(self, obj, cls):
del cls
return functools.partial(self._execute_call_once_per_instance, obj)

def _execute_call_once_per_instance(self, obj, *args, **kwargs):
# We only append to the call history, and do not overwrite or remove keys.
# Therefore, we can check the call history without a lock for an early
# exit.
# Another concern might be the weakref dictionary for return_value
# getting garbage collected without a lock. However, because
# user_function references whichever key it matches, it cannot be
# garbage collected during this call.
if obj in self.return_value:
return self.return_value[obj]
with self.lock:
if obj in self.return_value:
return self.return_value[obj]
if obj in self.inflight_lock:
inflight_lock = self.inflight_lock[obj]
else:
inflight_lock = threading.Lock()
self.inflight_lock[obj] = inflight_lock
# Now we have a per-object lock. This means that we will not block
# other instances. In addition to better performance, this reduces the
# potential for deadlocks.
with inflight_lock:
if obj in self.return_value:
return self.return_value[obj]
result = self.func(obj, *args, **kwargs)
self.return_value[obj] = result
# At this point, any new call will find a cache hit before
# even grabbing a lock. It is now safe to clean up the inflight
# lock entry from the dictionary, as all subsequent will not need
# it. Any other previously called inflight requests already have
# their reference to the lock object, and do not need it present
# in this dict either.
self.inflight_lock.pop(obj)
return result
with self.callables_lock:
if callable := self.callables.get(obj):
return callable
once_obj = _OnceBase(self.fn_type)
callable = once_obj._callable(functools.partial(self.func, obj))
self.callables[obj] = callable
return callable
12 changes: 4 additions & 8 deletions once_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,8 @@ def sample_sync_generator_fn(_1, _2):
once._WrappedFunctionType.SYNC_GENERATOR,
)
# The output of a sync generator is not a wrappable.
self.assertEqual(
once._wrapped_function_type(sample_sync_generator_fn(1, 2)),
once._WrappedFunctionType.UNSUPPORTED,
)
with self.assertRaises(SyntaxError):
once._wrapped_function_type(sample_sync_generator_fn(1, 2))

async def sample_async_generator_method(self, _):
yield 1
Expand Down Expand Up @@ -224,10 +222,8 @@ async def sample_async_generator_fn(_1, _2):
once._WrappedFunctionType.ASYNC_GENERATOR,
)
# The output of an async generator is not a wrappable.
self.assertEqual(
once._wrapped_function_type(sample_async_generator_fn(1, 2)),
once._WrappedFunctionType.UNSUPPORTED,
)
with self.assertRaises(SyntaxError):
once._wrapped_function_type(sample_async_generator_fn(1, 2))


class TestOnce(unittest.TestCase):
Expand Down

0 comments on commit 8c5547c

Please sign in to comment.