Skip to content

Commit

Permalink
fix(utility): add lock for cached file between processes
Browse files Browse the repository at this point in the history
When using PyTorch dataloader with multiple workers, the cache file will
be accessed by different processes. And the cache file can be written
and read at the same time, which would lead to the read content is
empty.

Thus, lock is added when writing the cache file to fix this
problem.
  • Loading branch information
Lee-000 committed Dec 15, 2021
1 parent 0561368 commit a7b26ef
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 18 deletions.
80 changes: 69 additions & 11 deletions tensorbay/utility/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
"""


from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from multiprocessing import Manager
from threading import Lock
from typing import Any, Callable, DefaultDict, Sequence, Type, TypeVar, Union
from typing import Any, Callable, DefaultDict, Dict, Iterator, Sequence, Type, TypeVar, Union

import numpy as np

Expand Down Expand Up @@ -56,7 +59,16 @@ def __eq__(self, other: object) -> bool:
return self.__dict__ == other.__dict__


locks: DefaultDict[int, Lock] = defaultdict(Lock)
@contextmanager
def _acquire(lock: Lock) -> Iterator[bool]:
acquire = lock.acquire(blocking=False)
yield acquire
if not acquire:
lock.acquire()
lock.release()


thread_locks: DefaultDict[int, Lock] = defaultdict(Lock)


def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
Expand All @@ -73,15 +85,61 @@ def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
@wraps(func)
def wrapper(self: Any, *arg: Any, **kwargs: Any) -> None:
key = id(self)
lock = locks[key]
acquire = lock.acquire(blocking=False)
try:
if acquire:
lock = thread_locks[key]
with _acquire(lock) as success:
if success:
func(self, *arg, **kwargs)
del locks[key]
else:
lock.acquire()
finally:
lock.release()
del thread_locks[key]

return wrapper # type: ignore[return-value]


class ProcessLocked: # pylint: disable=too-few-public-methods
"""A decorator to add lock for methods called from different processes.
Arguments:
attr_name: The name of the attr to be taken as the key of the lock.
"""

_manager = Manager()
_lock = _manager.Lock() # pylint: disable=no-member

process_locks: Dict[str, Lock] = _manager.dict()

def __init__(
self, attr_name: str, exit_condition: Callable[[str], bool] = lambda _: False
) -> None:
self._attr_name = attr_name
self._exit_condition = exit_condition

def __call__(self, func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
"""Return the locked function.
Arguments:
func: The function to add lock.
Returns:
The locked function.
"""

@wraps(func)
def wrapper(func_self: Any, *arg: Any, **kwargs: Any) -> None:
key = getattr(func_self, self._attr_name)
with self._lock:
# https://github.com/PyCQA/pylint/issues/3313
lock = self.process_locks.setdefault(
key, self._manager.Lock() # pylint: disable=no-member
)

with _acquire(lock) as success:
if success and not self._exit_condition(key):
func(func_self, *arg, **kwargs)
del self.process_locks[key]

with self._lock:
if key in self.process_locks:
del self.process_locks[key]

return wrapper # type: ignore[return-value]
18 changes: 11 additions & 7 deletions tensorbay/utility/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from _io import BufferedReader

from tensorbay.exception import ResponseError
from tensorbay.utility.common import ProcessLocked
from tensorbay.utility.repr import ReprMixin
from tensorbay.utility.requests import UserResponse, config, get_session

Expand Down Expand Up @@ -176,6 +177,14 @@ def _urlopen(self) -> UserResponse:
)
raise

@ProcessLocked("cache_path", os.path.exists)
def _write_cache(self, cache_path: str) -> None:
dirname = os.path.dirname(cache_path)
os.makedirs(dirname, exist_ok=True)
with self._urlopen() as fp:
with open(cache_path, "wb") as cache:
cache.write(fp.read())

def get_url(self) -> str:
"""Return the url of the data hosted by tensorbay.
Expand Down Expand Up @@ -204,12 +213,7 @@ def open(self) -> Union[UserResponse, BufferedReader]:
if not cache_path:
return self._urlopen()

if not os.path.exists(cache_path):
dirname = os.path.dirname(cache_path)
os.makedirs(dirname, exist_ok=True)

with self._urlopen() as fp:
with open(cache_path, "wb") as cache:
cache.write(fp.read())
if not os.path.exists(cache_path) or cache_path in ProcessLocked.process_locks:
self._write_cache(cache_path)

return open(cache_path, "rb")

0 comments on commit a7b26ef

Please sign in to comment.