Skip to content

Commit

Permalink
[BugFix] Track sub-tds in memmap (#719)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 25, 2024
1 parent 2dc0285 commit 62348af
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
14 changes: 12 additions & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,7 +1884,8 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None):
dest._is_shared = False # since they are mutually exclusive

for key, value in self.items():
if _is_tensor_collection(value.__class__):
type_value = type(value)
if _is_tensor_collection(type_value):
dest._tensordict[key] = value._memmap_(
prefix=prefix / key if prefix is not None else None,
copy_existing=copy_existing,
Expand All @@ -1894,6 +1895,10 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None):
like=like,
share_non_tensor=share_non_tensor,
)
if prefix is not None:
metadata[key] = {
"type": type_value.__name__,
}
continue
else:
# user did specify location and memmap is in wrong place, so we copy
Expand Down Expand Up @@ -1944,10 +1949,15 @@ def _load_memmap(cls, prefix: str, metadata: dict) -> T:

out = cls({}, batch_size=metadata.pop("shape"), device=metadata.pop("device"))

paths = set()
for key, entry_metadata in metadata.items():
if not isinstance(entry_metadata, dict):
# there can be other metadata
continue
type_value = entry_metadata.get("type", None)
if type_value is not None:
paths.add(key)
continue
dtype = entry_metadata.get("dtype", None)
shape = entry_metadata.get("shape", None)
if (
Expand All @@ -1970,7 +1980,7 @@ def _load_memmap(cls, prefix: str, metadata: dict) -> T:
)
# iterate over folders and load them
for path in prefix.iterdir():
if path.is_dir():
if path.is_dir() and path.parts[-1] in paths:
key = path.parts[len(prefix.parts) :]
out.set(key, TensorDict.load_memmap(path))
return out
Expand Down
2 changes: 2 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3101,6 +3101,8 @@ def test_memmap_(self, td_name, device, use_dir, tmpdir, num_threads):
)
assert td.is_memmap(), (td, td._is_memmap)
if use_dir:
# This would fail if we were not filtering out unregistered sub-folders
os.mkdir(Path(tmpdir) / "some_other_path")
assert_allclose_td(TensorDict.load_memmap(tmpdir), td)

@pytest.mark.parametrize("copy_existing", [False, True])
Expand Down

0 comments on commit 62348af

Please sign in to comment.