diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 79f99a224..cc53aaef7 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1196,7 +1196,9 @@ def forward( ) from err else: raise err - if isinstance(tensors, (dict, TensorDictBase)): + if isinstance(tensors, (dict, TensorDictBase)) and all( + key in tensors for key in self.out_keys + ): if isinstance(tensors, dict): keys = unravel_key_list(list(tensors.keys())) values = tensors.values() diff --git a/test/test_nn.py b/test/test_nn.py index 8ae650a77..0bd4ac187 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -13,7 +13,13 @@ import pytest import torch -from tensordict import LazyStackedTensorDict, tensorclass, TensorDict +from tensordict import ( + LazyStackedTensorDict, + NonTensorData, + NonTensorStack, + tensorclass, + TensorDict, +) from tensordict._tensordict import unravel_key_list from tensordict.nn import ( dispatch, @@ -365,6 +371,20 @@ def test_stateful_probabilistic_kwargs( assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) + def test_nontensor(self): + tdm = TensorDictModule( + lambda: NonTensorStack(NonTensorData(1), NonTensorData(2)), + in_keys=[], + out_keys=["out"], + ) + assert tdm(TensorDict({}))["out"] == [1, 2] + tdm = TensorDictModule( + lambda: "a string!", + in_keys=[], + out_keys=["out"], + ) + assert tdm(TensorDict({}))["out"] == "a string!" + @pytest.mark.parametrize( "out_keys", [