diff --git a/tensorbay/utility/common.py b/tensorbay/utility/common.py index 5709892ff..4a70695a9 100644 --- a/tensorbay/utility/common.py +++ b/tensorbay/utility/common.py @@ -12,10 +12,13 @@ """ + from collections import defaultdict +from contextlib import contextmanager from functools import wraps +from multiprocessing import Manager from threading import Lock -from typing import Any, Callable, DefaultDict, Sequence, Type, TypeVar, Union +from typing import Any, Callable, DefaultDict, Dict, Iterator, Sequence, Type, TypeVar, Union import numpy as np @@ -56,7 +59,18 @@ def __eq__(self, other: object) -> bool: return self.__dict__ == other.__dict__ -locks: DefaultDict[int, Lock] = defaultdict(Lock) +@contextmanager +def _acquire(lock: Lock) -> Iterator[bool]: + acquire = lock.acquire(blocking=False) + try: + yield acquire + if not acquire: + lock.acquire() + finally: + lock.release() + + +thread_locks: DefaultDict[int, Lock] = defaultdict(Lock) def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue: @@ -73,15 +87,50 @@ def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue: @wraps(func) def wrapper(self: Any, *arg: Any, **kwargs: Any) -> None: key = id(self) - lock = locks[key] - acquire = lock.acquire(blocking=False) - try: - if acquire: + lock = thread_locks[key] + with _acquire(lock) as success: + if success: func(self, *arg, **kwargs) - del locks[key] - else: - lock.acquire() - finally: - lock.release() + del thread_locks[key] return wrapper # type: ignore[return-value] + + +class ProcessLocked: # pylint: disable=too-few-public-methods + """A decorator to add lock for methods called from different processes. + + Arguments: + attr_name: The name of the attr to be taken as the key of the lock. + + """ + + _manager = Manager() + _process_locks: Dict[str, Lock] = _manager.dict() + + def __init__(self, attr_name: str) -> None: + self._attr_name = attr_name + + def __call__(self, func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue: + """Return the locked function. + + Arguments: + func: The function to add lock. + + Returns: + The locked function. + + """ + + @wraps(func) + def wrapper(func_self: Any, *arg: Any, **kwargs: Any) -> None: + key = getattr(func_self, self._attr_name) + # https://github.com/PyCQA/pylint/issues/3313 + lock = self._process_locks.setdefault( + key, self._manager.Lock() # pylint: disable=no-member + ) + with _acquire(lock) as success: + if success: + func(func_self, *arg, **kwargs) + del self._process_locks[key] + + return wrapper # type: ignore[return-value] diff --git a/tensorbay/utility/file.py b/tensorbay/utility/file.py index b6d03ef07..81a1389a8 100644 --- a/tensorbay/utility/file.py +++ b/tensorbay/utility/file.py @@ -14,6 +14,7 @@ from _io import BufferedReader from tensorbay.exception import ResponseError +from tensorbay.utility.common import ProcessLocked from tensorbay.utility.repr import ReprMixin from tensorbay.utility.requests import UserResponse, config, get_session @@ -176,6 +177,15 @@ def _urlopen(self) -> UserResponse: ) raise + @ProcessLocked("cache_path") + def _write_cache(self, cache_path: str) -> None: + if not os.path.exists(cache_path): + dirname = os.path.dirname(cache_path) + os.makedirs(dirname, exist_ok=True) + with self._urlopen() as fp: + with open(cache_path, "wb") as cache: + cache.write(fp.read()) + def get_url(self) -> str: """Return the url of the data hosted by tensorbay. @@ -204,12 +214,5 @@ def open(self) -> Union[UserResponse, BufferedReader]: if not cache_path: return self._urlopen() - if not os.path.exists(cache_path): - dirname = os.path.dirname(cache_path) - os.makedirs(dirname, exist_ok=True) - - with self._urlopen() as fp: - with open(cache_path, "wb") as cache: - cache.write(fp.read()) - + self._write_cache(cache_path) return open(cache_path, "rb")