Skip to content

Commit

Permalink
[Refactor] Refactor keys, items and values
Browse files Browse the repository at this point in the history
ghstack-source-id: 9d5436c6bbc743e3c754d5fe5f6d87b005dde014
Pull Request resolved: #1058
  • Loading branch information
vmoens committed Oct 25, 2024
1 parent 3d3ea24 commit 9232c46
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 172 deletions.
4 changes: 2 additions & 2 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _has_exclusive_keys(self):
return False

@_fails_exclusive_keys
def to_dict(self) -> dict[str, Any]: ...
def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]: ...

def _reduce_get_metadata(self):
metadata = {}
Expand Down Expand Up @@ -3418,7 +3418,7 @@ def _select(
) -> _CustomOpTensorDict:
if inplace:
raise RuntimeError("Cannot call select inplace on a lazy tensordict.")
return self.to_tensordict()._select(
return self.to_tensordict(retain_none=True)._select(
*keys, inplace=False, strict=strict, set_shared=set_shared
)

Expand Down
8 changes: 5 additions & 3 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -4226,7 +4226,7 @@ def _iter():
if self.leaves_only:
for key in self._keys():
target_class = self.tensordict.entry_class(key)
if _is_tensor_collection(target_class):
if not self.is_leaf(target_class):
continue
yield key
else:
Expand Down Expand Up @@ -4255,9 +4255,11 @@ def _iter_helper(
# For lazy stacks
value = value[0]
cls = type(value)
is_tc = _is_tensor_collection(cls)
if self.include_nested and is_tc:
if not is_non_tensor(cls):
yield from self._iter_helper(value, prefix=full_key)
is_leaf = self.is_leaf(cls)
if self.include_nested and not is_leaf:
yield from self._iter_helper(value, prefix=full_key)
if not self.leaves_only or is_leaf:
yield full_key

Expand Down
153 changes: 80 additions & 73 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5815,38 +5815,39 @@ def items(
Defaults to ``False``.
"""
if is_leaf is None:
is_leaf = _default_is_leaf
if sort:
yield from sorted(
self.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
),
key=lambda item: (
item[0] if isinstance(item[0], str) else ".".join(item[0])
),
)
else:

def _items():
if include_nested and leaves_only:
if is_leaf is None:
is_leaf = _default_is_leaf

if include_nested:
# check the conditions once only
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if not is_leaf(type(val)):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
)
else:
cls = type(val)
if not leaves_only or is_leaf(cls):
yield k, val
elif include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
yield k, val
if not is_leaf(type(val)):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
if _is_tensor_collection(cls):
if not is_non_tensor(cls):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
)
)
elif leaves_only:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
Expand All @@ -5856,16 +5857,6 @@ def _items():
for k in self.keys():
yield k, self._get_str(k, NO_DEFAULT)

if sort:
yield from sorted(
_items(),
key=lambda item: (
item[0] if isinstance(item[0], str) else ".".join(item[0])
),
)
else:
yield from _items()

def non_tensor_items(self, include_nested: bool = False):
"""Returns all non-tensor leaves, maybe recursively."""
return tuple(
Expand Down Expand Up @@ -5902,32 +5893,28 @@ def values(
Defaults to ``False``.
"""
if is_leaf is None:
is_leaf = _default_is_leaf
if sort:
for _, value in self.items(include_nested, leaves_only, is_leaf, sort=sort):
yield value
else:

if is_leaf is None:
is_leaf = _default_is_leaf

def _values():
# check the conditions once only
if include_nested and leaves_only:
if include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if not is_leaf(type(val)):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
else:
cls = type(val)
if not leaves_only or is_leaf(cls):
yield val
elif include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
yield val
if not is_leaf(type(val)):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
if include_nested and _is_tensor_collection(cls):
if not is_non_tensor(cls):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
elif leaves_only:
for k in self.keys(sort=sort):
val = self._get_str(k, NO_DEFAULT)
Expand All @@ -5937,12 +5924,6 @@ def _values():
for k in self.keys(sort=sort):
yield self._get_str(k, NO_DEFAULT)

if not sort or not include_nested:
yield from _values()
else:
for _, value in self.items(include_nested, leaves_only, is_leaf, sort=sort):
yield value

@cache # noqa: B019
def _values_list(
self,
Expand Down Expand Up @@ -9595,9 +9576,16 @@ def _maybe_set_shared_attributes(self, result, lock=False):
if lock:
result.lock_()

def to_tensordict(self) -> T:
def to_tensordict(self, *, retain_none: bool | None = None) -> T:
"""Returns a regular TensorDict instance from the TensorDictBase.
Args:
retain_none (bool): if ``True``, the ``None`` values from tensorclass instances
will be written in the tensordict.
Otherwise they will be discarded. Default: ``True``.
.. note:: from v0.8, the default value will be switched to ``False``.
Returns:
a new TensorDict object containing the same values.
Expand All @@ -9609,7 +9597,11 @@ def to_tensordict(self) -> T:
key: (
value.clone()
if not _is_tensor_collection(type(value))
else value if is_non_tensor(value) else value.to_tensordict()
else (
value
if is_non_tensor(value)
else value.to_tensordict(retain_none=retain_none)
)
)
for key, value in self.items(is_leaf=_is_leaf_nontensor)
},
Expand Down Expand Up @@ -9712,12 +9704,27 @@ def as_tensor(tensor):

return self._fast_apply(as_tensor, propagate_lock=True)

def to_dict(self) -> dict[str, Any]:
"""Returns a dictionary with key-value pairs matching those of the tensordict."""
return {
key: value.to_dict() if _is_tensor_collection(type(value)) else value
for key, value in self.items()
}
def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]:
"""Returns a dictionary with key-value pairs matching those of the tensordict.
Args:
retain_none (bool): if ``True``, the ``None`` values from tensorclass instances
will be written in the dictionary.
Otherwise, they will be discarded. Default: ``True``.
"""
result = {}
for key, value in self.items():
if _is_tensor_collection(type(value)):
if (
not retain_none
and _is_non_tensor(type(value))
and value.data is None
):
continue
value = value.to_dict(retain_none=retain_none)
result[key] = value
return result

def numpy(self):
"""Converts a tensordict to a (possibly nested) dictionary of numpy arrays.
Expand Down Expand Up @@ -9745,7 +9752,7 @@ def numpy(self):
{'a': {'b': array(0., dtype=float32), 'c': 'a string!'}}
"""
as_dict = self.to_dict()
as_dict = self.to_dict(retain_none=False)

def to_numpy(x):
if isinstance(x, torch.Tensor):
Expand Down Expand Up @@ -9786,7 +9793,7 @@ def dict_to_namedtuple(dictionary):
)
return cls(**dictionary)

return dict_to_namedtuple(self.to_dict())
return dict_to_namedtuple(self.to_dict(retain_none=False))

@classmethod
def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False):
Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def from_dict(cls, *args, **kwargs):
return TensorDictParams(td)

@_fallback
def to_tensordict(self): ...
def to_tensordict(self, *, retain_none: bool | None = None): ...

@_fallback
def to_h5(
Expand Down
Loading

0 comments on commit 9232c46

Please sign in to comment.