Skip to content

Commit

Permalink
[BugFix] Fix (keys, values) in sub (#907)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 22, 2024
1 parent 0fb5d83 commit 43faf04
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ def zero_grad(self, set_to_none: bool = True) -> T:
if set_to_none:
for val in self._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
val.grad = None
return
return self
for val in self._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
val.grad.zero_()
return self
Expand Down Expand Up @@ -7539,16 +7539,15 @@ def addcmul_(self, other1, other2, value: float | None = 1):
return self

def sub(self, other: TensorDictBase | float, alpha: float | None = None):
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
if alpha is not None:
vals = torch._foreach_sub(val, other_val, alpha=alpha)
vals = torch._foreach_sub(vals, other_val, alpha=alpha)
else:
vals = torch._foreach_sub(val, other_val)
vals = torch._foreach_sub(vals, other_val)
items = dict(zip(keys, vals))
return self._fast_apply(
lambda name, val: items.get(name, val),
Expand Down

0 comments on commit 43faf04

Please sign in to comment.