Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature, Test] Add tests for partial update #578

Merged
merged 7 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from functorch import dim as ftdim
from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list
from tensordict.base import (
_ACCEPTED_CLASSES,
_is_tensor_collection,
Expand All @@ -36,6 +36,7 @@
_getitem_batch_size,
_is_number,
_parse_to,
_prune_selected_keys,
_renamed_inplace_method,
_shape,
_td_fields,
Expand Down Expand Up @@ -1576,10 +1577,21 @@ def expand(self, *args: int, inplace: bool = False) -> T:
return self
return torch.stack(tensordicts, stack_dim)

def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T:
def update(
self,
input_dict_or_td: T,
clone: bool = False,
*,
keys_to_update: Sequence[NestedKey] | None = None,
**kwargs: Any,
) -> T:
if input_dict_or_td is self:
# no op
return self
if keys_to_update is not None:
keys_to_update = unravel_key_list(keys_to_update)
if len(keys_to_update) == 0:
return self

if (
isinstance(input_dict_or_td, LazyStackedTensorDict)
Expand All @@ -1592,7 +1604,9 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T:
for td_dest, td_source in zip(
self.tensordicts, input_dict_or_td.tensordicts
):
td_dest.update(td_source, clone=clone, **kwargs)
td_dest.update(
td_source, clone=clone, keys_to_update=keys_to_update, **kwargs
)
return self

inplace = kwargs.get("inplace", False)
Expand All @@ -1601,26 +1615,45 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T:
value = value.clone()
elif clone:
value = tree_map(torch.clone, value)
key = unravel_key(key)
if isinstance(key, tuple):
key = _unravel_key_to_tuple(key)
firstkey, subkey = key[0], key[1:]
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue

if subkey:
# we must check that the target is not a leaf
target = self._get_str(key[0], default=None)
target = self._get_str(firstkey, default=None)
if is_tensor_collection(target):
target.update({key[1:]: value}, inplace=inplace, clone=clone)
sub_keys_to_update = _prune_selected_keys(keys_to_update, firstkey)
target.update(
{subkey: value},
inplace=inplace,
clone=clone,
keys_to_update=sub_keys_to_update,
)
elif target is None:
self._set_tuple(key, value, inplace=inplace, validated=False)
else:
raise TypeError(
f"Type mismatch: self.get(key[0]) is {type(target)} but expected a tensor collection."
)
else:
target = self._get_str(key, default=None)
target = self._get_str(firstkey, default=None)
if is_tensor_collection(target) and (
is_tensor_collection(value) or isinstance(value, dict)
):
target.update(value, inplace=inplace, clone=clone)
sub_keys_to_update = _prune_selected_keys(keys_to_update, firstkey)
target.update(
value,
inplace=inplace,
clone=clone,
keys_to_update=sub_keys_to_update,
)
elif target is None or not is_tensor_collection(value):
self._set_str(key, value, inplace=inplace, validated=False)
self._set_str(firstkey, value, inplace=inplace, validated=False)
else:
raise TypeError(
f"Type mismatch: self.get(key) is {type(target)} but value is of type {type(value)}."
Expand Down
45 changes: 28 additions & 17 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
_NON_STR_KEY_ERR,
_NON_STR_KEY_TUPLE_ERR,
_parse_to,
_prune_selected_keys,
_set_item,
_set_max_batch_size,
_shape,
Expand Down Expand Up @@ -2105,18 +2106,18 @@ def update(
# no op
return self
if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
keys_to_update = unravel_key_list(keys_to_update)
else:
keys_to_update = ()
keys = set(self.keys(False))
for key, value in input_dict_or_td.items():
key = _unravel_key_to_tuple(key)
firstkey, subkey = key[0], key[1:]
if keys_to_update:
if (subkey and key in keys_to_update) or (
not subkey and firstkey in keys_to_update
):
continue
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue
if clone and hasattr(value, "clone"):
value = value.clone()
elif clone:
Expand All @@ -2127,12 +2128,22 @@ def update(
if _is_tensor_collection(target_class):
target = self._source.get(firstkey)._get_sub_tensordict(self.idx)
if len(subkey):
target._set_tuple(subkey, value, inplace=False, validated=False)
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
)
target.update(
{subkey: value},
inplace=False,
keys_to_update=sub_keys_to_update,
)
continue
elif isinstance(value, dict) or _is_tensor_collection(
value.__class__
):
target.update(value)
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
)
target.update(value, keys_to_update=sub_keys_to_update)
continue
raise ValueError(
f"Tried to replace a tensordict with an incompatible object of type {type(value)}"
Expand Down Expand Up @@ -2173,17 +2184,17 @@ def update_at_(
keys_to_update: Sequence[NestedKey] | None = None,
) -> _SubTensorDict:
if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
keys_to_update = unravel_key_list(keys_to_update)
else:
keys_to_update = ()
for key, value in input_dict.items():
key = _unravel_key_to_tuple(key)
firstkey, *keys = key
if keys_to_update:
if (keys and key in keys_to_update) or (
not keys and firstkey in keys_to_update
):
continue
firstkey, _ = key[0], key[1:]
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue
if not isinstance(value, tuple(_ACCEPTED_CLASSES)):
raise TypeError(
f"Expected value to be one of types {_ACCEPTED_CLASSES} "
Expand Down
72 changes: 51 additions & 21 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_is_tensorclass,
_KEY_ERROR,
_proc_init,
_prune_selected_keys,
_shape,
_split_tensordict,
_td_fields,
Expand All @@ -50,7 +51,6 @@
lock_blocked,
NestedKey,
prod,
unravel_key,
unravel_key_list,
)
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
Expand Down Expand Up @@ -1871,17 +1871,17 @@ def update(
# no op
return self
if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
keys_to_update = unravel_key_list(keys_to_update)
else:
keys_to_update = ()
for key, value in input_dict_or_td.items():
key = _unravel_key_to_tuple(key)
firstkey, subkey = key[0], key[1:]
if keys_to_update:
if (subkey and key in keys_to_update) or (
not subkey and firstkey in keys_to_update
):
continue
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue
target = self._get_str(firstkey, None)
if clone and hasattr(value, "clone"):
value = value.clone()
Expand All @@ -1891,25 +1891,49 @@ def update(
if target is not None:
if _is_tensor_collection(type(target)):
if subkey:
target.update({subkey: value}, inplace=inplace, clone=clone)
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
)
target.update(
{subkey: value},
inplace=inplace,
clone=clone,
keys_to_update=sub_keys_to_update,
)
continue
elif isinstance(value, (dict,)) or _is_tensor_collection(
value.__class__
):
if isinstance(value, LazyStackedTensorDict) and not isinstance(
target, LazyStackedTensorDict
):
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
)
self._set_tuple(
key,
LazyStackedTensorDict(
*target.unbind(value.stack_dim),
stack_dim=value.stack_dim,
).update(value, inplace=inplace, clone=clone),
).update(
value,
inplace=inplace,
clone=clone,
keys_to_update=sub_keys_to_update,
),
validated=True,
inplace=False,
)
else:
target.update(value, inplace=inplace, clone=clone)
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
)
target.update(
value,
inplace=inplace,
clone=clone,
keys_to_update=sub_keys_to_update,
)
continue
self._set_tuple(
key,
Expand Down Expand Up @@ -1960,12 +1984,15 @@ def update_(
# no op
return self
if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
keys_to_update = unravel_key_list(keys_to_update)
else:
keys_to_update = ()
for key, value in input_dict_or_td.items():
key = unravel_key(key)
if key in keys_to_update:
firstkey, *nextkeys = _unravel_key_to_tuple(key)
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue
# if not isinstance(value, _accepted_classes):
# raise TypeError(
Expand All @@ -1974,7 +2001,7 @@ def update_(
# )
if clone:
value = value.clone()
self.set_(key, value)
self.set_((firstkey, *nextkeys), value)
return self

def update_at_(
Expand Down Expand Up @@ -2025,12 +2052,15 @@ def update_at_(

"""
if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
keys_to_update = unravel_key_list(keys_to_update)
else:
keys_to_update = ()
for key, value in input_dict_or_td.items():
key = unravel_key(key)
if key in keys_to_update:
firstkey, *nextkeys = _unravel_key_to_tuple(key)
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue
if not isinstance(value, tuple(_ACCEPTED_CLASSES)):
raise TypeError(
Expand All @@ -2039,7 +2069,7 @@ def update_at_(
)
if clone:
value = value.clone()
self.set_at_(key, value, idx)
self.set_at_((firstkey, *nextkeys), value, idx)
return self

@lock_blocked
Expand Down
10 changes: 9 additions & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ def update(
input_dict_or_td: dict[str, CompatibleType] | TensorDictBase,
clone: bool = False,
inplace: bool = False,
*,
keys_to_update: Sequence[NestedKey] | None = None,
) -> TensorDictBase:
if not self.no_convert:
func = _maybe_make_param
Expand All @@ -397,7 +399,13 @@ def update(
else:
input_dict_or_td = tree_map(func, input_dict_or_td)
with self._param_td.unlock_():
TensorDictBase.update(self, input_dict_or_td, clone=clone, inplace=inplace)
TensorDictBase.update(
self,
input_dict_or_td,
clone=clone,
inplace=inplace,
keys_to_update=keys_to_update,
)
self._reset_params()
return self

Expand Down
3 changes: 2 additions & 1 deletion tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __iter__(self):
yield from self.tensordict._valid_keys()

def __contains__(self, key):
if isinstance(key, tuple) and len(key) == 1:
key = _unravel_key_to_tuple(key)
if len(key) == 1:
key = key[0]
for a_key in self:
if isinstance(a_key, tuple) and len(a_key) == 1:
Expand Down
8 changes: 8 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,3 +1740,11 @@ def _proc_init(base_seed, queue):
torch.manual_seed(seed)
np_seed = _generate_state(base_seed, worker_id)
np.random.seed(np_seed)


def _prune_selected_keys(keys_to_update, prefix):
if keys_to_update is None:
return None
return tuple(
key[1:] for key in keys_to_update if isinstance(key, tuple) and key[0] == prefix
)
Loading
Loading