Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix, Feature] tensorclass.to_dict and from_dict #707

Merged
merged 12 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 93 additions & 12 deletions docs/source/reference/tensorclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ tensorclass
The ``@tensorclass`` decorator helps you build custom classes that inherit the
behaviour from :class:`~tensordict.TensorDict` while being able to restrict
the possible entries to a predefined set or implement custom methods for your class.
Like :class:`~tensordict.TensorDict`, ``@tensorclass`` supports nesting, indexing, reshaping,
item assignment. It also supports tensor operations like clone, squeeze, cat,
split and many more. ``@tensorclass`` allows non-tensor entries,

Like :class:`~tensordict.TensorDict`, ``@tensorclass`` supports nesting,
indexing, reshaping, item assignment. It also supports tensor operations like
``clone``, ``squeeze``, ``torch.cat``, ``split`` and many more.
``@tensorclass`` allows non-tensor entries,
however all the tensor operations are strictly restricted to tensor attributes.

One needs to implement their custom methods for non-tensor data.
It is important to note that ``@tensorclass`` does not enforce strict type matching

Expand Down Expand Up @@ -69,9 +72,12 @@ It is important to note that ``@tensorclass`` does not enforce strict type match
device=None,
is_shared=False)

As it is the case with :class:`~tensordict.TensorDict`, from v0.4 if the batch size
is omitted it is considered as empty.

``@tensorclass`` supports indexing. Internally the tensor objects gets indexed,
however the non-tensor data remains the same
If a non-empty batch-size is provided, ``@tensorclass`` supports indexing.
Internally the tensor objects gets indexed, however the non-tensor data
remains the same

.. code-block::

Expand Down Expand Up @@ -150,7 +156,7 @@ Here is an example:
is_shared=False)

Serialization
~~~~~~~~~~~~~
-------------

Saving a tensorclass instance can be achieved with the `memmap` method.
The saving strategy is as follows: tensor data will be saved using memory-mapped
Expand All @@ -168,7 +174,7 @@ the `tensorclass` is available in the working environment:


Edge cases
~~~~~~~~~~
----------

``@tensorclass`` supports equality and inequality operators, even for
nested objects. Note that the non-tensor/ meta data is not validated.
Expand Down Expand Up @@ -212,11 +218,12 @@ thrown
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours

Even though ``@tensorclass`` supports torch functions like cat and stack, the
non-tensor / meta data is not validated. The torch operation is performed on the
tensor data and while returning the output, the non-tensor / meta data of the first
tensor class object is considered. User needs to make sure that all the
list of tensor class objects have the same non-tensor data to avoid discrepancies
Even though ``@tensorclass`` supports torch functions like :func:`~torch.cat`
and :func:`~torch.stack`, the non-tensor / meta data is not validated.
The torch operation is performed on the tensor data and while returning the
output, the non-tensor / meta data of the first tensor class object is
considered. User needs to make sure that all the list of tensor class objects
have the same non-tensor data to avoid discrepancies

Here is an example:

Expand Down Expand Up @@ -274,3 +281,77 @@ Here is an example:
tensorclass
NonTensorData
NonTensorStack

Auto-casting
------------

.. warning:: Auto-casting is an experimental feature and subject to changes in
the future. Compatibility with python<=3.9 is limited.

``@tensorclass`` partially supports auto-casting as an experimental feature.
Methods such as ``__setattr__``, ``update``, ``update_`` and ``from_dict`` will
attempt to cast type-annotated entries to the desired TensorDict / tensorclass
instance (except in cases detailed below). For instance, following code will
cast the `td` dictionary to a :class:`~tensordict.TensorDict` and the `tc`
entry to a :class:`MyClass` instance:

>>> @tensorclass
... class MyClass:
... tensor: torch.Tensor
... td: TensorDict
... tc: MyClass
...
>>> obj = MyClass(
... tensor=torch.randn(()),
... td={"a": torch.randn(())},
... tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> assert isinstance(obj.tc, TensorDict)
>>> assert isinstance(obj.td, MyClass)

.. note:: Type annotated items that include an ``typing.Optional`` or
``typing.Union`` will not be compatible with auto-casting, but other items
in the tensorclass will:

>>> @tensorclass
... class MyClass:
... tensor: torch.Tensor
... tc_autocast: MyClass = None
... tc_not_autocast: Optional[MyClass] = None
>>> obj = MyClass(
... tensor=torch.randn(()),
... tc_autocast={"tensor": torch.randn(())},
... tc_not_autocast={"tensor": torch.randn(())},
... )
>>> assert isinstance(obj.tc_autocast, MyClass)
>>> # because the type is Optional or Union, auto-casting is disabled for
>>> # that variable.
>>> assert not isinstance(obj.tc_not_autocast, MyClass)

If at least one item in the class is annotated using the ``type0 | type1``
semantic, the whole class auto-casting capabilities are deactivated.
Because ``tensorclass`` supports non-tensor leaves, setting a dictionary in
these cases will lead to setting it as a plain dictionary instead of a
tensor collection subclass (``TensorDict`` or ``tensorclass``):

>>> @tensorclass
... class MyClass:
... tensor: torch.Tensor
... td: TensorDict
... tc: MyClass | None
...
>>> obj = MyClass(
... tensor=torch.randn(()),
... td={"a": torch.randn(())},
... tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> # tc and td have not been cast
>>> assert isinstance(obj.tc, dict)
>>> assert isinstance(obj.td, dict)

.. note:: Auto-casting isn't enabled for leaves (tensors).
The reason for this is that this feature isn't compatible with type
annotations that contain the ``type0 | type1`` type hinting semantic, which
is widespread. Allowing auto-casting would result in very similar codes to
have drastically different behaviours if the type annotation differs only
slightly.
7 changes: 6 additions & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def newfunc(self, *args, **kwargs):
raise RuntimeError(
f"the method {func.__name__} cannot complete when there are exclusive keys."
)
return getattr(TensorDictBase, func.__name__)(self, *args, **kwargs)
parent_func = getattr(TensorDictBase, func.__name__, None)
if parent_func is None:
parent_func = getattr(TensorDict, func.__name__)
return parent_func(self, *args, **kwargs)

return newfunc

Expand Down Expand Up @@ -2538,6 +2541,7 @@ def _unsqueeze(self, dim):
reshape = TensorDict.reshape
split = TensorDict.split
_to_module = TensorDict._to_module
from_dict_instance = TensorDict.from_dict_instance


class _CustomOpTensorDict(TensorDictBase):
Expand Down Expand Up @@ -3070,6 +3074,7 @@ def _unsqueeze(self, dim):
expand = TensorDict.expand
_unbind = TensorDict._unbind
_get_names_idx = TensorDict._get_names_idx
from_dict_instance = TensorDict.from_dict_instance


class _UnsqueezedTensorDict(_CustomOpTensorDict):
Expand Down
54 changes: 49 additions & 5 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ def _quick_set(swap_dict, swap_td):
def __ne__(self, other: object) -> T | bool:
if _is_tensorclass(other):
return other != self
if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__):
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
if _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
Expand All @@ -470,7 +472,9 @@ def __ne__(self, other: object) -> T | bool:
def __xor__(self, other: object) -> T | bool:
if _is_tensorclass(other):
return other ^ self
if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__):
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
if _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
Expand All @@ -492,7 +496,9 @@ def __xor__(self, other: object) -> T | bool:
def __or__(self, other: object) -> T | bool:
if _is_tensorclass(other):
return other | self
if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__):
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
if _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
Expand All @@ -515,7 +521,7 @@ def __eq__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other == self
if isinstance(other, (dict,)):
other = self.empty(recurse=True).update(other)
other = self.from_dict_instance(other)
if _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
Expand Down Expand Up @@ -570,7 +576,7 @@ def __setitem__(
if isinstance(value, (TensorDictBase, dict)):
indexed_bs = _getitem_batch_size(self.batch_size, index)
if isinstance(value, dict):
value = TensorDict.from_dict(value, batch_size=indexed_bs)
value = self.from_dict_instance(value, batch_size=indexed_bs)
# value = self.empty(recurse=True)[index].update(value)
if value.batch_size != indexed_bs:
if value.shape == indexed_bs[-len(value.shape) :]:
Expand Down Expand Up @@ -1299,6 +1305,7 @@ def from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None):
)

batch_size_set = torch.Size(()) if batch_size is None else batch_size
input_dict = copy(input_dict)
for key, value in list(input_dict.items()):
if isinstance(value, (dict,)):
# we don't know if another tensor of smaller size is coming
Expand All @@ -1318,6 +1325,41 @@ def from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None):
out.batch_size = batch_size
return out

def from_dict_instance(
self, input_dict, batch_size=None, device=None, batch_dims=None
):
if batch_dims is not None and batch_size is not None:
raise ValueError(
"Cannot pass both batch_size and batch_dims to `from_dict`."
)
from tensordict import TensorDict

batch_size_set = torch.Size(()) if batch_size is None else batch_size
input_dict = copy(input_dict)
for key, value in list(input_dict.items()):
if isinstance(value, (dict,)):
cur_value = self.get(key, None)
if cur_value is not None:
input_dict[key] = cur_value.from_dict_instance(
value, batch_size=[], device=device, batch_dims=None
)
continue
# we don't know if another tensor of smaller size is coming
# so we can't be sure that the batch-size will still be valid later
input_dict[key] = TensorDict.from_dict(
value, batch_size=[], device=device, batch_dims=None
)
out = TensorDict.from_dict(
input_dict,
batch_size=batch_size_set,
device=device,
)
if batch_size is None:
_set_max_batch_size(out, batch_dims)
else:
out.batch_size = batch_size
return out

@staticmethod
def _parse_batch_size(
source: T | dict,
Expand Down Expand Up @@ -2339,6 +2381,8 @@ def _convert_inplace(self, inplace, key):
raise RuntimeError(_LOCK_ERROR)
return inplace

from_dict_instance = TensorDict.from_dict_instance

def _set_str(
self,
key: NestedKey,
Expand Down
62 changes: 56 additions & 6 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,57 @@ def auto_batch_size_(self, batch_dims: int | None = None) -> T:
_set_max_batch_size(self, batch_dims)
return self

@abc.abstractmethod
def from_dict_instance(
self, input_dict, batch_size=None, device=None, batch_dims=None
):
"""Instance method version of :meth:`~tensordict.TensorDict.from_dict`.

Unlike :meth:`~tensordict.TensorDict.from_dict`, this method will
attempt to keep the tensordict types within the existing tree (for
any existing leaf).

Examples:
>>> from tensordict import TensorDict, tensorclass
>>> import torch
>>>
>>> @tensorclass
>>> class MyClass:
... x: torch.Tensor
... y: int
>>>
>>> td = TensorDict({"a": torch.randn(()), "b": MyClass(x=torch.zeros(()), y=1)})
>>> print(td.from_dict_instance(td.to_dict()))
TensorDict(
fields={
a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
b: MyClass(
x=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
y=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> print(td.from_dict(td.to_dict()))
TensorDict(
fields={
a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)

"""
...

# Module interaction
@classmethod
def from_module(
Expand Down Expand Up @@ -672,13 +723,12 @@ def _batch_size_setter(self, new_batch_size: torch.Size) -> None:
)
if not isinstance(new_batch_size, torch.Size):
new_batch_size = torch.Size(new_batch_size)
for key in self.keys():
if _is_tensor_collection(self.entry_class(key)):
tensordict = self.get(key)
if len(tensordict.batch_size) < len(new_batch_size):
for key, value in self.items():
if _is_tensor_collection(type(value)):
if len(value.batch_size) < len(new_batch_size):
# document as edge case
tensordict.batch_size = new_batch_size
self._set_str(key, tensordict, inplace=True, validated=True)
value.batch_size = new_batch_size
self._set_str(key, value, inplace=True, validated=True)
self._check_new_batch_size(new_batch_size)
self._change_batch_size(new_batch_size)
if self._has_names():
Expand Down
6 changes: 6 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,12 @@ def _exclude(
) -> TensorDictBase:
...

@_carry_over
def from_dict_instance(
self, input_dict, batch_size=None, device=None, batch_dims=None
):
...

@_carry_over
def _legacy_transpose(self, dim0, dim1):
...
Expand Down
1 change: 1 addition & 0 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,7 @@ def _unsqueeze(self, dim):
_to_module = TensorDict._to_module
_unbind = TensorDict._unbind
_get_names_idx = TensorDict._get_names_idx
from_dict_instance = TensorDict.from_dict_instance


def _set_max_batch_size(source: PersistentTensorDict):
Expand Down
Loading
Loading