Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorDict X TransformsV2 #7763

Open
NicolasHug opened this issue Jul 26, 2023 · 6 comments
Open

TensorDict X TransformsV2 #7763

NicolasHug opened this issue Jul 26, 2023 · 6 comments

Comments

@NicolasHug
Copy link
Member

https://github.com/pytorch-labs/tensordict
https://pytorch.org/rl/tensordict/index.html

Some random notes after a chat I had with @vmoens


TensorsDicts don't really work with our V2 transforms right now: they don't error, but they get passed-through without being transformed:

img = torch.rand(3, 10, 10)
bbox1 = datapoints.BoundingBox(torch.rand(3, 4), format="XYXY", spatial_size=(10, 10))
bbox2 = datapoints.BoundingBox(torch.rand(12, 4), format="XYXY", spatial_size=(10, 10))

td1 = TensorDict({"img": img, "bbox": bbox1}, batch_size=[])
out = v2.Resize(20)(td1)
assert out["img"] is out["img"]  # passed-through :'(

It's because pytree.tree_flatten(TensorDict) returns [TensorDict] and so our transforms just pass it through as per our convention.


Some interesting property of TensorDicts is that they could potentially be able to stack() tensors with different shapes which is particularly relevant for BBoxes:

td2 = TensorDict({"img": img, "bbox": bbox2}, batch_size=[])
batch = torch.stack([td1, td2])

gives:

LazyStackedTensorDict(
    fields={
        bbox: BoundingBox(shape=torch.Size([2, -1, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        img: Tensor(shape=torch.Size([2, 3, 10, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False,
    stack_dim=0)

note the -1 in the BBox dim which replaces 3 and 12.


class MyDataset:
    def __getitem__(self, idx):
        img = torch.rand(3, 10, 10)
        num_bboxes = idx + 1
        bbox = datapoints.BoundingBox(torch.rand(num_bboxes, 4), format="XYXY", spatial_size=(10, 10))
        return TensorDict({"img": img, "bbox": bbox}, [])

    def __len__(self):
        return 100

from torch.utils.data import DataLoader

ds = MyDataset()

dl = DataLoader(ds, batch_size=4, collate_fn=torch.stack)  # This will work fine
dl = DataLoader(ds, batch_size=4)  # This fails

I suppose the default behaviour (i.e. not passing a custom collate_fn) could be supported by tweaking default_collate_fn_map https://github.com/pytorch/pytorch/blob/21ede4547aa6873971c990d527c4511bcebf390d/torch/utils/data/_utils/collate.py#L190, but it's private (CC @vmoens )

@pmeier
Copy link
Collaborator

pmeier commented Jul 26, 2023

It's because pytree.tree_flatten(TensorDict) returns [TensorDict] and so our transforms just pass it through as per our convention.

Does it make sense to open an issue in core about this? First time I hear about TensorDict, but supporting it in pytree sounds reasonable.

Apart from that, we could always monkeypatch it to support our needs. Registering a new type is straightforward if we are comfortable using private APIs.

@NicolasHug
Copy link
Member Author

pytree support should be added soon: pytorch/tensordict#501

@vmoens
Copy link
Contributor

vmoens commented Nov 20, 2023

This works fine with v0.2.1

import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch

image = Image(torch.randint(255, (3, 64, 64), dtype=torch.uint8))
box = BoundingBoxes(torch.randint(0, 64, size=(5, 4)), format="XYXY", canvas_size=(64, 64))
label = torch.randint(10, ())

td = TensorDict({"image": image, "label": label, "meta": {"box": box}}, [])

t = Compose([Resize((32, 32)), Grayscale()])

t(td)

which gives

TensorDict(
    fields={
        image: Image(shape=torch.Size([1, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=False),
        label: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        meta: TensorDict(
            fields={
                box: BoundingBoxes(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Happy to write a tuto in tensordict (or here) to show how that can be used!

@vmoens
Copy link
Contributor

vmoens commented Nov 20, 2023

If I can brag a bit
Check out this multiprocessed transform:

import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch

if __name__ == "__main__":
    image = Image(torch.randint(255, (5, 3, 64, 64), dtype=torch.uint8))
    box = BoundingBoxes(torch.randint(0, 64, size=(5, 4)), format="XYXY", canvas_size=(64, 64))
    label = torch.randint(10, ())
    
    td = TensorDict({"image": image, "label": label, "meta": {"box": box}}, [], device="cpu")
    
    t = Compose([Resize((32, 32)), Grayscale()])
    
    tdt = t(td)
    print(tdt)
    # Makes a lazy stack of the tensordicts
    td = torch.stack([td] * 100)
    # Map the transform over all items on 2 separate procs
    tdt = td.map(t, dim=0, num_workers=2, chunksize=1)
    print(tdt)

This prints the first td (like in the previous comment) but also this

TensorDict(
    fields={
        image: Tensor(shape=torch.Size([100, 5, 1, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=False),
        label: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False),
        meta: TensorDict(
            fields={
                box: Tensor(shape=torch.Size([100, 5, 4]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([100]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([100]),
    device=cpu,
    is_shared=False)

The print shows that all items are tensors which means that the type is lost somewhere, let me check where. But it's pretty cool to see that this works almost oob!

@vmoens
Copy link
Contributor

vmoens commented Dec 4, 2023

This PR pytorch/tensordict#589 will allow you to keep the tensor type after a call to TensorDict.map provided that you work with a lazy stack:

import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch

if __name__ == "__main__":
    image = Image(torch.randint(255, (5, 3, 64, 64), dtype=torch.uint8))
    box = BoundingBoxes(
        torch.randint(0, 64, size=(5, 4)),
        format="XYXY",
        canvas_size=(64, 64)
        )
    label = torch.randint(10, ())

    td = TensorDict(
        {"image": image, "label": label, "meta": {"box": box}},
        [],
        device="cpu"
        )

    t = Compose([Resize((32, 32)), Grayscale()])

    tdt = t(td)
    # Makes a lazy stack of the tensordicts
    td = torch.stack([td.clone() for _ in range(100)])
    # Map the transform over all items on 2 separate procs
    print('calling map on', td)
    tdt = td.map(t, dim=0, num_workers=2, chunksize=0)
    print(tdt[0]) # the first tensordict of the lazy stack contains the original types!

This prints at TD with original types

TensorDict(
    fields={
        image: Image(shape=torch.Size([5, 1, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=True),
        label: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=True),
        meta: TensorDict(
            fields={
                box: BoundingBoxes(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=True)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

@Mxbonn
Copy link

Mxbonn commented Mar 8, 2024

+1:
The following does not work:

from torchvision.io import read_image
from torchvision.transforms.v2.functional import to_pil_image, to_image
from torchvision.transforms.v2 import RandomAffine
from tensordict import TensorDict

img = to_image(read_image("./astronaut.jpg"))
transform = RandomAffine(degrees=45)
out = transform(img) # This does work on a torchvision.tv_tensors.Image

td = TensorDict({"image1": img, "image2": img}, [])
out = transform(td)

TypeError: No image, video, mask or bounding box was found in the sample
While the following does work correctly:

out = TensorDict.from_dict(transform(td.to_dict()))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants