From b50bae0862e7410a12075559aeb6a06414e4dd2f Mon Sep 17 00:00:00 2001 From: "yexuan.li" Date: Mon, 6 Dec 2021 20:51:24 +0800 Subject: [PATCH] fix(utility): add lock for cached file between processes When using PyTorch dataloader with multiple workers, the cache file will be accessed by different processes. And the cache file can be written and read at the same time, which would lead to the read content is empty. Thus, lock is added when writing the cache file to fix this problem. --- tensorbay/utility/common.py | 71 +++++++++++++++++++++++++++++++------ tensorbay/utility/file.py | 19 +++++----- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/tensorbay/utility/common.py b/tensorbay/utility/common.py index 5709892ff..4a70695a9 100644 --- a/tensorbay/utility/common.py +++ b/tensorbay/utility/common.py @@ -12,10 +12,13 @@ """ + from collections import defaultdict +from contextlib import contextmanager from functools import wraps +from multiprocessing import Manager from threading import Lock -from typing import Any, Callable, DefaultDict, Sequence, Type, TypeVar, Union +from typing import Any, Callable, DefaultDict, Dict, Iterator, Sequence, Type, TypeVar, Union import numpy as np @@ -56,7 +59,18 @@ def __eq__(self, other: object) -> bool: return self.__dict__ == other.__dict__ -locks: DefaultDict[int, Lock] = defaultdict(Lock) +@contextmanager +def _acquire(lock: Lock) -> Iterator[bool]: + acquire = lock.acquire(blocking=False) + try: + yield acquire + if not acquire: + lock.acquire() + finally: + lock.release() + + +thread_locks: DefaultDict[int, Lock] = defaultdict(Lock) def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue: @@ -73,15 +87,50 @@ def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue: @wraps(func) def wrapper(self: Any, *arg: Any, **kwargs: Any) -> None: key = id(self) - lock = locks[key] - acquire = lock.acquire(blocking=False) - try: - if acquire: + lock = thread_locks[key] + with _acquire(lock) as success: + if success: func(self, *arg, **kwargs) - del locks[key] - else: - lock.acquire() - finally: - lock.release() + del thread_locks[key] return wrapper # type: ignore[return-value] + + +class ProcessLocked: # pylint: disable=too-few-public-methods + """A decorator to add lock for methods called from different processes. + + Arguments: + attr_name: The name of the attr to be taken as the key of the lock. + + """ + + _manager = Manager() + _process_locks: Dict[str, Lock] = _manager.dict() + + def __init__(self, attr_name: str) -> None: + self._attr_name = attr_name + + def __call__(self, func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue: + """Return the locked function. + + Arguments: + func: The function to add lock. + + Returns: + The locked function. + + """ + + @wraps(func) + def wrapper(func_self: Any, *arg: Any, **kwargs: Any) -> None: + key = getattr(func_self, self._attr_name) + # https://github.com/PyCQA/pylint/issues/3313 + lock = self._process_locks.setdefault( + key, self._manager.Lock() # pylint: disable=no-member + ) + with _acquire(lock) as success: + if success: + func(func_self, *arg, **kwargs) + del self._process_locks[key] + + return wrapper # type: ignore[return-value] diff --git a/tensorbay/utility/file.py b/tensorbay/utility/file.py index b6d03ef07..81a1389a8 100644 --- a/tensorbay/utility/file.py +++ b/tensorbay/utility/file.py @@ -14,6 +14,7 @@ from _io import BufferedReader from tensorbay.exception import ResponseError +from tensorbay.utility.common import ProcessLocked from tensorbay.utility.repr import ReprMixin from tensorbay.utility.requests import UserResponse, config, get_session @@ -176,6 +177,15 @@ def _urlopen(self) -> UserResponse: ) raise + @ProcessLocked("cache_path") + def _write_cache(self, cache_path: str) -> None: + if not os.path.exists(cache_path): + dirname = os.path.dirname(cache_path) + os.makedirs(dirname, exist_ok=True) + with self._urlopen() as fp: + with open(cache_path, "wb") as cache: + cache.write(fp.read()) + def get_url(self) -> str: """Return the url of the data hosted by tensorbay. @@ -204,12 +214,5 @@ def open(self) -> Union[UserResponse, BufferedReader]: if not cache_path: return self._urlopen() - if not os.path.exists(cache_path): - dirname = os.path.dirname(cache_path) - os.makedirs(dirname, exist_ok=True) - - with self._urlopen() as fp: - with open(cache_path, "wb") as cache: - cache.write(fp.read()) - + self._write_cache(cache_path) return open(cache_path, "rb")