Skip to content

Commit

Permalink
[Feature] Prevent loading existing mmap files in storages if they alr…
Browse files Browse the repository at this point in the history
…eady exist

ghstack-source-id: 63bcb1e0420620d5dcd2b73d8e0a5b3bf137c8e1
Pull Request resolved: #2438
  • Loading branch information
vmoens committed Sep 17, 2024
1 parent 36545af commit 605b4aa
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
14 changes: 14 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,20 @@ def test_errors(self, storage_type):
):
storage_type(data, max_size=4)

def test_existsok_lazymemmap(self, tmpdir):
storage0 = LazyMemmapStorage(10, scratch_dir=tmpdir)
rb = ReplayBuffer(storage=storage0)
rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))

storage1 = LazyMemmapStorage(10, scratch_dir=tmpdir)
rb = ReplayBuffer(storage=storage1)
with pytest.raises(RuntimeError, match="existsok"):
rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))

storage2 = LazyMemmapStorage(10, scratch_dir=tmpdir, existsok=True)
rb = ReplayBuffer(storage=storage2)
rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))

@pytest.mark.parametrize(
"data_type", ["tensor", "tensordict", "tensorclass", "pytree"]
)
Expand Down
9 changes: 8 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,8 @@ class LazyMemmapStorage(LazyTensorStorage):
Args:
max_size (int): size of the storage, i.e. maximum number of elements stored
in the buffer.
Keyword Args:
scratch_dir (str or path): directory where memmap-tensors will be written.
device (torch.device, optional): device where the sampled tensors will be
stored and sent. Default is :obj:`torch.device("cpu")`.
Expand All @@ -933,6 +935,9 @@ class LazyMemmapStorage(LazyTensorStorage):
measuring the storage size. For instance, a storage of shape ``[3, 4]``
has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``.
Defaults to ``1``.
existsok (bool, optional): whether an error should be raised if any of the
tensors already exists on disk. Defaults to ``True``. If ``False``, the
tensor will be opened as is, not overewritten.
.. note:: When checkpointing a ``LazyMemmapStorage``, one can provide a path identical to where the storage is
already stored to avoid executing long copies of data that is already stored on disk.
Expand Down Expand Up @@ -1009,10 +1014,12 @@ def __init__(
scratch_dir=None,
device: torch.device = "cpu",
ndim: int = 1,
existsok: bool = False,
):
super().__init__(max_size, ndim=ndim)
self.initialized = False
self.scratch_dir = None
self.existsok = existsok
if scratch_dir is not None:
self.scratch_dir = str(scratch_dir)
if self.scratch_dir[-1] != "/":
Expand Down Expand Up @@ -1108,7 +1115,7 @@ def max_size_along_dim0(data_shape):
if is_tensor_collection(data):
out = data.clone().to(self.device)
out = out.expand(max_size_along_dim0(data.shape))
out = out.memmap_like(prefix=self.scratch_dir)
out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok)
for key, tensor in sorted(
out.items(include_nested=True, leaves_only=True), key=str
):
Expand Down

0 comments on commit 605b4aa

Please sign in to comment.