Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(utility): add lock for cached file between processes #1151

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions tensorbay/utility/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

"""Common tools."""


from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from multiprocessing import Manager
from threading import Lock
from typing import Any, Callable, DefaultDict, Sequence, Type, TypeVar, Union
from typing import Any, Callable, DefaultDict, Dict, Iterator, Sequence, Type, TypeVar, Union

import numpy as np

Expand Down Expand Up @@ -48,7 +51,16 @@ def __eq__(self, other: object) -> bool:
return self.__dict__ == other.__dict__


locks: DefaultDict[int, Lock] = defaultdict(Lock)
@contextmanager
def _acquire(lock: Lock) -> Iterator[bool]:
acquire = lock.acquire(blocking=False)
yield acquire
if not acquire:
lock.acquire()
lock.release()


thread_locks: DefaultDict[int, Lock] = defaultdict(Lock)


def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
Expand All @@ -65,15 +77,52 @@ def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
@wraps(func)
def wrapper(self: Any, *arg: Any, **kwargs: Any) -> None:
key = id(self)
lock = locks[key]
acquire = lock.acquire(blocking=False)
try:
if acquire:
lock = thread_locks[key]
with _acquire(lock) as success:
if success:
func(self, *arg, **kwargs)
del locks[key]
else:
lock.acquire()
finally:
lock.release()
del thread_locks[key]

return wrapper # type: ignore[return-value]


class ProcessLocked: # pylint: disable=too-few-public-methods
"""A decorator to add lock for methods called from different processes.
Arguments:
attr_name: The name of the attr to be taken as the key of the lock.
"""

_manager = Manager()

process_locks: Dict[str, Lock] = _manager.dict()

def __init__(self, attr_name: str) -> None:
self._attr_name = attr_name

def __call__(self, func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
"""Return the locked function.
Arguments:
func: The function to add lock.
Returns:
The locked function.
"""

@wraps(func)
def wrapper(func_self: Any, *arg: Any, **kwargs: Any) -> None:
Lee-000 marked this conversation as resolved.
Show resolved Hide resolved
key = getattr(func_self, self._attr_name)
# https://github.com/PyCQA/pylint/issues/3313
lock = self.process_locks.setdefault(
key, self._manager.Lock() # pylint: disable=no-member
)

with _acquire(lock) as success:
if success:
func(func_self, *arg, **kwargs)
del self.process_locks[key]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It that possible one process is deleting the lock from self.process_locks, and at the same time another process is checking whether the key in the self.process_locks?


return wrapper # type: ignore[return-value]
18 changes: 12 additions & 6 deletions tensorbay/utility/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from _io import BufferedReader

from tensorbay.exception import ResponseError
from tensorbay.utility.common import ProcessLocked
from tensorbay.utility.repr import ReprMixin
from tensorbay.utility.requests import UserResponse, config, get_session

Expand Down Expand Up @@ -179,6 +180,16 @@ def _urlopen(self) -> UserResponse:
)
raise

@ProcessLocked("cache_path")
def _write_cache(self, cache_path: str) -> None:
dirname = os.path.dirname(cache_path)
os.makedirs(dirname, exist_ok=True)
temp_path = f"{cache_path}.tensorbay.downloading"
with self._urlopen() as fp:
with open(temp_path, "wb") as cache:
cache.write(fp.read())
os.rename(temp_path, cache_path)

def get_url(self) -> str:
"""Return the url of the data hosted by tensorbay.

Expand Down Expand Up @@ -208,11 +219,6 @@ def open(self) -> Union[UserResponse, BufferedReader]:
return self._urlopen()

if not os.path.exists(cache_path):
dirname = os.path.dirname(cache_path)
os.makedirs(dirname, exist_ok=True)

with self._urlopen() as fp:
with open(cache_path, "wb") as cache:
cache.write(fp.read())
self._write_cache(cache_path)

return open(cache_path, "rb")