Skip to content

Commit

Permalink
fix(utility): add lock for cached file between processes
Browse files Browse the repository at this point in the history
When using PyTorch dataloader with multiple workers, the cache file will
be accessed by different processes. And the cache file can be written
and read at the same time, which would lead to the read content being
empty.

Thus, use a temp filename while writing the cache file and then change
it to the cache path. And add a lock to avoid repeatedly writing cache.

PR Closed: #1151
  • Loading branch information
Lee-000 committed Dec 28, 2021
1 parent 960c629 commit abb7b9c
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 17 deletions.
71 changes: 60 additions & 11 deletions tensorbay/utility/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

"""Common tools."""


from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from multiprocessing import Manager
from threading import Lock
from typing import Any, Callable, DefaultDict, Sequence, Type, TypeVar, Union
from typing import Any, Callable, DefaultDict, Dict, Iterator, Sequence, Type, TypeVar, Union

import numpy as np

Expand Down Expand Up @@ -48,7 +51,16 @@ def __eq__(self, other: object) -> bool:
return self.__dict__ == other.__dict__


locks: DefaultDict[int, Lock] = defaultdict(Lock)
@contextmanager
def _acquire(lock: Lock) -> Iterator[bool]:
acquire = lock.acquire(blocking=False)
yield acquire
if not acquire:
lock.acquire()
lock.release()


thread_locks: DefaultDict[int, Lock] = defaultdict(Lock)


def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
Expand All @@ -65,15 +77,52 @@ def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
@wraps(func)
def wrapper(self: Any, *arg: Any, **kwargs: Any) -> None:
key = id(self)
lock = locks[key]
acquire = lock.acquire(blocking=False)
try:
if acquire:
lock = thread_locks[key]
with _acquire(lock) as success:
if success:
func(self, *arg, **kwargs)
del locks[key]
else:
lock.acquire()
finally:
lock.release()
del thread_locks[key]

return wrapper # type: ignore[return-value]


class ProcessLocked: # pylint: disable=too-few-public-methods
"""A decorator to add lock for methods called from different processes.
Arguments:
attr_name: The name of the attr to be taken as the key of the lock.
"""

_manager = Manager()

process_locks: Dict[str, Lock] = _manager.dict()

def __init__(self, attr_name: str) -> None:
self._attr_name = attr_name

def __call__(self, func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
"""Return the locked function.
Arguments:
func: The function to add lock.
Returns:
The locked function.
"""

@wraps(func)
def wrapper(func_self: Any, *arg: Any, **kwargs: Any) -> None:
key = getattr(func_self, self._attr_name)
# https://github.com/PyCQA/pylint/issues/3313
lock = self.process_locks.setdefault(
key, self._manager.Lock() # pylint: disable=no-member
)

with _acquire(lock) as success:
if success:
func(func_self, *arg, **kwargs)
del self.process_locks[key]

return wrapper # type: ignore[return-value]
18 changes: 12 additions & 6 deletions tensorbay/utility/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from _io import BufferedReader

from tensorbay.exception import ResponseError
from tensorbay.utility.common import ProcessLocked
from tensorbay.utility.repr import ReprMixin
from tensorbay.utility.requests import UserResponse, config, get_session

Expand Down Expand Up @@ -179,6 +180,16 @@ def _urlopen(self) -> UserResponse:
)
raise

@ProcessLocked("cache_path")
def _write_cache(self, cache_path: str) -> None:
dirname = os.path.dirname(cache_path)
os.makedirs(dirname, exist_ok=True)
temp_path = f"{cache_path}.tensorbay.downloading"
with self._urlopen() as fp:
with open(temp_path, "wb") as cache:
cache.write(fp.read())
os.rename(temp_path, cache_path)

def get_url(self) -> str:
"""Return the url of the data hosted by tensorbay.
Expand Down Expand Up @@ -208,11 +219,6 @@ def open(self) -> Union[UserResponse, BufferedReader]:
return self._urlopen()

if not os.path.exists(cache_path):
dirname = os.path.dirname(cache_path)
os.makedirs(dirname, exist_ok=True)

with self._urlopen() as fp:
with open(cache_path, "wb") as cache:
cache.write(fp.read())
self._write_cache(cache_path)

return open(cache_path, "rb")

0 comments on commit abb7b9c

Please sign in to comment.