From 059f53992d424a6324f2b4aa04c3d4e65dab5040 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 8 Mar 2024 15:08:46 +0000 Subject: [PATCH] [Performance] Faster update_ (#705) --- tensordict/_lazy.py | 3 + tensordict/_td.py | 7 +- tensordict/base.py | 217 +++++++++++---- tensordict/tensorclass.py | 550 ++++++++++++++++++++++++++++++++++++-- test/test_tensordict.py | 242 ++++++++++++++++- 5 files changed, 936 insertions(+), 83 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 68cb14e45..78abd97e5 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1413,6 +1413,7 @@ def _apply_nest( nested_keys: bool = False, prefix: tuple = (), filter_empty: bool | None = None, + is_leaf: Callable | None = None, **constructor_kwargs, ) -> T | None: if inplace and any( @@ -1438,6 +1439,7 @@ def _apply_nest( prefix=prefix, inplace=inplace, filter_empty=filter_empty, + is_leaf=is_leaf, **constructor_kwargs, ) @@ -1455,6 +1457,7 @@ def _apply_nest( prefix=prefix, # + (i,), inplace=inplace, filter_empty=filter_empty, + is_leaf=is_leaf, ) for i, (td, *oth) in enumerate(zip(self.tensordicts, *others)) ] diff --git a/tensordict/_td.py b/tensordict/_td.py index 87fa115bf..9430d2015 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -661,6 +661,7 @@ def _apply_nest( nested_keys: bool = False, prefix: tuple = (), filter_empty: bool | None = None, + is_leaf: Callable = None, **constructor_kwargs, ) -> T | None: if inplace: @@ -696,10 +697,13 @@ def make_result(): is_locked = False any_set = False + if is_leaf is None: + is_leaf = _default_is_leaf + for key, item in self.items(): if ( not call_on_nested - and _is_tensor_collection(item.__class__) + and not is_leaf(item.__class__) # and not is_non_tensor(item) ): if default is not NO_DEFAULT: @@ -725,6 +729,7 @@ def make_result(): default=default, prefix=prefix + (key,), filter_empty=filter_empty, + is_leaf=is_leaf, **constructor_kwargs, ) else: diff --git a/tensordict/base.py b/tensordict/base.py index 9c97098a7..6eaf4abf9 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -18,6 +18,7 @@ from concurrent.futures import ThreadPoolExecutor from copy import copy +from functools import wraps from pathlib import Path from textwrap import indent from typing import ( @@ -38,6 +39,7 @@ import numpy as np import torch from tensordict.utils import ( + _CloudpickleWrapper, _GENERIC_NESTED_ERR, _get_shape_from_args, _is_tensorclass, @@ -58,7 +60,7 @@ IndexType, infer_size_impl, int_generator, - KeyedJaggedTensor, + is_non_tensor, lazy_legacy, lock_blocked, NestedKey, @@ -104,12 +106,13 @@ def __bool__(self): class TensorDictBase(MutableMapping): """TensorDictBase is an abstract parent class for TensorDicts, a torch.Tensor data container.""" - _safe = False - _lazy = False - _inplace_set = False - is_meta = False - _is_locked = False - _cache = None + _safe: bool = False + _lazy: bool = False + _inplace_set: bool = False + is_meta: bool = False + _is_locked: bool = False + _cache: bool = None + _non_tensor: bool = False def __bool__(self) -> bool: raise RuntimeError("Converting a tensordict to boolean value is not permitted") @@ -240,10 +243,11 @@ def __getitem__(self, index: IndexType) -> T: idx_unravel = _unravel_key_to_tuple(index) if idx_unravel: result = self._get_tuple(idx_unravel, NO_DEFAULT) - from .tensorclass import NonTensorData - - if isinstance(result, NonTensorData): - return result.data + if is_non_tensor(result): + result_data = getattr(result, "data", NO_DEFAULT) + if result_data is NO_DEFAULT: + return result.tolist() + return result_data return result if (istuple and not index) or (not istuple and index is Ellipsis): @@ -377,6 +381,7 @@ def from_module( """Copies the params and buffers of a module in a tensordict. Args: + module (nn.Module): the module to get the parameters from. as_module (bool, optional): if ``True``, a :class:`~tensordict.nn.TensorDictParams` instance will be returned which can be used to store parameters within a :class:`torch.nn.Module`. Defaults to ``False``. @@ -386,7 +391,7 @@ def from_module( module will be used and unflattened into a TensorDict with the tree structure of the model. Defaults to ``False``. .. note:: - This is particularily useful when state-dict hooks have to be + This is particularly useful when state-dict hooks have to be used. Examples: @@ -847,7 +852,7 @@ def unsqueeze(self, *args, **kwargs): version of the unsqueezed tensordict. Up until v0.3 included, a lazy unsqueezed tensordict was returned. From v0.4 onward, a dense unsqueeze will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -936,7 +941,7 @@ def squeeze(self, *args, **kwargs): version of the squeezed tensordict. Up until v0.3 included, a lazy squeezed tensordict was returned. From v0.4 onward, a dense squeeze will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -1148,7 +1153,7 @@ def view( version of the viewed tensordict. Up until v0.3 included, a lazy view of the tensordict was returned. From v0.4 onward, a proper view will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -1217,7 +1222,7 @@ def transpose(self, dim0, dim1): version of the transposed tensordict. Up until v0.3 included, a lazy transpose of the tensordict was returned. From v0.4 onward, a proper transpose will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -1332,7 +1337,7 @@ def permute(self, *args, **kwargs): version of the permuted tensordict. Up until v0.3 included, a lazy permute of the tensordict was returned. From v0.4 onward, a proper permute will be returned. To silence this warning, choose one of the following options: -- set the LAZY_LEGACY_OP environment variable to 'False' (recommended) or 'True' depending on +- set the LAZY_LEGACY_OP environment variable to 'False' (recommended, default) or 'True' depending on the behaviour you want to use. Another way to achieve this is to call `tensordict.set_lazy_legacy(False).set()` at the beginning of your script. - set the decorator/context manager `tensordict.set_lazy_legacy(False)` (recommended) around @@ -1652,6 +1657,14 @@ def cuda(self, device: int = None) -> T: return self.to(torch.device("cuda")) return self.to(f"cuda:{device}") + @property + def is_cuda(self): + return self.device is not None and self.device.type == "cuda" + + @property + def is_cpu(self): + return self.device is not None and self.device.type == "cpu" + # Serialization functionality def state_dict( self, @@ -2146,7 +2159,8 @@ def load_metadata(filepath): return other_cls._load_memmap(prefix, metadata) else: raise RuntimeError( - f"Could not find name {type_name} in {tensordict.base._ACCEPTED_CLASSES}. Did you call _register_tensor_class(cls) on {type_name}?" + f"Could not find name {type_name} in {tensordict.base._ACCEPTED_CLASSES}. " + f"Did you call _register_tensor_class(cls) on {type_name}?" ) return cls._load_memmap(prefix, metadata) @@ -2313,23 +2327,24 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): return subtd return subtd._get_non_tensor(key[1:], default=default) value = self._get_str(key, default=default) - from tensordict.tensorclass import NonTensorData - if isinstance(value, NonTensorData): - return value.data + if is_non_tensor(value): + data = getattr(value, "data", None) + if data is None: + return value.tolist() + return data return value def filter_non_tensor_data(self) -> T: """Filters out all non-tensor-data.""" - from tensordict.tensorclass import NonTensorData def _filter(x): - if not isinstance(x, NonTensorData): + if not is_non_tensor(x): if is_tensor_collection(x): return x.filter_non_tensor_data() return x - return self._apply_nest(_filter, call_on_nested=True) + return self._apply_nest(_filter, call_on_nested=True, filter_empty=False) def _convert_inplace(self, inplace, key): if inplace is not False: @@ -2674,41 +2689,41 @@ def update_( if len(keys_to_update) == 0: return self keys_to_update = [_unravel_key_to_tuple(key) for key in keys_to_update] - if keys_to_update: + named = True def inplace_update(name, dest, source): if source is None: - return dest + return None name = _unravel_key_to_tuple(name) for key in keys_to_update: if key == name[: len(key)]: - return dest.copy_(source, non_blocking=True) - else: - return dest + dest.copy_(source, non_blocking=True) else: named = False def inplace_update(dest, source): if source is None: - return dest - return dest.copy_(source, non_blocking=True) + return None + dest.copy_(source, non_blocking=True) - if not is_tensor_collection(input_dict_or_td): + if not _is_tensor_collection(type(input_dict_or_td)): from tensordict import TensorDict input_dict_or_td = TensorDict.from_dict( input_dict_or_td, batch_dims=self.batch_dims ) - return self._apply_nest( + self._apply_nest( inplace_update, input_dict_or_td, nested_keys=True, default=None, - inplace=True, + filter_empty=True, named=named, + is_leaf=_is_leaf_nontensor, ) + return self def update_at_( self, @@ -3737,6 +3752,8 @@ def apply_(self, fn: Callable, *others, **kwargs) -> T: *others (sequence of TensorDictBase, optional): the other tensordicts to be used. + Keyword Args: See :meth:`~.apply`. + Returns: self or a copy of self with the function applied @@ -3752,8 +3769,9 @@ def apply( names: Sequence[str] | None = None, inplace: bool = False, default: Any = NO_DEFAULT, + filter_empty: bool | None = None, **constructor_kwargs, - ) -> T: + ) -> T | None: """Applies a callable to all values stored in the tensordict and sets them in a new tensordict. The callable signature must be ``Callable[Tuple[Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]``. @@ -3781,6 +3799,12 @@ def apply( default (Any, optional): default value for missing entries in the other tensordicts. If not provided, missing entries will raise a `KeyError`. + filter_empty (bool, optional): if ``True``, empty tensordicts will be + filtered out. This also comes with a lower computational cost as + empty data structures won't be created and destroyed. Non-tensor data + is considered as a leaf and thereby will be kept in the tensordict even + if left untouched by the function. + Defaults to ``False`` for backward compatibility. **constructor_kwargs: additional keyword arguments to be passed to the TensorDict constructor. @@ -3838,6 +3862,7 @@ def apply( inplace=inplace, checked=False, default=default, + filter_empty=filter_empty, **constructor_kwargs, ) @@ -3851,8 +3876,9 @@ def named_apply( names: Sequence[str] | None = None, inplace: bool = False, default: Any = NO_DEFAULT, + filter_empty: bool | None = None, **constructor_kwargs, - ) -> T: + ) -> T | None: """Applies a key-conditioned callable to all values stored in the tensordict and sets them in a new atensordict. The callable signature must be ``Callable[Tuple[str, Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]``. @@ -3882,6 +3908,10 @@ def named_apply( default (Any, optional): default value for missing entries in the other tensordicts. If not provided, missing entries will raise a `KeyError`. + filter_empty (bool, optional): if ``True``, empty tensordicts will be + filtered out. This also comes with a lower computational cost as + empty data structures won't be created and destroyed. Defaults to + ``False`` for backward compatibility. **constructor_kwargs: additional keyword arguments to be passed to the TensorDict constructor. @@ -3966,6 +3996,7 @@ def named_apply( default=default, named=True, nested_keys=nested_keys, + filter_empty=filter_empty, **constructor_kwargs, ) @@ -3984,8 +4015,10 @@ def _apply_nest( named: bool = False, nested_keys: bool = False, prefix: tuple = (), + filter_empty: bool | None = None, + is_leaf: Callable = None, **constructor_kwargs, - ) -> T: + ) -> T | None: ... def _fast_apply( @@ -4000,8 +4033,12 @@ def _fast_apply( default: Any = NO_DEFAULT, named: bool = False, nested_keys: bool = False, + # filter_empty must be False because we use _fast_apply for all sorts of ops like expand etc + # and non-tensor data will disappear if we use True by default. + filter_empty: bool | None = False, + is_leaf: Callable = None, **constructor_kwargs, - ) -> T: + ) -> T | None: """A faster apply method. This method does not run any check after performing the func. This @@ -4021,21 +4058,27 @@ def _fast_apply( named=named, default=default, nested_keys=nested_keys, + filter_empty=filter_empty, + is_leaf=is_leaf, **constructor_kwargs, ) def map( self, - fn: Callable, + fn: Callable[[TensorDictBase], TensorDictBase | None], dim: int = 0, num_workers: int | None = None, + *, + out: TensorDictBase = None, chunksize: int | None = None, num_chunks: int | None = None, pool: mp.Pool | None = None, generator: torch.Generator | None = None, max_tasks_per_child: int | None = None, worker_threads: int = 1, + index_with_generator: bool = False, pbar: bool = False, + mp_start_method: str | None = None, ): """Maps a function to splits of the tensordict across one dimension. @@ -4055,6 +4098,15 @@ def map( num_workers (int, optional): the number of workers. Exclusive with ``pool``. If none is provided, the number of workers will be set to the number of cpus available. + + Keyword Args: + out (TensorDictBase, optional): an optional container for the output. + Its batch-size along the ``dim`` provided must match ``self.ndim``. + If it is shared or memmap (:meth:`~.is_shared` or :meth:`~.is_memmap` + returns ``True``) it will be populated within the remote processes, + avoiding data inward transfers. Otherwise, the data from the ``self`` + slice will be sent to the process, collected on the current process + and written inplace into ``out``. chunksize (int, optional): The size of each chunk of data. A ``chunksize`` of 0 will unbind the tensordict along the desired dimension and restack it after the function is applied, @@ -4102,8 +4154,20 @@ def map( on the number of jobs. worker_threads (int, optional): the number of threads for the workers. Defaults to ``1``. + index_with_generator (bool, optional): if ``True``, the splitting / chunking + of the tensordict will be done during the query, sparing init time. + Note that :meth:`~.chunk` and :meth:`~.split` are much more + efficient than indexing (which is used within the generator) + so a gain of processing time at init time may have a negative + impact on the total runtime. Defaults to ``False``. pbar (bool, optional): if ``True``, a progress bar will be displayed. Requires tqdm to be available. Defaults to ``False``. + mp_start_method (str, optional): the start method for multiprocessing. + If not provided, the default start method will be used. + Accepted strings are ``"fork"`` and ``"spawn"``. Keep in mind that + ``"cuda"`` tensors cannot be shared between processes with the + ``"fork"`` start method. This is without effect if the ``pool`` + is passed to the ``map`` method. Examples: >>> import torch @@ -4137,11 +4201,15 @@ def map( seed = ( torch.empty((), dtype=torch.int64).random_(generator=generator).item() ) + if mp_start_method is not None: + ctx = mp.get_context(mp_start_method) + else: + ctx = mp.get_context() - queue = mp.Queue(maxsize=num_workers) + queue = ctx.Queue(maxsize=num_workers) for i in range(num_workers): queue.put(i) - with mp.Pool( + with ctx.Pool( processes=num_workers, initializer=_proc_init, initargs=(seed, queue, worker_threads), @@ -4154,6 +4222,7 @@ def map( num_chunks=num_chunks, pool=pool, pbar=pbar, + out=out, ) num_workers = pool._processes dim_orig = dim @@ -4162,18 +4231,66 @@ def map( if dim < 0 or dim >= self.ndim: raise ValueError(f"Got incompatible dimension {dim_orig}") - self_split = _split_tensordict(self, chunksize, num_chunks, num_workers, dim) + self_split = _split_tensordict( + self, + chunksize, + num_chunks, + num_workers, + dim, + use_generator=index_with_generator, + ) + if not index_with_generator: + length = len(self_split) + else: + length = None call_chunksize = 1 + + if out is not None and (out.is_shared() or out.is_memmap()): + + def wrap_fn_with_out(fn, out): + @wraps(fn) + def newfn(item_and_out): + item, out = item_and_out + result = fn(item) + out.update_(result) + return + + out_split = _split_tensordict( + out, + chunksize, + num_chunks, + num_workers, + dim, + use_generator=index_with_generator, + ) + return _CloudpickleWrapper(newfn), zip(self_split, out_split) + + fn, self_split = wrap_fn_with_out(fn, out) + out = None + imap = pool.imap(fn, self_split, call_chunksize) + if pbar and importlib.util.find_spec("tqdm", None) is not None: import tqdm - imap = tqdm.tqdm(imap, total=len(self_split)) + imap = tqdm.tqdm(imap, total=length) imaplist = [] + start = 0 + base_index = (slice(None),) * dim for item in imap: if item is not None: - imaplist.append(item) + if out is not None: + if chunksize == 0: + out[base_index + (start,)].update_(item) + start += 1 + else: + end = start + item.shape[dim] + chunk = base_index + (slice(start, end),) + out[chunk].update_(item) + start = end + else: + imaplist.append(item) del imap # support inplace modif @@ -4182,7 +4299,7 @@ def map( out = torch.stack(imaplist, dim) else: out = torch.cat(imaplist, dim) - return out + return out # Functorch compatibility @abc.abstractmethod @@ -4437,6 +4554,7 @@ def select(self, *keys: NestedKey, inplace: bool = False, strict: bool = True) - device=None, is_shared=False) """ + keys = unravel_key_list(keys) result = self._select(*keys, inplace=inplace, strict=strict) if not inplace and (result._is_memmap or result._is_shared): result.lock_() @@ -4491,6 +4609,7 @@ def exclude(self, *keys: NestedKey, inplace: bool = False) -> T: is_shared=False) """ + keys = unravel_key_list(keys) result = self._exclude(*keys, inplace=inplace) if not inplace and (result._is_memmap or result._is_shared): result.lock_() @@ -5294,10 +5413,8 @@ def _default_is_leaf(cls: Type) -> bool: def _is_leaf_nontensor(cls: Type) -> bool: - from tensordict.tensorclass import NonTensorData - - if issubclass(cls, KeyedJaggedTensor): - return False if _is_tensor_collection(cls): - return issubclass(cls, NonTensorData) + return cls._non_tensor + # if issubclass(cls, KeyedJaggedTensor): + # return False return issubclass(cls, torch.Tensor) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index a07e8dae7..c597c9046 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -5,17 +5,21 @@ from __future__ import annotations +import ctypes + import dataclasses import functools import inspect import json +import multiprocessing.managers +import multiprocessing.sharedctypes import numbers import os import pickle import re import sys import warnings -from copy import copy +from copy import copy, deepcopy from dataclasses import dataclass from pathlib import Path from textwrap import indent @@ -24,22 +28,24 @@ import tensordict as tensordict_lib import torch +from tensordict import LazyStackedTensorDict from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS -from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class +from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class, CompatibleType from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor - from tensordict.utils import ( _get_repr, _is_json_serializable, _LOCK_ERROR, DeviceType, IndexType, + is_non_tensor, is_tensorclass, NestedKey, ) -from torch import Tensor +from torch import multiprocessing as mp, Tensor +from torch.multiprocessing import Manager T = TypeVar("T", bound=TensorDictBase) PY37 = sys.version_info < (3, 8) @@ -56,6 +62,9 @@ torch.full_like, torch.zeros_like, torch.ones_like, + torch.rand_like, + torch.empty_like, + torch.randn_like, torch.clone, torch.squeeze, torch.unsqueeze, @@ -158,11 +167,13 @@ def __torch_function__( ) return _from_tensordict_with_copy(tensorclass_instance, result) + _non_tensor = getattr(cls, "_non_tensor", False) + cls = dataclass(cls) expected_keys = set(cls.__dataclass_fields__) for attr in cls.__dataclass_fields__: - if attr in dir(TensorDict): + if attr in dir(TensorDict) and attr != "_non_tensor": raise AttributeError( f"Attribute name {attr} can't be used with @tensorclass" ) @@ -186,25 +197,41 @@ def __torch_function__( cls.__ne__ = _ne cls.__or__ = _or cls.__xor__ = _xor - cls.set = _set - cls.set_at_ = _set_at_ - cls.del_ = _del_ - cls.get = _get - cls.get_at = _get_at - cls.unbind = _unbind - cls.state_dict = _state_dict - cls.load_state_dict = _load_state_dict - cls._memmap_ = _memmap_ + if not hasattr(cls, "set"): + cls.set = _set + if not hasattr(cls, "set_at_"): + cls.set_at_ = _set_at_ + if not hasattr(cls, "del_"): + cls.del_ = _del_ + if not hasattr(cls, "get"): + cls.get = _get + if not hasattr(cls, "get_at"): + cls.get_at = _get_at + if not hasattr(cls, "unbind"): + cls.unbind = _unbind + if not hasattr(cls, "state_dict"): + cls.state_dict = _state_dict + if not hasattr(cls, "load_state_dict"): + cls.load_state_dict = _load_state_dict + if not hasattr(cls, "_memmap_"): + cls._memmap_ = _memmap_ + if not hasattr(cls, "share_memory_"): + cls.share_memory_ = _share_memory_ cls.__enter__ = __enter__ cls.__exit__ = __exit__ # Memmap - cls.memmap_like = TensorDictBase.memmap_like - cls.memmap_ = TensorDictBase.memmap_ - cls.memmap = TensorDictBase.memmap - cls.load_memmap = TensorDictBase.load_memmap - cls._load_memmap = classmethod(_load_memmap) + if not hasattr(cls, "memmap_like"): + cls.memmap_like = TensorDictBase.memmap_like + if not hasattr(cls, "memmap_"): + cls.memmap_ = TensorDictBase.memmap_ + if not hasattr(cls, "memmap"): + cls.memmap = TensorDictBase.memmap + if not hasattr(cls, "load_memmap"): + cls.load_memmap = TensorDictBase.load_memmap + if not hasattr(cls, "_load_memmap"): + cls._load_memmap = classmethod(_load_memmap) for attr in TensorDict.__dict__.keys(): func = getattr(TensorDict, attr) @@ -213,13 +240,21 @@ def __torch_function__( if issubclass(tdcls, TensorDictBase): # detects classmethods setattr(cls, attr, _wrap_classmethod(tdcls, cls, func)) - cls.to_tensordict = _to_tensordict - cls.device = property(_device, _device_setter) - cls.batch_size = property(_batch_size, _batch_size_setter) + if not hasattr(cls, "_to_tensordict"): + cls.to_tensordict = _to_tensordict + if not hasattr(cls, "device"): + cls.device = property(_device, _device_setter) + if not hasattr(cls, "batch_size"): + cls.batch_size = property(_batch_size, _batch_size_setter) + if not hasattr(cls, "names"): + cls.names = property(_names, _names_setter) cls.__doc__ = f"{cls.__name__}{inspect.signature(cls)}" _register_tensor_class(cls) + + cls._non_tensor = _non_tensor + return cls @@ -382,6 +417,7 @@ def save_metadata(cls=cls, _non_tensordict=_non_tensordict, prefix=prefix): metadata = {"_type": str(cls)} to_pickle = {} for key, value in _non_tensordict.items(): + value = _from_shared_nontensor(value) if _is_json_serializable(value): metadata[key] = value else: @@ -414,6 +450,12 @@ def save_metadata(cls=cls, _non_tensordict=_non_tensordict, prefix=prefix): return result +# TODO: test +def _share_memory_(self): + self._tensordict.share_memory_() + return self + + def _load_memmap(cls, prefix: Path, metadata: dict): non_tensordict = copy(metadata) del non_tensordict["_type"] @@ -477,13 +519,19 @@ def wrapper(self, item: str) -> Any: and item in self.__dict__["_non_tensordict"] ): out = self._non_tensordict[item] + if ( + isinstance(self, NonTensorData) + and item == "data" + and (self._is_shared or self._is_memmap) + ): + return _from_shared_nontensor(out) return out return getattribute(self, item) return wrapper -SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts") +SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names") def _setattr_wrapper(setattr_: Callable, expected_keys: set[str]) -> Callable: @@ -722,6 +770,9 @@ def _set(self, key: NestedKey, value: Any, inplace: bool = False): __dict__ = self.__dict__ if __dict__["_tensordict"].is_locked: raise RuntimeError(_LOCK_ERROR) + if key in ("batch_size", "names", "device"): + # handled by setattr + return expected_keys = self.__dataclass_fields__ if key not in expected_keys: raise AttributeError( @@ -838,6 +889,26 @@ def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417 self._tensordict._batch_size_setter(new_size) +def _names(self) -> torch.Size: + """Retrieves the dim names for the tensor class. + + Returns: + names (list of str) + + """ + return self._tensordict.names + + +def _names_setter(self, names: str) -> None: # noqa: D417 + """Set the value of ``tensorclass.names``. + + Args: + names (sequence of str) + + """ + self._tensordict.names = names + + def _state_dict( self, destination=None, prefix="", keep_vars=False, flatten=False ) -> dict[str, Any]: @@ -1130,6 +1201,15 @@ def _unbind(self, dim: int): NONTENSOR_HANDLED_FUNCTIONS = [] +_MP_MANAGER = None + + +def _mp_manager(): + global _MP_MANAGER + if _MP_MANAGER is None: + _MP_MANAGER = Manager() + return _MP_MANAGER + @tensorclass class NonTensorData: @@ -1249,6 +1329,114 @@ class NonTensorData: meta.json >>> assert loaded.get_non_tensor("pickable").value == 10 + .. note:: __Preallocation__ is also possible with ``NonTensorData``. + This class can handle conversion from ``NonTensorData`` to + ``NonTensorStack`` where appropriate, as the following example + demonstrates: + + >>> td = TensorDict({"val": NonTensorData(data=0, batch_size=[10])}, [10]) + >>> print(td) + TensorDict( + fields={ + val: NonTensorData( + data=0, + _metadata=None, + _non_tensor=True, + batch_size=torch.Size([10]), + device=None, + is_shared=False)}, + batch_size=torch.Size([10]), + device=None, + is_shared=False) + >>> print(td["val"]) + 0 + >>> newdata = TensorDict({"val": NonTensorData(data=1, batch_size=[5])}, [5]) + >>> td[1::2] = newdata + >>> print(td) + TensorDict( + fields={ + val: NonTensorStack( + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + batch_size=torch.Size([10]), + device=None)}, + batch_size=torch.Size([10]), + device=None, + is_shared=False) + >>> print(td["val"]) # the stack is automatically converted to a list + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + + If the value is unique, the ``NonTensorData`` container is kept and + retrieving the value only returns this value. If a ``NonTensorStack`` + is used, ``__getitem__`` will return the list of values instead. + This makes the two operations not exactly interchangeable. The reason + for this inconsistency is that a single ``NonTensorData`` with a non-empty + batch-size is intended to be used as a metadata carrier for bigger + tensordicts, whereas ``NonTensorStack`` usage is aimed at allocating + one metadata atom to each corresponding batch element. + + .. note:: + ``NonTensorData`` can be shared between processes. In fact, both + :meth:`~tensordict.TensorDict.memmap_` (and the likes) and + :meth:`~tensordict.TensorDict.share_memory_` will produce sharable + instances. + + Valid methods to write data are :meth:`~tensordict.TensorDictBase.update` + with the `inplace=True` flag and :meth:`~tensordict.TensorDictBase.update_` + or :meth:`~tensordict.TensorDictBase.update_at_`. + + >>> if __name__ == "__main__": + ... td = TensorDict({"val": NonTensorData(data=0, batch_size=[])}, []) + ... td.share_memory_() + ... td.update_(TensorDict({"val": NonTensorData(data=1, batch_size=[])}, [])) # works + ... td.update(TensorDict({"val": NonTensorData(data=1, batch_size=[])}, []), inplace=True) # works + ... td["val"] = 1 # breaks + + A shared ``NonTensorData`` is writable whenever its content is a ``str``, + ``int``, ``float``, ``bool``, ``dict`` or ``list`` instance. Other types + (e.g., dataclasses) will not raise an exception during the call to + ``memmap_`` or ``share_memory_`` but they will cause the code to break + when the data is overwritten. + + >>> @dataclass + ... class MyClass: + ... string: str + ... + >>> if __name__ == "__main__": + ... td = TensorDict({"val": MyClass("a string!")}, []) + ... td.share_memory_() # works and can be shared between processes + ... td.update_(TensorDict({"val": MyClass("another string!")}, [])) # breaks! + + :class:`~tensordict.tensorclass.TensorStack` instances are also sharable + in a similar way. Crucially, preallocation must be properly handled for + this to work. + + >>> td = TensorDict({"val": NonTensorData(data=0, batch_size=[10])}, [10]) + >>> newdata = TensorDict({"val": NonTensorData(data=1, batch_size=[5])}, [5]) + >>> td[1::2] = newdata + >>> # If TD is properly preallocated, we can share it and change its content + >>> td.share_memory_() + >>> newdata = TensorDict({"val": NonTensorData(data=2, batch_size=[5])}, [5]) + >>> td[1::2] = newdata # Works! + >>> # In contrast, not preallocating the tensordict properly will break when assigning values + >>> td = TensorDict({"val": NonTensorData(data=0, batch_size=[10])}, [10]) + >>> td.share_memory_() + >>> newdata = TensorDict({"val": NonTensorData(data=2, batch_size=[5])}, [5]) + >>> td[1::2] = newdata # breaks! + + Writable memmapped-``NonTensorData`` instances will update the underlying + metadata if required. This involves writing in a JSON file, which can + introduce some overhead. We advise against this usage whenever one seeks + performance and long-lasting data sharing isn't required (``share_memory_`` + should be preferred in these cases). + + >>> if __name__ == "__main__": + ... td = TensorDict({"val": NonTensorData(data=0, batch_size=[])}, []) + ... td.memmap_(dest_folder) + ... td.update_(TensorDict({"val": NonTensorData(data=1, batch_size=[])}, [])) + ... # The underlying metadata on disk is updated during calls to update_ + ... td_load = TensorDict.load_memmap(dest_folder) + ... assert (td == td_load).all() + """ # Used to carry non-tensor data in a tensordict. @@ -1256,10 +1444,16 @@ class NonTensorData: # to patch tensordict with additional checks that will encur unwanted overhead # and all the overhead falls back on this class. data: Any + _metadata: dict | None = None + + _non_tensor: bool = True def __post_init__(self): - if isinstance(self.data, NonTensorData): - self.data = self.data.data + if is_non_tensor(self.data): + data = getattr(self.data, "data", None) + if data is None: + data = self.data.tolist() + self.data = data old_eq = self.__class__.__eq__ if old_eq is _eq: @@ -1314,6 +1508,85 @@ def __or__(self, other): self.__class__.__or__ = __or__ + def update( + self, + input_dict_or_td: dict[str, CompatibleType] | T, + clone: bool = False, + inplace: bool = False, + *, + keys_to_update: Sequence[NestedKey] | None = None, + ) -> T: + if isinstance(input_dict_or_td, NonTensorData): + data = input_dict_or_td.data + if inplace and self._tensordict._is_shared: + _update_shared_nontensor(self._non_tensordict["data"], data) + return self + elif inplace and self._tensordict._is_memmap: + _update_shared_nontensor(self._non_tensordict["data"], data) + # Force json update by setting is memmap to False + self._tensordict._is_memmap = False + self._memmap_( + prefix=self._metadata["memmap_prefix"], + copy_existing=False, + executor=None, + futures=None, + inplace=True, + like=False, + ) + return self + elif not inplace and self.is_locked: + raise RuntimeError(_LOCK_ERROR) + if clone: + data = deepcopy(data) + self.data = data + elif not input_dict_or_td.is_empty(): + raise RuntimeError(f"Unexpected type {type(input_dict_or_td)}") + return self + + def update_( + self, + input_dict_or_td: dict[str, CompatibleType] | T, + clone: bool = False, + *, + keys_to_update: Sequence[NestedKey] | None = None, + ) -> T: + + if isinstance(input_dict_or_td, NonTensorStack): + raise RuntimeError( + "Cannot update a NonTensorData with a NonTensorStack object." + ) + if not isinstance(input_dict_or_td, NonTensorData): + raise RuntimeError( + "NonTensorData.copy_ / update_ requires the source to be a NonTensorData object." + ) + + if isinstance(input_dict_or_td, NonTensorData): + data = input_dict_or_td.data + if self._tensordict._is_shared: + _update_shared_nontensor(self._non_tensordict["data"], data) + return self + if self._tensordict._is_memmap: + _update_shared_nontensor(self._non_tensordict["data"], data) + # Force json update by setting is memmap to False + self._tensordict._is_memmap = False + self._memmap_( + prefix=self._metadata["memmap_prefix"], + copy_existing=False, + executor=None, + futures=None, + inplace=True, + like=False, + ) + return self + if self._tensordict._is_memmap: + raise NotImplementedError + if clone: + data = deepcopy(data) + self.data = data + elif not input_dict_or_td.is_empty(): + raise RuntimeError(f"Unexpected type {type(input_dict_or_td)}") + return self + def empty(self, recurse=False): return NonTensorData( data=self.data, @@ -1340,7 +1613,9 @@ def _check_equal(a, b): iseq = False return iseq - if all(_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]): + if all(isinstance(data, NonTensorData) for data in list_of_non_tensor) and all( + _check_equal(data.data, first.data) for data in list_of_non_tensor[1:] + ): batch_size = list(first.batch_size) batch_size.insert(dim, len(list_of_non_tensor)) return NonTensorData( @@ -1350,9 +1625,7 @@ def _check_equal(a, b): device=first.device, ) - from tensordict._lazy import LazyStackedTensorDict - - return LazyStackedTensorDict(*list_of_non_tensor, stack_dim=dim) + return NonTensorStack(*list_of_non_tensor, stack_dim=dim) @classmethod def __torch_function__( @@ -1394,3 +1667,222 @@ def __torch_function__( if not escape_conversion: return _from_tensordict_with_copy(tensorclass_instance, result) return result + + def _apply_nest(self, *args, **kwargs): + kwargs["filter_empty"] = False + return _wrap_method(self, "_apply_nest", self._tensordict._apply_nest)( + *args, **kwargs + ) + + def _fast_apply(self, *args, **kwargs): + kwargs["filter_empty"] = False + return _wrap_method(self, "_fast_apply", self._tensordict._fast_apply)( + *args, **kwargs + ) + + def tolist(self): + """Converts the data in a list if the batch-size is non-empty. + + If the batch-size is empty, returns the data. + + """ + if not self.batch_size: + return self.data + return [ntd.tolist() for ntd in self.unbind(0)] + + def copy_(self, src: NonTensorData | NonTensorStack, non_blocking: bool = False): + return self.update_(src) + + def clone(self, recurse: bool = True): + if recurse: + return type(self)( + data=deepcopy(self.data), + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + return type(self)( + data=self.data, + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + + def share_memory_(self): + if self._tensordict._is_shared: + return self + with self.unlock_(): + self._non_tensordict["data"] = _share_memory_nontensor( + self.data, manager=_mp_manager() + ) + self._tensordict.share_memory_() + return self + + def _memmap_( + self, + prefix: str | None = None, + copy_existing: bool = False, + executor=None, + futures=None, + inplace=True, + like=False, + ): + if self._tensordict._is_memmap: + return self + + if prefix is not None: + _metadata = copy(self._metadata) + if _metadata is None: + self._non_tensordict["_metadata"] = {} + self._metadata["memmap_prefix"] = prefix + + out = _memmap_( + self, + prefix=prefix, + copy_existing=copy_existing, + executor=executor, + futures=futures, + inplace=inplace, + like=like, + ) + if prefix is not None and not inplace: + self._non_tensordict["_metadata"] = _metadata + out._non_tensordict["data"] = _share_memory_nontensor( + out.data, manager=_mp_manager() + ) + return out + + +class NonTensorStack(LazyStackedTensorDict): + """A thin wrapper around LazyStackedTensorDict to make stack on non-tensor data easily recognizable. + + A ``NonTensorStack`` is returned whenever :func:`~torch.stack` is called on + a list of :class:`~tensordict.NonTensorData` or ``NonTensorStack``. + + Examples: + >>> from tensordict import NonTensorData + >>> import torch + >>> data = torch.stack([ + ... torch.stack([NonTensorData(data=(i, j), batch_size=[]) for i in range(2)]) + ... for j in range(3)]) + >>> print(data) + NonTensorStack( + [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(0, 2), (1, ..., + batch_size=torch.Size([3, 2]), + device=None) + + To obtain the values stored in a ``NonTensorStack``, call :class:`~.tolist`. + + """ + + _non_tensor: bool = True + + def tolist(self): + """Extracts the content of a :class:`tensordict.tensorclass.NonTensorStack` in a nested list. + + Examples: + >>> from tensordict import NonTensorData + >>> import torch + >>> data = torch.stack([ + ... torch.stack([NonTensorData(data=(i, j), batch_size=[]) for i in range(2)]) + ... for j in range(3)]) + >>> data.tolist() + [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(0, 2), (1, 2)]] + + """ + iterator = self.tensordicts if self.stack_dim == 0 else self.unbind(0) + return [td.tolist() for td in iterator] + + @classmethod + def from_nontensordata(cls, non_tensor: NonTensorData): + data = non_tensor.data + prev = NonTensorData(data, batch_size=[], device=non_tensor.device) + for dim in reversed(non_tensor.shape): + prev = cls(*[prev.clone(False) for _ in range(dim)], stack_dim=0) + return prev + + def __repr__(self): + selfrepr = str(self.tolist()) + if len(selfrepr) > 50: + selfrepr = f"{selfrepr[:50]}..." + selfrepr = indent(selfrepr, prefix=4 * " ") + batch_size = indent(f"batch_size={self.batch_size}", prefix=4 * " ") + device = indent(f"device={self.device}", prefix=4 * " ") + return f"NonTensorStack(\n{selfrepr}," f"\n{batch_size}," f"\n{device})" + + def to_dict(self) -> dict[str, Any]: + return self.tolist() + + +_register_tensor_class(NonTensorStack) + + +def _share_memory_nontensor(data, manager: Manager): + if isinstance(data, int): + return mp.Value(ctypes.c_int, data) + if isinstance(data, float): + return mp.Value(ctypes.c_double, data) + if isinstance(data, bool): + return mp.Value(ctypes.c_bool, data) + if isinstance(data, bytes): + return mp.Value(ctypes.c_byte, data) + if isinstance(data, dict): + result = manager.dict() + result.update(data) + return result + if isinstance(data, str): + result = mp.Array(ctypes.c_char, 100) + data = data.encode("utf-8") + result[: len(data)] = data + return result + if isinstance(data, list): + result = manager.list() + result.extend(data) + return result + # In all other cases, we just return the tensor. It's ok because the content + # will be passed to the remote process using regular serialization. We will + # lock the update in _update_shared_nontensor though. + return data + + +def _from_shared_nontensor(nontensor): + if isinstance(nontensor, multiprocessing.managers.ListProxy): + return list(nontensor) + if isinstance(nontensor, multiprocessing.managers.DictProxy): + return dict(nontensor) + if isinstance(nontensor, multiprocessing.sharedctypes.Synchronized): + return nontensor.value + if isinstance(nontensor, multiprocessing.sharedctypes.SynchronizedArray): + byte_list = [] + for byte in nontensor: + if byte == b"\x00": + break + byte_list.append(byte) + return b"".join(byte_list).decode("utf-8") + return nontensor + + +def _update_shared_nontensor(nontensor, val): + if isinstance(nontensor, multiprocessing.managers.ListProxy): + nontensor[:] = [] + nontensor.extend(val) + elif isinstance(nontensor, multiprocessing.managers.DictProxy): + nontensor.clear() + nontensor.update(val) + elif isinstance(nontensor, multiprocessing.sharedctypes.Synchronized): + nontensor.value = val + elif isinstance(nontensor, multiprocessing.sharedctypes.SynchronizedArray): + val = val.encode("utf-8") + for i, byte in enumerate(nontensor): + if i < len(val): + v = val[i] + nontensor[i] = v + elif byte == b"\x00": + break + else: + nontensor[i] = b"\x00" + # nontensor[0] = val.encode("utf-8") + else: + raise NotImplementedError( + f"Updating {type(nontensor).__name__} within a shared/memmaped structure is not supported." + ) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 5801235a7..adac91c79 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -16,6 +16,7 @@ import platform import re import uuid +from dataclasses import dataclass from pathlib import Path import numpy as np @@ -3161,7 +3162,7 @@ def test_memmap_like(self, td_name, device, use_dir, tmpdir, num_threads): v2 = tdmemmap[key] if isinstance(v1, str): # non-tensor data storing strings share the same id in python - assert v1 is v2 + assert v1 == v2 else: assert v1 is not v2 assert (tdmemmap == 0).all() @@ -7786,7 +7787,8 @@ def selectfn(input): @pytest.mark.parametrize("chunksize", [0, 5]) @pytest.mark.parametrize("mmap", [True, False]) - def test_map_with_out(self, mmap, chunksize, tmpdir): + @pytest.mark.parametrize("start_method", [None, "fork"]) + def test_map_with_out(self, mmap, chunksize, tmpdir, start_method): tmpdir = Path(tmpdir) input = TensorDict({"a": torch.arange(10), "b": torch.arange(10)}, [10]) if mmap: @@ -7794,7 +7796,13 @@ def test_map_with_out(self, mmap, chunksize, tmpdir): out = TensorDict({"a": torch.zeros(10, dtype=torch.int)}, [10]) if mmap: out.memmap_(tmpdir / "output") - input.map(self.selectfn, num_workers=2, chunksize=chunksize, out=out) + input.map( + self.selectfn, + num_workers=2, + chunksize=chunksize, + out=out, + mp_start_method=start_method, + ) assert (out["a"] == torch.arange(10)).all(), (chunksize, mmap) @classmethod @@ -7933,6 +7941,234 @@ def test_ignore_lock(self): assert td[0]["a", "b"] == "0" assert td[1]["a", "b"] == "1" + PAIRS = [ + ("something", "something else"), + (0, 1), + (0.0, 1.0), + ([0, "something", 2], [9, "something else", 11]), + ({"key1": 1, 2: 3}, {"key1": 4, 5: 6}), + ] + + @pytest.mark.parametrize("pair", PAIRS) + @pytest.mark.parametrize("strategy", ["shared", "memmap"]) + @pytest.mark.parametrize("update", ["update_", "update-inplace", "update"]) + def test_shared_memmap_single(self, pair, strategy, update, tmpdir): + val0, val1 = pair + td = TensorDict({"val": NonTensorData(data=val0, batch_size=[])}, []) + if strategy == "shared": + td.share_memory_() + elif strategy == "memmap": + td.memmap_(tmpdir) + else: + raise RuntimeError + + # Test that the Value is unpacked + assert td.get("val").data == val0 + assert td["val"] == val0 + + # Check shared status + if strategy == "shared": + assert td._is_shared + assert td.get("val")._is_shared + assert td.get("val")._tensordict._is_shared + elif strategy == "memmap": + assert td._is_memmap + assert td.get("val")._is_memmap + assert td.get("val")._tensordict._is_memmap + + # check that the json has been updated + td_load = TensorDict.load_memmap(tmpdir) + assert td["val"] == td_load["val"] + # with open(Path(tmpdir) / "val" / "meta.json") as file: + # print(json.load(file)) + + # Update in place + if update == "setitem": + td["val"] = val1 + elif update == "update_": + td.get("val").update_(NonTensorData(data=val1, batch_size=[])) + elif update == "update-inplace": + td.get("val").update(NonTensorData(data=val1, batch_size=[]), inplace=True) + elif update == "update": + with pytest.raises(RuntimeError, match="lock"): + td.get("val").update( + NonTensorData(data="something else", batch_size=[]) + ) + return + + # Test that the Value is unpacked + assert td.get("val").data == val1 + assert td["val"] == val1 + + # Check shared status + if strategy == "shared": + assert td._is_shared + assert td.get("val")._is_shared + assert td.get("val")._tensordict._is_shared + elif strategy == "memmap": + assert td._is_memmap + assert td.get("val")._is_memmap + assert td.get("val")._tensordict._is_memmap + + # check that the json has been updated + td_load = TensorDict.load_memmap(tmpdir) + assert td["val"] == td_load["val"] + # with open(Path(tmpdir) / "val" / "meta.json") as file: + # print(json.load(file)) + + @staticmethod + def _run_worker(td, val1, update): + # Update in place + if update == "setitem": + td["val"] = val1 + elif update == "update_": + td.get("val").update_(NonTensorData(data=val1, batch_size=[])) + elif update == "update-inplace": + td.get("val").update(NonTensorData(data=val1, batch_size=[]), inplace=True) + else: + raise NotImplementedError + # Test that the Value is unpacked + assert td.get("val").data == val1 + assert td["val"] == val1 + + @pytest.mark.parametrize("pair", PAIRS) + @pytest.mark.parametrize("strategy", ["shared", "memmap"]) + @pytest.mark.parametrize("update", ["update_", "update-inplace"]) + def test_shared_memmap_mult(self, pair, strategy, update, tmpdir): + from tensordict.tensorclass import _from_shared_nontensor + + val0, val1 = pair + td = TensorDict({"val": NonTensorData(data=val0, batch_size=[])}, []) + if strategy == "shared": + td.share_memory_() + elif strategy == "memmap": + td.memmap_(tmpdir) + else: + raise RuntimeError + + # Test that the Value is unpacked + assert td.get("val").data == val0 + assert td["val"] == val0 + + # Check shared status + if strategy == "shared": + assert td._is_shared + assert td.get("val")._is_shared + assert td.get("val")._tensordict._is_shared + elif strategy == "memmap": + assert td._is_memmap + assert td.get("val")._is_memmap + assert td.get("val")._tensordict._is_memmap + + # check that the json has been updated + td_load = TensorDict.load_memmap(tmpdir) + assert td["val"] == td_load["val"] + # with open(Path(tmpdir) / "val" / "meta.json") as file: + # print(json.load(file)) + + proc = mp.Process(target=self._run_worker, args=(td, val1, update)) + proc.start() + proc.join() + + # Test that the Value is unpacked + assert _from_shared_nontensor(td.get("val")._non_tensordict["data"]) == val1 + assert td.get("val").data == val1 + assert td["val"] == val1 + + # Check shared status + if strategy == "shared": + assert td._is_shared + assert td.get("val")._is_shared + assert td.get("val")._tensordict._is_shared + elif strategy == "memmap": + assert td._is_memmap + assert td.get("val")._is_memmap + assert td.get("val")._tensordict._is_memmap + + # check that the json has been updated + td_load = TensorDict.load_memmap(tmpdir) + assert td["val"] == td_load["val"] + # with open(Path(tmpdir) / "val" / "meta.json") as file: + # print(json.load(file)) + + def test_shared_limitations(self): + # Sharing a special type works but it's locked for writing + @dataclass + class MyClass: + string: str + + val0 = MyClass(string="a string!") + + td = TensorDict({"val": NonTensorData(data=val0, batch_size=[])}, []) + td.share_memory_() + + # with pytest.raises(RuntimeError) + val1 = MyClass(string="another string!") + with pytest.raises( + NotImplementedError, match="Updating MyClass within a shared/memmaped" + ): + td.update( + TensorDict({"val": NonTensorData(data=val1, batch_size=[])}, []), + inplace=True, + ) + with pytest.raises( + NotImplementedError, match="Updating MyClass within a shared/memmaped" + ): + td.update_(TensorDict({"val": NonTensorData(data=val1, batch_size=[])}, [])) + + # We can update a batched NonTensorData to a NonTensorStack if it's not already shared + td = TensorDict({"val": NonTensorData(data=0, batch_size=[10])}, [10]) + td[1::2] = TensorDict({"val": NonTensorData(data=1, batch_size=[5])}, [5]) + assert td.get("val").tolist() == [0, 1] * 5 + td = TensorDict({"val": NonTensorData(data=0, batch_size=[10])}, [10]) + td.share_memory_() + with pytest.raises( + RuntimeError, + match="You're attempting to update a leaf in-place with a shared", + ): + td[1::2] = TensorDict({"val": NonTensorData(data=1, batch_size=[5])}, [5]) + + def _update_stack(self, td): + td[1::2] = TensorDict({"val": NonTensorData(data=3, batch_size=[5])}, [5]) + + @pytest.mark.parametrize("update", ["update_at_", "slice"]) + @pytest.mark.parametrize("strategy", ["shared", "memmap"]) + def test_shared_stack(self, strategy, update, tmpdir): + td = TensorDict({"val": NonTensorData(data=0, batch_size=[10])}, [10]) + newdata = TensorDict({"val": NonTensorData(data=1, batch_size=[5])}, [5]) + if update == "slice": + td[1::2] = newdata + elif update == "update_at_": + td.update_at_(newdata, slice(1, None, 2)) + else: + raise NotImplementedError + if strategy == "shared": + td.share_memory_() + elif strategy == "memmap": + td.memmap_(tmpdir) + else: + raise NotImplementedError + assert td.get("val").tolist() == [0, 1] * 5 + + newdata = TensorDict({"val": NonTensorData(data=2, batch_size=[5])}, [5]) + if update == "slice": + td[1::2] = newdata + elif update == "update_at_": + td.update_at_(newdata, slice(1, None, 2)) + else: + raise NotImplementedError + + assert td.get("val").tolist() == [0, 2] * 5 + if strategy == "memmap": + assert TensorDict.load_memmap(tmpdir).get("val").tolist() == [0, 2] * 5 + + proc = mp.Process(target=self._update_stack, args=(td,)) + proc.start() + proc.join() + assert td.get("val").tolist() == [0, 3] * 5 + if strategy == "memmap": + assert TensorDict.load_memmap(tmpdir).get("val").tolist() == [0, 3] * 5 + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()