Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 24, 2024
1 parent 6921523 commit b682d11
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 85 deletions.
4 changes: 2 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
from torchrl.envs.utils import check_env_specs, step_mdp
from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal

from tensordict.utils import assert_allclose_td
TIMEOUT = 100.0

_has_gymnasium = importlib.util.find_spec("gymnasium") is not None
Expand Down Expand Up @@ -797,7 +797,7 @@ def test_transform_model(self, dim, N, padding):
v1 = model(tdbase0)
v2 = model(tdbase0_copy)
# check that swapping dims and names leads to same result
assert (v1 == v2.transpose(0, 1)).all()
assert_allclose_td(v1, v2.transpose(0, 1))

@pytest.mark.parametrize("dim", [-2, -1])
@pytest.mark.parametrize("N", [3, 4])
Expand Down
7 changes: 4 additions & 3 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,10 @@ def __init__(
implement_for._setters.append(self)

@staticmethod
def check_version(version, from_version, to_version):
return (from_version is None or parse(version) >= parse(from_version)) and (
to_version is None or parse(version) < parse(to_version)
def check_version(version:str, from_version:str|None, to_version:str|None):
version = parse(".".join([str(v) for v in parse(version).release]))
return (from_version is None or version >= parse(from_version)) and (
to_version is None or version < parse(to_version)
)

@staticmethod
Expand Down
296 changes: 216 additions & 80 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from copy import copy
from multiprocessing.context import get_spawning_popen
from pathlib import Path
from typing import Any, Dict, List, Sequence, Union
from typing import Any, Dict, List, Sequence, Union, Callable

import numpy as np
import tensordict
Expand All @@ -23,17 +23,17 @@
from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase
from tensordict.utils import _STRDTYPE2DTYPE, expand_right
from torch import multiprocessing as mp
from torch.utils._pytree import (
MappingKey,
SequenceKey,
tree_flatten,
tree_map,
tree_map_with_path,
)
from packaging import version

from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE
from torchrl.data.replay_buffers.utils import INT_CLASSES

from torch.utils._pytree import (
tree_flatten,
tree_map,tree_unflatten,
LeafSpec,
)

try:
from torchsnapshot.serialization import tensor_from_memoryview

Expand Down Expand Up @@ -355,46 +355,7 @@ def dumps(self, path):
)
is_pytree = False
else:

def save_tensor(
tensor_path: tuple, tensor: torch.Tensor, metadata=metadata
):
tensor_path = _path2str(tensor_path)
if "." in tensor_path:
tensor_path.replace(".", "_<dot>_")
total_tensor_path = path / (tensor_path + ".memmap")
if os.path.exists(total_tensor_path):
MemoryMappedTensor.from_filename(
shape=tensor.shape,
filename=total_tensor_path,
dtype=tensor.dtype,
).copy_(tensor)
else:
os.makedirs(total_tensor_path.parent, exist_ok=True)
MemoryMappedTensor.from_tensor(
tensor,
filename=total_tensor_path,
copy_existing=True,
copy_data=True,
)
t = MemoryMappedTensor.from_filename(
filename=total_tensor_path,
dtype=tensor.dtype,
shape=tensor.shape,
)
assert (t == tensor).all()
key = tensor_path.replace("/", ".")
if key in metadata:
raise KeyError(
"At least two values have conflicting representations in "
f"the data structure to be serialized: {key}."
)
metadata[key] = {
"dtype": str(tensor.dtype),
"shape": list(tensor.shape),
}

tree_map_with_path(save_tensor, self._storage)
_save_pytree(self._storage, metadata, path)
is_pytree = True

with open(path / "storage_metadata.json", "w") as file:
Expand Down Expand Up @@ -919,38 +880,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
)
else:
# If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
def save_tensor(tensor_path: tuple, tensor: torch.Tensor):
tensor_path = _path2str(tensor_path)
if "." in tensor_path:
tensor_path.replace(".", "_<dot>_")
if self.scratch_dir is not None:
total_tensor_path = Path(self.scratch_dir) / (
tensor_path + ".memmap"
)
if os.path.exists(total_tensor_path):
raise RuntimeError(
f"The storage of tensor {total_tensor_path} already exists. "
f"To load an existing replay buffer, use storage.loads. "
f"Choose a different path to store your buffer or delete the existing files."
)
os.makedirs(total_tensor_path.parent, exist_ok=True)
else:
total_tensor_path = None
out = MemoryMappedTensor.empty(
shape=(self.max_size, *tensor.shape),
filename=total_tensor_path,
dtype=tensor.dtype,
)
if VERBOSE:
filesize = os.path.getsize(out.filename) / 1024 / 1024
logging.info(
f"The storage was created in {out.filename} and occupies {filesize} Mb of storage."
)
return out

out = tree_map_with_path(save_tensor, data)
out = _init_pytree(self.scratch_dir, self.max_size, data)
self._storage = out
self.initialized = True

Expand Down Expand Up @@ -1223,7 +1153,13 @@ def _make_empty_memmap(shape, dtype, path):
return MemoryMappedTensor.empty(shape=shape, dtype=dtype, filename=path)


@implement_for("torch", "2.3", None)
def _path2str(path, default_name=None):
from torch.utils._pytree import (
MappingKey,
SequenceKey,
)

if default_name is None:
default_name = SINGLE_TENSOR_BUFFER_NAME
if not path:
Expand All @@ -1243,3 +1179,203 @@ def _path2str(path, default_name=None):
return result
if isinstance(path, SequenceKey):
return str(path.idx)

@implement_for("torch", None, "2.3")
def _path2str(path, default_name=None):
raise RuntimeError
def get_paths(spec, cumulpath=""):
if isinstance(spec, LeafSpec):
yield cumulpath
contexts = spec.context
children_specs = spec.children_specs
if contexts is None:
contexts = range(len(children_specs))
for context, spec in zip(contexts, children_specs):
cpath = ".".join(
(cumulpath, str(context))
) if cumulpath else context
yield from get_paths(spec, cpath)


@implement_for("torch", "2.3", None)
def _save_pytree(_storage, metadata, path):
from torch.utils._pytree import tree_map_with_path

def save_tensor(
tensor_path: tuple, tensor: torch.Tensor, metadata=metadata
):
tensor_path = _path2str(tensor_path)
if "." in tensor_path:
tensor_path.replace(".", "_<dot>_")
total_tensor_path = path / (tensor_path + ".memmap")
if os.path.exists(total_tensor_path):
MemoryMappedTensor.from_filename(
shape=tensor.shape,
filename=total_tensor_path,
dtype=tensor.dtype,
).copy_(tensor)
else:
os.makedirs(total_tensor_path.parent, exist_ok=True)
MemoryMappedTensor.from_tensor(
tensor,
filename=total_tensor_path,
copy_existing=True,
copy_data=True,
)
t = MemoryMappedTensor.from_filename(
filename=total_tensor_path,
dtype=tensor.dtype,
shape=tensor.shape,
)
assert (t == tensor).all()
key = tensor_path.replace("/", ".")
if key in metadata:
raise KeyError(
"At least two values have conflicting representations in "
f"the data structure to be serialized: {key}."
)
metadata[key] = {
"dtype": str(tensor.dtype),
"shape": list(tensor.shape),
}

tree_map_with_path(save_tensor, _storage)

def get_paths(spec, cumulpath=""):
if isinstance(spec, LeafSpec):
yield cumulpath
contexts = spec.context
children_specs = spec.children_specs
if contexts is None:
contexts = range(len(children_specs))
for context, spec in zip(contexts, children_specs):
cpath = "/".join(
(cumulpath, str(context))
) if cumulpath else context
yield from get_paths(spec, cpath)

@implement_for("torch", None, "2.3")
def _save_pytree(_storage, metadata, path):

flat_storage, storage_specs = tree_flatten(_storage)
storage_paths = get_paths(storage_specs)

def save_tensor(
tensor_path: str, tensor: torch.Tensor, metadata=metadata
):
if "." in tensor_path:
tensor_path.replace(".", "_<dot>_")
total_tensor_path = path / (tensor_path + ".memmap")
if os.path.exists(total_tensor_path):
MemoryMappedTensor.from_filename(
shape=tensor.shape,
filename=total_tensor_path,
dtype=tensor.dtype,
).copy_(tensor)
else:
os.makedirs(total_tensor_path.parent, exist_ok=True)
MemoryMappedTensor.from_tensor(
tensor,
filename=total_tensor_path,
copy_existing=True,
copy_data=True,
)
t = MemoryMappedTensor.from_filename(
filename=total_tensor_path,
dtype=tensor.dtype,
shape=tensor.shape,
)
assert (t == tensor).all()
key = tensor_path.replace("/", ".")
if key in metadata:
raise KeyError(
"At least two values have conflicting representations in "
f"the data structure to be serialized: {key}."
)
metadata[key] = {
"dtype": str(tensor.dtype),
"shape": list(tensor.shape),
}

for tensor, path in flat_storage, storage_paths:
save_tensor(path, tensor)

@implement_for("torch", "2.3", None)
def _init_pytree(scratch_dir, max_size, data):
from torch.utils._pytree import tree_map_with_path

# If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
def save_tensor(tensor_path: tuple, tensor: torch.Tensor):
tensor_path = _path2str(tensor_path)
if "." in tensor_path:
tensor_path.replace(".", "_<dot>_")
if scratch_dir is not None:
total_tensor_path = Path(scratch_dir) / (
tensor_path + ".memmap"
)
if os.path.exists(total_tensor_path):
raise RuntimeError(
f"The storage of tensor {total_tensor_path} already exists. "
f"To load an existing replay buffer, use storage.loads. "
f"Choose a different path to store your buffer or delete the existing files."
)
os.makedirs(total_tensor_path.parent, exist_ok=True)
else:
total_tensor_path = None
out = MemoryMappedTensor.empty(
shape=(max_size, *tensor.shape),
filename=total_tensor_path,
dtype=tensor.dtype,
)
if VERBOSE:
filesize = os.path.getsize(out.filename) / 1024 / 1024
logging.info(
f"The storage was created in {out.filename} and occupies {filesize} Mb of storage."
)
return out

out = tree_map_with_path(save_tensor, data)
return out


@implement_for("torch", None, "2.3")
def _init_pytree(scratch_dir, max_size, data):

flat_data, data_specs = tree_flatten(data)
data_paths = get_paths(data_specs)

# If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
def save_tensor(tensor_path: str, tensor: torch.Tensor):
if "." in tensor_path:
tensor_path.replace(".", "_<dot>_")
if scratch_dir is not None:
total_tensor_path = Path(scratch_dir) / (
tensor_path + ".memmap"
)
if os.path.exists(total_tensor_path):
raise RuntimeError(
f"The storage of tensor {total_tensor_path} already exists. "
f"To load an existing replay buffer, use storage.loads. "
f"Choose a different path to store your buffer or delete the existing files."
)
os.makedirs(total_tensor_path.parent, exist_ok=True)
else:
total_tensor_path = None
out = MemoryMappedTensor.empty(
shape=(max_size, *tensor.shape),
filename=total_tensor_path,
dtype=tensor.dtype,
)
if VERBOSE:
filesize = os.path.getsize(out.filename) / 1024 / 1024
logging.info(
f"The storage was created in {out.filename} and occupies {filesize} Mb of storage."
)
return out
out = []
for tensor, path in flat_data, data_paths:
out.append(save_tensor(path, tensor))

return tree_flatten(out, data_specs)

0 comments on commit b682d11

Please sign in to comment.