You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If the module of a TensorDictModule returns a single NonTensorStack then the TensorDictModule will write NonTensorData(data=None) into the output tensordict instead of the returned stack.
0.4.0 1.26.4 3.10.14 (main, Mar 21 2024, 11:21:31) [Clang 14.0.6 ] darwin 2.3.0
Reason and Possible fixes
It looks like the problematic code is withinTensorDictModule::forward() in Line 1199 to Line 1204.
if isinstance(tensors, (dict, TensorDictBase)):
a. NonTensorStack inherits from TensorDictBase via LazyStackedTensorDict.
b. As NonTensorStack however, does not store its data under keys, the same logic as for regular tensordicts cannot be applied.
tensors = tuple(tensors.get(key, None) for key in self.out_keys)
a. As the code assumes it is dealing with a regular tensordict the next step tries to extract the self.out_keys from tensor (the NonTensorStack) which fails and hence returns the default value of None.
b. (Would be great to perhaps have a strict case in which missing keys throw an error?)
tensordict_out = self._write_to_tensordict( tensordict, tensors, tensordict_out )
a. self._write_to_tensordict will wrap None in a NonTensorData object
b. This is expected behavior.
Interestingly, this issue only occurs if the Callable of the TensorDictModule only returns a single value. Otherwise, if one returns a tuple the isinstance(...) fails and the output is directly written to the output-tensordict.
Therefore, a minimal workaround for me currently is returning a Tuple with a single value.
Checklist
I have checked that there is no similar issue in the repo
Describe the bug
If the module of a
TensorDictModule
returns a singleNonTensorStack
then theTensorDictModule
will writeNonTensorData(data=None)
into the output tensordict instead of the returned stack.To Reproduce
Expected behavior
The
NonTensorStack
gets written input the tensordict under the specificed key.System info
0.4.0 1.26.4 3.10.14 (main, Mar 21 2024, 11:21:31) [Clang 14.0.6 ] darwin 2.3.0
Reason and Possible fixes
It looks like the problematic code is within
TensorDictModule::forward()
in Line 1199 to Line 1204.if isinstance(tensors, (dict, TensorDictBase)):
a.
NonTensorStack
inherits fromTensorDictBase
viaLazyStackedTensorDict
.b. As
NonTensorStack
however, does not store its data under keys, the same logic as for regular tensordicts cannot be applied.tensors = tuple(tensors.get(key, None) for key in self.out_keys)
a. As the code assumes it is dealing with a regular tensordict the next step tries to extract the
self.out_keys
fromtensor
(the NonTensorStack) which fails and hence returns the default value ofNone
.b. (Would be great to perhaps have a strict case in which missing keys throw an error?)
tensordict_out = self._write_to_tensordict( tensordict, tensors, tensordict_out )
a. self._write_to_tensordict will wrap
None
in aNonTensorData
objectb. This is expected behavior.
Interestingly, this issue only occurs if the
Callable
of the TensorDictModule only returns a single value. Otherwise, if one returns a tuple theisinstance(...)
fails and the output is directly written to the output-tensordict.Therefore, a minimal workaround for me currently is returning a Tuple with a single value.
Checklist
The text was updated successfully, but these errors were encountered: