Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] TensorDictModule with single NonTensorStack fails #821

Closed
3 tasks done
jkrude opened this issue Jun 19, 2024 · 0 comments · Fixed by #822
Closed
3 tasks done

[BUG] TensorDictModule with single NonTensorStack fails #821

jkrude opened this issue Jun 19, 2024 · 0 comments · Fixed by #822
Assignees
Labels
bug Something isn't working

Comments

@jkrude
Copy link

jkrude commented Jun 19, 2024

Describe the bug

If the module of a TensorDictModule returns a single NonTensorStack then the TensorDictModule will write NonTensorData(data=None) into the output tensordict instead of the returned stack.

To Reproduce

import tensordict
from tensordict.nn import TensorDictModule
from tensordict import NonTensorStack,NonTensorData 
tdm = TensorDictModule(
    lambda : NonTensorStack(NonTensorData(1),NonTensorData(2)),
    in_keys=[],
    out_keys=['out']
)
tdm(TensorDict({}))
TensorDict(
    fields={
        out: NonTensorData(data=None, batch_size=torch.Size([]), device=None)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Expected behavior

The NonTensorStack gets written input the tensordict under the specificed key.

TensorDict(
    fields={
        out: NonTensorStack(
            [1, 2],
            batch_size=torch.Size([2]),
            device=None)},
    batch_size=torch.Size([]),
    device=None,

System info

import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)

0.4.0 1.26.4 3.10.14 (main, Mar 21 2024, 11:21:31) [Clang 14.0.6 ] darwin 2.3.0

Reason and Possible fixes

It looks like the problematic code is withinTensorDictModule::forward() in Line 1199 to Line 1204.

  1. if isinstance(tensors, (dict, TensorDictBase)):
    a. NonTensorStack inherits from TensorDictBase via LazyStackedTensorDict.
    b. As NonTensorStack however, does not store its data under keys, the same logic as for regular tensordicts cannot be applied.
  2. tensors = tuple(tensors.get(key, None) for key in self.out_keys)
    a. As the code assumes it is dealing with a regular tensordict the next step tries to extract the self.out_keys from tensor (the NonTensorStack) which fails and hence returns the default value of None.
    b. (Would be great to perhaps have a strict case in which missing keys throw an error?)
  3. tensordict_out = self._write_to_tensordict( tensordict, tensors, tensordict_out )
    a. self._write_to_tensordict will wrap None in a NonTensorData object
    b. This is expected behavior.

Interestingly, this issue only occurs if the Callable of the TensorDictModule only returns a single value. Otherwise, if one returns a tuple the isinstance(...) fails and the output is directly written to the output-tensordict.
Therefore, a minimal workaround for me currently is returning a Tuple with a single value.

Checklist

  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • I have provided a minimal working example to reproduce the bug
@jkrude jkrude added the bug Something isn't working label Jun 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants