Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 7, 2024
1 parent 5b4693b commit b201188
Showing 1 changed file with 20 additions and 22 deletions.
42 changes: 20 additions & 22 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5193,30 +5193,29 @@ def inplace_update(name, dest, source):
if key == name[: len(key)]:
dest.copy_(source, non_blocking=non_blocking)

self._apply_nest(
inplace_update,
input_dict_or_td,
nested_keys=True,
default=None,
filter_empty=True,
named=named,
is_leaf=_is_leaf_nontensor,
)
return self
else:
named = False

def inplace_update(dest, source):
if source is None:
return None
dest.copy_(source, non_blocking=non_blocking)
if not _is_tensor_collection(type(input_dict_or_td)):
from tensordict import TensorDict

if not _is_tensor_collection(type(input_dict_or_td)):
from tensordict import TensorDict
input_dict_or_td = TensorDict.from_dict(
input_dict_or_td, batch_dims=self.batch_dims
)

input_dict_or_td = TensorDict.from_dict(
input_dict_or_td, batch_dims=self.batch_dims
)
self._apply_nest(
inplace_update,
input_dict_or_td,
nested_keys=True,
default=None,
filter_empty=True,
named=named,
is_leaf=_is_leaf_nontensor,
)
return self
# Fastest route using _foreach_copy_
keys, vals = self._items_list(True, True)
other_val = input_dict_or_td._values_list(True, True, sorting_keys=keys)
torch._foreach_copy_(vals, other_val)
return self

def update_at_(
self,
Expand Down Expand Up @@ -8157,7 +8156,6 @@ def add_(self, other: TensorDictBase | float, *, alpha: float | None = None):
.. note::
In-place ``add`` does not support ``default`` keyword argument.
"""
torch.Tensor.add_
if _is_tensor_collection(type(other)):
keys, vals = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
Expand Down

0 comments on commit b201188

Please sign in to comment.