From ccc1d706fe8c006c19ac6a88d22dcf69ef8b058c Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Nov 2023 14:03:35 +0000 Subject: [PATCH] amend --- tensordict/_td.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index b644b17a7..8091b1d7d 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -276,6 +276,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) else: swap = swap_dest memo[id(module)] = swap + _swap = {} for key, value in self.items(): if isinstance(value, (Tensor, ftdim.Tensor)): @@ -320,7 +321,12 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) swap = swap.to(device=local_out.device) if return_swap: - swap._set_str(key, local_out, inplace=False, validated=True) + _swap[key] = local_out + if return_swap: + if isinstance(swap, TensorDict): + swap._tensordict.update(_swap) + else: + swap.update(_swap) return swap def __ne__(self, other: object) -> T | bool: