-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Stacking tensors of different shape #135
Conversation
# Conflicts: # test/test_tensordict.py
This is marvelous thanks. Just one thing: when stacking two heterogenous tensordicts, is there a way we could keep track of which are the heterogenous dimensions? For example: td1 = TensorDict(
fields={
action: Tensor(torch.Size([10, 32, 5, 100, 2]), dtype=torch.float32),
next: TensorDict(
fields={
observation: Tensor(torch.Size([10, 32, 5, 100, 18]), dtype=torch.float32)},
batch_size=torch.Size([10, 32, 5, 100]),
device=None,
is_shared=False)},
batch_size=torch.Size([10, 32, 5, 100]),
device=None,
is_shared=False)
td2 = TensorDict(
fields={
action: Tensor(torch.Size([10, 32, 5, 100, 3]), dtype=torch.float32),
next: TensorDict(
fields={
observation: Tensor(torch.Size([10, 32, 5, 100, 21]), dtype=torch.float32)},
batch_size=torch.Size([10, 32, 5, 100]),
device=None,
is_shared=False)},
batch_size=torch.Size([10, 32, 5, 100]),
device=None,
is_shared=False)
het_td = torch.stack([td1,td2], dim=-1)
het_td = LazyStackedTensorDict(
fields={
action: Tensor(*, dtype=torch.float32),
next: LazyStackedTensorDict(
fields={
observation: Tensor(*, dtype=torch.float32)},
batch_size=torch.Size([10, 32, 5, 100, 2]),
device=None,
is_shared=False)},
batch_size=torch.Size([10, 32, 5, 100, 2]),
device=None,
is_shared=False) We would need to keep a flag like the following het_td = LazyStackedTensorDict(
fields={
action: Tensor(*, dtype=torch.float32),
next: LazyStackedTensorDict(
fields={
observation: Tensor(*, dtype=torch.float32)},
batch_size=torch.Size([10, 32, 5, 100, 2]),
device=None,
is_shared=False)},
batch_size=torch.Size([10, 32, 5, 100, 2]),
het_dims=torch.Tensor([False,False,False,False,True])
device=None,
is_shared=False) Also, the line |
@matteobettini I love the idea of having the shape starred on the shapes that do not match. And those won't necessarily be the non-batch size dims: td1 = TensorDict({"a": torch.randn(3, 4, 3, 255, 256)}, [3, 4])
td2 = TensorDict({"a": torch.randn(3, 4, 3, 254, 256)}, [3, 4])
td = torch.stack([td1, td2], 0) will give a shape for |
yes this is perfect! What do you think about the other problem? We need to keep track of that 2 at dim 0 in the stacked thensor to let the user know that they have to use that dim to get the heterogeneous underlying tensors |
Ie, to resolve the asterisks we need to know that dim 0 is what we need to index |
Yes this unfortunately goes beyond what nested tensors can do, as we can only stack them along the first dim. I have a personal preference for option B. It would be useful to list the stuff that is missing for (y)our use cases with nestedtensor, here's a few:
|
Yes I will add a list of the stuff that is missing for my use cases with nestedtensor to my main issue in torchrl to have all in one place. Yes option B sounds good. The problem is see is that imagine we have a simulator spitting out a tensordict with a key with shapeshape With het_td[0].shape = (2,2,3,2) # Eureka it was dim 0
het_td[:,0].shape = (2,2,*,2) # Still heterogenous
het_td[...,0].shape = (2,2,2,*) # Still heterogenous With het_td[0].shape = (2,2,*,2) # Still heterogenous
het_td[:,0].shape = (2,2,3,2) # Eureka it was dim 1
het_td[...,0].shape = (2,2,2,*) # Still heterogenous So I know that the underlying nested tensor is always dim 0 but if we let users choose another dim in the stack wrapper, we have to carry that info imo. If we do not want to carry that info then maybe forcing the 0 one may be better as in option c |
Description
Allows to construct LazyStackedTensorDict instances containing entries with the same key but different shape.
These entries will raise an exception when queried with
get
or functions that (indirectly) callget
, but a proper error message will point to the solution (i.e. usingget_nestedtensor
).Partially solves
pytorch/rl#766
cc @matteobettini