Skip to content

Commit

Permalink
fix(utility): add lock for cached file between processes
Browse files Browse the repository at this point in the history
When using PyTorch dataloader with multiple workers, the cache file will
be accessed by different processes. And the cache file can be written
and read at the same time, which would lead to the read content is
empty.

Thus, lock is added when writing the cache file to fix this
problem.
  • Loading branch information
Lee-000 committed Dec 7, 2021
1 parent 3e6e2ee commit b50bae0
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 19 deletions.
71 changes: 60 additions & 11 deletions tensorbay/utility/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
"""


from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from multiprocessing import Manager
from threading import Lock
from typing import Any, Callable, DefaultDict, Sequence, Type, TypeVar, Union
from typing import Any, Callable, DefaultDict, Dict, Iterator, Sequence, Type, TypeVar, Union

import numpy as np

Expand Down Expand Up @@ -56,7 +59,18 @@ def __eq__(self, other: object) -> bool:
return self.__dict__ == other.__dict__


locks: DefaultDict[int, Lock] = defaultdict(Lock)
@contextmanager
def _acquire(lock: Lock) -> Iterator[bool]:
acquire = lock.acquire(blocking=False)
try:
yield acquire
if not acquire:
lock.acquire()
finally:
lock.release()


thread_locks: DefaultDict[int, Lock] = defaultdict(Lock)


def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
Expand All @@ -73,15 +87,50 @@ def locked(func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
@wraps(func)
def wrapper(self: Any, *arg: Any, **kwargs: Any) -> None:
key = id(self)
lock = locks[key]
acquire = lock.acquire(blocking=False)
try:
if acquire:
lock = thread_locks[key]
with _acquire(lock) as success:
if success:
func(self, *arg, **kwargs)
del locks[key]
else:
lock.acquire()
finally:
lock.release()
del thread_locks[key]

return wrapper # type: ignore[return-value]


class ProcessLocked: # pylint: disable=too-few-public-methods
"""A decorator to add lock for methods called from different processes.
Arguments:
attr_name: The name of the attr to be taken as the key of the lock.
"""

_manager = Manager()
_process_locks: Dict[str, Lock] = _manager.dict()

def __init__(self, attr_name: str) -> None:
self._attr_name = attr_name

def __call__(self, func: _CallableWithoutReturnValue) -> _CallableWithoutReturnValue:
"""Return the locked function.
Arguments:
func: The function to add lock.
Returns:
The locked function.
"""

@wraps(func)
def wrapper(func_self: Any, *arg: Any, **kwargs: Any) -> None:
key = getattr(func_self, self._attr_name)
# https://github.com/PyCQA/pylint/issues/3313
lock = self._process_locks.setdefault(
key, self._manager.Lock() # pylint: disable=no-member
)
with _acquire(lock) as success:
if success:
func(func_self, *arg, **kwargs)
del self._process_locks[key]

return wrapper # type: ignore[return-value]
19 changes: 11 additions & 8 deletions tensorbay/utility/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from _io import BufferedReader

from tensorbay.exception import ResponseError
from tensorbay.utility.common import ProcessLocked
from tensorbay.utility.repr import ReprMixin
from tensorbay.utility.requests import UserResponse, config, get_session

Expand Down Expand Up @@ -176,6 +177,15 @@ def _urlopen(self) -> UserResponse:
)
raise

@ProcessLocked("cache_path")
def _write_cache(self, cache_path: str) -> None:
if not os.path.exists(cache_path):
dirname = os.path.dirname(cache_path)
os.makedirs(dirname, exist_ok=True)
with self._urlopen() as fp:
with open(cache_path, "wb") as cache:
cache.write(fp.read())

def get_url(self) -> str:
"""Return the url of the data hosted by tensorbay.
Expand Down Expand Up @@ -204,12 +214,5 @@ def open(self) -> Union[UserResponse, BufferedReader]:
if not cache_path:
return self._urlopen()

if not os.path.exists(cache_path):
dirname = os.path.dirname(cache_path)
os.makedirs(dirname, exist_ok=True)

with self._urlopen() as fp:
with open(cache_path, "wb") as cache:
cache.write(fp.read())

self._write_cache(cache_path)
return open(cache_path, "rb")

0 comments on commit b50bae0

Please sign in to comment.