Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Stacking tensors of different shape #135

Merged
merged 4 commits into from
Dec 31, 2022
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Dec 31, 2022

Description

Allows to construct LazyStackedTensorDict instances containing entries with the same key but different shape.
These entries will raise an exception when queried with get or functions that (indirectly) call get, but a proper error message will point to the solution (i.e. using get_nestedtensor).

Partially solves
pytorch/rl#766

cc @matteobettini

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 31, 2022
@vmoens vmoens added the enhancement New feature or request label Dec 31, 2022
@vmoens vmoens merged commit cfc5aff into main Dec 31, 2022
@matteobettini
Copy link
Contributor

matteobettini commented Jan 1, 2023

This is marvelous thanks.

Just one thing:

when stacking two heterogenous tensordicts, is there a way we could keep track of which are the heterogenous dimensions?

For example:

td1 = TensorDict(
    fields={
        action: Tensor(torch.Size([10, 32, 5, 100, 2]), dtype=torch.float32),
        next: TensorDict(
            fields={
                observation: Tensor(torch.Size([10, 32, 5, 100, 18]), dtype=torch.float32)},
            batch_size=torch.Size([10, 32, 5, 100]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10, 32, 5, 100]),
    device=None,
    is_shared=False)

td2 = TensorDict(
    fields={
        action: Tensor(torch.Size([10, 32, 5, 100, 3]), dtype=torch.float32),
        next: TensorDict(
            fields={
                observation: Tensor(torch.Size([10, 32, 5, 100, 21]), dtype=torch.float32)},
            batch_size=torch.Size([10, 32, 5, 100]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10, 32, 5, 100]),
    device=None,
    is_shared=False)

het_td = torch.stack([td1,td2], dim=-1)

het_td = LazyStackedTensorDict(
    fields={
        action: Tensor(*, dtype=torch.float32),
        next: LazyStackedTensorDict(
            fields={
                observation: Tensor(*, dtype=torch.float32)},
            batch_size=torch.Size([10, 32, 5, 100, 2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10, 32, 5, 100, 2]),
    device=None,
    is_shared=False)

We would need to keep a flag like the following

het_td = LazyStackedTensorDict(
    fields={
        action: Tensor(*, dtype=torch.float32),
        next: LazyStackedTensorDict(
            fields={
                observation: Tensor(*, dtype=torch.float32)},
            batch_size=torch.Size([10, 32, 5, 100, 2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10, 32, 5, 100, 2]),
    het_dims=torch.Tensor([False,False,False,False,True])
    device=None,
    is_shared=False)

Also, the line observation: Tensor(*, dtype=torch.float32)} does not tell me which dims diverge, so observation: Tensor((torch.Size([10, 32, 5, 100, 2,*]), dtype=torch.float32)} might be more informative. (In reality this does not add more info because it is always gonna be [*batch_dim,*] but maybe repetita iuvant?)

@vmoens
Copy link
Contributor Author

vmoens commented Jan 1, 2023

@matteobettini I love the idea of having the shape starred on the shapes that do not match. And those won't necessarily be the non-batch size dims:

td1 = TensorDict({"a": torch.randn(3, 4, 3, 255, 256)}, [3, 4])
td2 = TensorDict({"a": torch.randn(3, 4, 3, 254, 256)}, [3, 4])
td = torch.stack([td1, td2], 0)

will give a shape for "a" of torch.Size([2, 3, 4, 3, *, 256]) I believe

@vmoens vmoens deleted the stack_heter_shape branch January 1, 2023 13:27
@matteobettini
Copy link
Contributor

@matteobettini I love the idea of having the shape starred on the shapes that do not match. And those won't necessarily be the non-batch size dims:

td1 = TensorDict({"a": torch.randn(3, 4, 3, 255, 256)}, [3, 4])
td2 = TensorDict({"a": torch.randn(3, 4, 3, 254, 256)}, [3, 4])
td = torch.stack([td1, td2], 0)

will give a shape for "a" of torch.Size([2, 3, 4, 3, *, 256]) I believe

yes this is perfect! What do you think about the other problem? We need to keep track of that 2 at dim 0 in the stacked thensor to let the user know that they have to use that dim to get the heterogeneous underlying tensors

@matteobettini
Copy link
Contributor

Ie, to resolve the asterisks we need to know that dim 0 is what we need to index

@vmoens
Copy link
Contributor Author

vmoens commented Jan 1, 2023

Yes this unfortunately goes beyond what nested tensors can do, as we can only stack them along the first dim.
We can stack tensordicts along any dim but the nested tensors can only be stacked along the first, hence we have to choose if we want to allow stacking along other dims or not.
Option A: leave it as it is. We can stack along any dim. Nestedtensors will be stacked along dim 0. We just let the user know about this as a "bug" or such.
Option B: we can stack along any dim but can't get a nested tensor if the stack dim is not 0. We can still get other keys or get back the original tensordicts.
Option C: we prevent stacking tensordicts when they have heterogeneous shapes and the dim is not 0.

I have a personal preference for option B.

It would be useful to list the stuff that is missing for (y)our use cases with nestedtensor, here's a few:

  • stacking along any dim
  • shape (not only size)
  • indexing along any dim that is compatible
  • stacking nested tensors together (currently we can't combine a two nested tensors containing tensors of shape [[a, b], [a, c]] into a single one of shape [[[a, b], [a, c]], [[a, b], [a, c]]])

@matteobettini
Copy link
Contributor

matteobettini commented Jan 1, 2023

@vmoens

Yes I will add a list of the stuff that is missing for my use cases with nestedtensor to my main issue in torchrl to have all in one place.

Yes option B sounds good. The problem is see is that imagine we have a simulator spitting out a tensordict with a key with shapeshape torch.Size([2, 2, 2, *, 2]. We know that there is one dime which is heterogenous but we do not know which dimensions they have been stacked at. Ie if we want to get the two heterogenous underlying tensors, we have to know what the user who created them used as x in `torch.stack([...], dim=x).

With torch.stack([td1,td2], dim=0) we get:

het_td[0].shape = (2,2,3,2) # Eureka it was dim 0
het_td[:,0].shape = (2,2,*,2) # Still heterogenous
het_td[...,0].shape = (2,2,2,*) # Still heterogenous

With torch.stack([td1,td2], dim=1) we get:

het_td[0].shape = (2,2,*,2)  # Still heterogenous
het_td[:,0].shape = (2,2,3,2) # Eureka it was dim 1
het_td[...,0].shape = (2,2,2,*) # Still heterogenous

So I know that the underlying nested tensor is always dim 0 but if we let users choose another dim in the stack wrapper, we have to carry that info imo. If we do not want to carry that info then maybe forcing the 0 one may be better as in option c

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants