Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 22, 2023
1 parent 51ac5ec commit ccc1d70
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None)
else:
swap = swap_dest
memo[id(module)] = swap
_swap = {}

for key, value in self.items():
if isinstance(value, (Tensor, ftdim.Tensor)):
Expand Down Expand Up @@ -320,7 +321,12 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None)
swap = swap.to(device=local_out.device)

if return_swap:
swap._set_str(key, local_out, inplace=False, validated=True)
_swap[key] = local_out
if return_swap:
if isinstance(swap, TensorDict):
swap._tensordict.update(_swap)
else:
swap.update(_swap)
return swap

def __ne__(self, other: object) -> T | bool:
Expand Down

0 comments on commit ccc1d70

Please sign in to comment.