From 8c6b1924df2d738f34cbba505754a8ef9c2235b5 Mon Sep 17 00:00:00 2001 From: "yexuan.li" Date: Mon, 6 Dec 2021 20:51:24 +0800 Subject: [PATCH] fix(utility): add lock for cached file between processes When using PyTorch dataloader with multiple workers, the cache file will be accessed by different processes. And the cache file can be written and read at the same time, which would lead to the read content being empty. Thus, use a temp filename while writing the cache file and then change it to the cache path. And add a lock to avoid repeatedly writing cache. PR Closed: https://github.com/Graviti-AI/tensorbay-python-sdk/pull/1151 --- tensorbay/utility/common.py | 71 +++++++++++++++++++++++++++++++------ tensorbay/utility/file.py | 18 ++++++---- 2 files changed, 72 insertions(+), 17 deletions(-) diff --git a/tensorbay/utility/common.py b/tensorbay/utility/common.py index 56dbabec6..fce1f0e37 100644 --- a/tensorbay/utility/common.py +++ b/tensorbay/utility/common.py @@ -5,10 +5,13 @@ """Common tools.""" + from collections import defaultdict +from contextlib import contextmanager from functools import wraps +from multiprocessing import Manager from threading import Lock -from typing import Any, Callable, DefaultDict, Sequence, Type, TypeVar, Union +from typing import Any, Callable, DefaultDict, Dict, Iterator, Sequence, Type, TypeVar, Union import numpy as np @@ -48,7 +51,16 @@ def __eq__(self, other: object) -> bool: return self.__dict__ == other.__dict__ -locks: DefaultDict[int, Lock] = defaultdict(Lock) +@contextmanager +def _acquire(lock: Lock) -> Iterator[bool]: + acquire = lock.acquire(blocking=False) + yield acquire + if not acquire: + lock.acquire() + lock.release() + + +thread_locks: DefaultDict[int, Lock] = defaultdict(Lock) def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue: @@ -65,15 +77,52 @@ def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue: @wraps(func) def wrapper(self: Any, *arg: Any, **kwargs: Any) -> None: key = id(self) - lock = locks[key] - acquire = lock.acquire(blocking=False) - try: - if acquire: + lock = thread_locks[key] + with _acquire(lock) as success: + if success: func(self, *arg, **kwargs) - del locks[key] - else: - lock.acquire() - finally: - lock.release() + del thread_locks[key] return wrapper # type: ignore[return-value] + + +class ProcessLocked: # pylint: disable=too-few-public-methods + """A decorator to add lock for methods called from different processes. + + Arguments: + attr_name: The name of the attr to be taken as the key of the lock. + + """ + + _manager = Manager() + + process_locks: Dict[str, Lock] = _manager.dict() + + def __init__(self, attr_name: str) -> None: + self._attr_name = attr_name + + def __call__(self, func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue: + """Return the locked function. + + Arguments: + func: The function to add lock. + + Returns: + The locked function. + + """ + + @wraps(func) + def wrapper(func_self: Any, *arg: Any, **kwargs: Any) -> None: + key = getattr(func_self, self._attr_name) + # https://github.com/PyCQA/pylint/issues/3313 + lock = self.process_locks.setdefault( + key, self._manager.Lock() # pylint: disable=no-member + ) + + with _acquire(lock) as success: + if success: + func(func_self, *arg, **kwargs) + del self.process_locks[key] + + return wrapper # type: ignore[return-value] diff --git a/tensorbay/utility/file.py b/tensorbay/utility/file.py index d3977173c..7481279e1 100644 --- a/tensorbay/utility/file.py +++ b/tensorbay/utility/file.py @@ -14,6 +14,7 @@ from _io import BufferedReader from tensorbay.exception import ResponseError +from tensorbay.utility.common import ProcessLocked from tensorbay.utility.repr import ReprMixin from tensorbay.utility.requests import UserResponse, config, get_session @@ -179,6 +180,16 @@ def _urlopen(self) -> UserResponse: ) raise + @ProcessLocked("cache_path") + def _write_cache(self, cache_path: str) -> None: + dirname = os.path.dirname(cache_path) + os.makedirs(dirname, exist_ok=True) + temp_path = f"{cache_path}.tensorbay.downloading" + with self._urlopen() as fp: + with open(temp_path, "wb") as cache: + cache.write(fp.read()) + os.rename(temp_path, cache_path) + def get_url(self) -> str: """Return the url of the data hosted by tensorbay. @@ -208,11 +219,6 @@ def open(self) -> Union[UserResponse, BufferedReader]: return self._urlopen() if not os.path.exists(cache_path): - dirname = os.path.dirname(cache_path) - os.makedirs(dirname, exist_ok=True) - - with self._urlopen() as fp: - with open(cache_path, "wb") as cache: - cache.write(fp.read()) + self._write_cache(cache_path) return open(cache_path, "rb")