Skip to content

Commit

Permalink
[BugFix, Feature] tensorclass.to_dict and from_dict (#707)
Browse files Browse the repository at this point in the history
(cherry picked from commit d6b6a4b)
  • Loading branch information
vmoens committed Mar 24, 2024
1 parent e88e53d commit 5cdddf5
Show file tree
Hide file tree
Showing 10 changed files with 849 additions and 60 deletions.
106 changes: 94 additions & 12 deletions docs/source/reference/tensorclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ tensorclass
The ``@tensorclass`` decorator helps you build custom classes that inherit the
behaviour from :class:`~tensordict.TensorDict` while being able to restrict
the possible entries to a predefined set or implement custom methods for your class.
Like :class:`~tensordict.TensorDict`, ``@tensorclass`` supports nesting, indexing, reshaping,
item assignment. It also supports tensor operations like clone, squeeze, cat,
split and many more. ``@tensorclass`` allows non-tensor entries,

Like :class:`~tensordict.TensorDict`, ``@tensorclass`` supports nesting,
indexing, reshaping, item assignment. It also supports tensor operations like
``clone``, ``squeeze``, ``torch.cat``, ``split`` and many more.
``@tensorclass`` allows non-tensor entries,
however all the tensor operations are strictly restricted to tensor attributes.

One needs to implement their custom methods for non-tensor data.
It is important to note that ``@tensorclass`` does not enforce strict type matching

Expand Down Expand Up @@ -69,9 +72,12 @@ It is important to note that ``@tensorclass`` does not enforce strict type match
device=None,
is_shared=False)
As it is the case with :class:`~tensordict.TensorDict`, from v0.4 if the batch size
is omitted it is considered as empty.

``@tensorclass`` supports indexing. Internally the tensor objects gets indexed,
however the non-tensor data remains the same
If a non-empty batch-size is provided, ``@tensorclass`` supports indexing.
Internally the tensor objects gets indexed, however the non-tensor data
remains the same

.. code-block::
Expand Down Expand Up @@ -150,7 +156,7 @@ Here is an example:
is_shared=False)
Serialization
~~~~~~~~~~~~~
-------------

Saving a tensorclass instance can be achieved with the `memmap` method.
The saving strategy is as follows: tensor data will be saved using memory-mapped
Expand All @@ -168,7 +174,7 @@ the `tensorclass` is available in the working environment:


Edge cases
~~~~~~~~~~
----------

``@tensorclass`` supports equality and inequality operators, even for
nested objects. Note that the non-tensor/ meta data is not validated.
Expand Down Expand Up @@ -212,11 +218,12 @@ thrown
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours
Even though ``@tensorclass`` supports torch functions like cat and stack, the
non-tensor / meta data is not validated. The torch operation is performed on the
tensor data and while returning the output, the non-tensor / meta data of the first
tensor class object is considered. User needs to make sure that all the
list of tensor class objects have the same non-tensor data to avoid discrepancies
Even though ``@tensorclass`` supports torch functions like :func:`~torch.cat`
and :func:`~torch.stack`, the non-tensor / meta data is not validated.
The torch operation is performed on the tensor data and while returning the
output, the non-tensor / meta data of the first tensor class object is
considered. User needs to make sure that all the list of tensor class objects
have the same non-tensor data to avoid discrepancies

Here is an example:

Expand Down Expand Up @@ -273,3 +280,78 @@ Here is an example:

tensorclass
NonTensorData
NonTensorStack

Auto-casting
------------

.. warning:: Auto-casting is an experimental feature and subject to changes in
the future. Compatibility with python<=3.9 is limited.

``@tensorclass`` partially supports auto-casting as an experimental feature.
Methods such as ``__setattr__``, ``update``, ``update_`` and ``from_dict`` will
attempt to cast type-annotated entries to the desired TensorDict / tensorclass
instance (except in cases detailed below). For instance, following code will
cast the `td` dictionary to a :class:`~tensordict.TensorDict` and the `tc`
entry to a :class:`MyClass` instance:

>>> @tensorclass
... class MyClass:
... tensor: torch.Tensor
... td: TensorDict
... tc: MyClass
...
>>> obj = MyClass(
... tensor=torch.randn(()),
... td={"a": torch.randn(())},
... tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> assert isinstance(obj.tc, TensorDict)
>>> assert isinstance(obj.td, MyClass)

.. note:: Type annotated items that include an ``typing.Optional`` or
``typing.Union`` will not be compatible with auto-casting, but other items
in the tensorclass will:

>>> @tensorclass
... class MyClass:
... tensor: torch.Tensor
... tc_autocast: MyClass = None
... tc_not_autocast: Optional[MyClass] = None
>>> obj = MyClass(
... tensor=torch.randn(()),
... tc_autocast={"tensor": torch.randn(())},
... tc_not_autocast={"tensor": torch.randn(())},
... )
>>> assert isinstance(obj.tc_autocast, MyClass)
>>> # because the type is Optional or Union, auto-casting is disabled for
>>> # that variable.
>>> assert not isinstance(obj.tc_not_autocast, MyClass)

If at least one item in the class is annotated using the ``type0 | type1``
semantic, the whole class auto-casting capabilities are deactivated.
Because ``tensorclass`` supports non-tensor leaves, setting a dictionary in
these cases will lead to setting it as a plain dictionary instead of a
tensor collection subclass (``TensorDict`` or ``tensorclass``):

>>> @tensorclass
... class MyClass:
... tensor: torch.Tensor
... td: TensorDict
... tc: MyClass | None
...
>>> obj = MyClass(
... tensor=torch.randn(()),
... td={"a": torch.randn(())},
... tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> # tc and td have not been cast
>>> assert isinstance(obj.tc, dict)
>>> assert isinstance(obj.td, dict)

.. note:: Auto-casting isn't enabled for leaves (tensors).
The reason for this is that this feature isn't compatible with type
annotations that contain the ``type0 | type1`` type hinting semantic, which
is widespread. Allowing auto-casting would result in very similar codes to
have drastically different behaviours if the type annotation differs only
slightly.
7 changes: 6 additions & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def newfunc(self, *args, **kwargs):
raise RuntimeError(
f"the method {func.__name__} cannot complete when there are exclusive keys."
)
return getattr(TensorDictBase, func.__name__)(self, *args, **kwargs)
parent_func = getattr(TensorDictBase, func.__name__, None)
if parent_func is None:
parent_func = getattr(TensorDict, func.__name__)
return parent_func(self, *args, **kwargs)

return newfunc

Expand Down Expand Up @@ -2539,6 +2542,7 @@ def _unsqueeze(self, dim):
reshape = TensorDict.reshape
split = TensorDict.split
_to_module = TensorDict._to_module
from_dict_instance = TensorDict.from_dict_instance


class _CustomOpTensorDict(TensorDictBase):
Expand Down Expand Up @@ -3071,6 +3075,7 @@ def _unsqueeze(self, dim):
expand = TensorDict.expand
_unbind = TensorDict._unbind
_get_names_idx = TensorDict._get_names_idx
from_dict_instance = TensorDict.from_dict_instance


class _UnsqueezedTensorDict(_CustomOpTensorDict):
Expand Down
54 changes: 49 additions & 5 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ def _quick_set(swap_dict, swap_td):
def __ne__(self, other: object) -> T | bool:
if _is_tensorclass(other):
return other != self
if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__):
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
if _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
Expand All @@ -470,7 +472,9 @@ def __ne__(self, other: object) -> T | bool:
def __xor__(self, other: object) -> T | bool:
if _is_tensorclass(other):
return other ^ self
if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__):
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
if _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
Expand All @@ -492,7 +496,9 @@ def __xor__(self, other: object) -> T | bool:
def __or__(self, other: object) -> T | bool:
if _is_tensorclass(other):
return other | self
if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__):
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
if _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
Expand All @@ -515,7 +521,7 @@ def __eq__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other == self
if isinstance(other, (dict,)):
other = self.empty(recurse=True).update(other)
other = self.from_dict_instance(other)
if _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
Expand Down Expand Up @@ -570,7 +576,7 @@ def __setitem__(
if isinstance(value, (TensorDictBase, dict)):
indexed_bs = _getitem_batch_size(self.batch_size, index)
if isinstance(value, dict):
value = TensorDict.from_dict(value, batch_size=indexed_bs)
value = self.from_dict_instance(value, batch_size=indexed_bs)
# value = self.empty(recurse=True)[index].update(value)
if value.batch_size != indexed_bs:
if value.shape == indexed_bs[-len(value.shape) :]:
Expand Down Expand Up @@ -1299,6 +1305,7 @@ def from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None):
)

batch_size_set = torch.Size(()) if batch_size is None else batch_size
input_dict = copy(input_dict)
for key, value in list(input_dict.items()):
if isinstance(value, (dict,)):
# we don't know if another tensor of smaller size is coming
Expand All @@ -1318,6 +1325,41 @@ def from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None):
out.batch_size = batch_size
return out

def from_dict_instance(
self, input_dict, batch_size=None, device=None, batch_dims=None
):
if batch_dims is not None and batch_size is not None:
raise ValueError(
"Cannot pass both batch_size and batch_dims to `from_dict`."
)
from tensordict import TensorDict

batch_size_set = torch.Size(()) if batch_size is None else batch_size
input_dict = copy(input_dict)
for key, value in list(input_dict.items()):
if isinstance(value, (dict,)):
cur_value = self.get(key, None)
if cur_value is not None:
input_dict[key] = cur_value.from_dict_instance(
value, batch_size=[], device=device, batch_dims=None
)
continue
# we don't know if another tensor of smaller size is coming
# so we can't be sure that the batch-size will still be valid later
input_dict[key] = TensorDict.from_dict(
value, batch_size=[], device=device, batch_dims=None
)
out = TensorDict.from_dict(
input_dict,
batch_size=batch_size_set,
device=device,
)
if batch_size is None:
_set_max_batch_size(out, batch_dims)
else:
out.batch_size = batch_size
return out

@staticmethod
def _parse_batch_size(
source: T | dict,
Expand Down Expand Up @@ -2330,6 +2372,8 @@ def _convert_inplace(self, inplace, key):
raise RuntimeError(_LOCK_ERROR)
return inplace

from_dict_instance = TensorDict.from_dict_instance

def _set_str(
self,
key: NestedKey,
Expand Down
62 changes: 56 additions & 6 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,57 @@ def auto_batch_size_(self, batch_dims: int | None = None) -> T:
_set_max_batch_size(self, batch_dims)
return self

@abc.abstractmethod
def from_dict_instance(
self, input_dict, batch_size=None, device=None, batch_dims=None
):
"""Instance method version of :meth:`~tensordict.TensorDict.from_dict`.
Unlike :meth:`~tensordict.TensorDict.from_dict`, this method will
attempt to keep the tensordict types within the existing tree (for
any existing leaf).
Examples:
>>> from tensordict import TensorDict, tensorclass
>>> import torch
>>>
>>> @tensorclass
>>> class MyClass:
... x: torch.Tensor
... y: int
>>>
>>> td = TensorDict({"a": torch.randn(()), "b": MyClass(x=torch.zeros(()), y=1)})
>>> print(td.from_dict_instance(td.to_dict()))
TensorDict(
fields={
a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
b: MyClass(
x=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
y=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> print(td.from_dict(td.to_dict()))
TensorDict(
fields={
a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""
...

# Module interaction
@classmethod
def from_module(
Expand Down Expand Up @@ -672,13 +723,12 @@ def _batch_size_setter(self, new_batch_size: torch.Size) -> None:
)
if not isinstance(new_batch_size, torch.Size):
new_batch_size = torch.Size(new_batch_size)
for key in self.keys():
if _is_tensor_collection(self.entry_class(key)):
tensordict = self.get(key)
if len(tensordict.batch_size) < len(new_batch_size):
for key, value in self.items():
if _is_tensor_collection(type(value)):
if len(value.batch_size) < len(new_batch_size):
# document as edge case
tensordict.batch_size = new_batch_size
self._set_str(key, tensordict, inplace=True, validated=True)
value.batch_size = new_batch_size
self._set_str(key, value, inplace=True, validated=True)
self._check_new_batch_size(new_batch_size)
self._change_batch_size(new_batch_size)
if self._has_names():
Expand Down
6 changes: 6 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,12 @@ def _exclude(
) -> TensorDictBase:
...

@_carry_over
def from_dict_instance(
self, input_dict, batch_size=None, device=None, batch_dims=None
):
...

@_carry_over
def _legacy_transpose(self, dim0, dim1):
...
Expand Down
1 change: 1 addition & 0 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,7 @@ def _unsqueeze(self, dim):
_to_module = TensorDict._to_module
_unbind = TensorDict._unbind
_get_names_idx = TensorDict._get_names_idx
from_dict_instance = TensorDict.from_dict_instance


def _set_max_batch_size(source: PersistentTensorDict):
Expand Down
Loading

0 comments on commit 5cdddf5

Please sign in to comment.