-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TensorDict X TransformsV2 #7763
Comments
Does it make sense to open an issue in core about this? First time I hear about Apart from that, we could always monkeypatch it to support our needs. Registering a new type is straightforward if we are comfortable using private APIs. |
pytree support should be added soon: pytorch/tensordict#501 |
This works fine with v0.2.1 import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch
image = Image(torch.randint(255, (3, 64, 64), dtype=torch.uint8))
box = BoundingBoxes(torch.randint(0, 64, size=(5, 4)), format="XYXY", canvas_size=(64, 64))
label = torch.randint(10, ())
td = TensorDict({"image": image, "label": label, "meta": {"box": box}}, [])
t = Compose([Resize((32, 32)), Grayscale()])
t(td) which gives
Happy to write a tuto in tensordict (or here) to show how that can be used! |
If I can brag a bit import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch
if __name__ == "__main__":
image = Image(torch.randint(255, (5, 3, 64, 64), dtype=torch.uint8))
box = BoundingBoxes(torch.randint(0, 64, size=(5, 4)), format="XYXY", canvas_size=(64, 64))
label = torch.randint(10, ())
td = TensorDict({"image": image, "label": label, "meta": {"box": box}}, [], device="cpu")
t = Compose([Resize((32, 32)), Grayscale()])
tdt = t(td)
print(tdt)
# Makes a lazy stack of the tensordicts
td = torch.stack([td] * 100)
# Map the transform over all items on 2 separate procs
tdt = td.map(t, dim=0, num_workers=2, chunksize=1)
print(tdt) This prints the first td (like in the previous comment) but also this TensorDict(
fields={
image: Tensor(shape=torch.Size([100, 5, 1, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=False),
label: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False),
meta: TensorDict(
fields={
box: Tensor(shape=torch.Size([100, 5, 4]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([100]),
device=cpu,
is_shared=False)},
batch_size=torch.Size([100]),
device=cpu,
is_shared=False) The print shows that all items are tensors which means that the type is lost somewhere, let me check where. But it's pretty cool to see that this works almost oob! |
This PR pytorch/tensordict#589 will allow you to keep the tensor type after a call to import torchvision
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.transforms.v2 import Compose, Resize, Grayscale
from tensordict import TensorDict
import torch
if __name__ == "__main__":
image = Image(torch.randint(255, (5, 3, 64, 64), dtype=torch.uint8))
box = BoundingBoxes(
torch.randint(0, 64, size=(5, 4)),
format="XYXY",
canvas_size=(64, 64)
)
label = torch.randint(10, ())
td = TensorDict(
{"image": image, "label": label, "meta": {"box": box}},
[],
device="cpu"
)
t = Compose([Resize((32, 32)), Grayscale()])
tdt = t(td)
# Makes a lazy stack of the tensordicts
td = torch.stack([td.clone() for _ in range(100)])
# Map the transform over all items on 2 separate procs
print('calling map on', td)
tdt = td.map(t, dim=0, num_workers=2, chunksize=0)
print(tdt[0]) # the first tensordict of the lazy stack contains the original types! This prints at TD with original types
|
+1: from torchvision.io import read_image
from torchvision.transforms.v2.functional import to_pil_image, to_image
from torchvision.transforms.v2 import RandomAffine
from tensordict import TensorDict
img = to_image(read_image("./astronaut.jpg"))
transform = RandomAffine(degrees=45)
out = transform(img) # This does work on a torchvision.tv_tensors.Image
td = TensorDict({"image1": img, "image2": img}, [])
out = transform(td)
out = TensorDict.from_dict(transform(td.to_dict())) |
https://github.com/pytorch-labs/tensordict
https://pytorch.org/rl/tensordict/index.html
Some random notes after a chat I had with @vmoens
TensorsDicts don't really work with our V2 transforms right now: they don't error, but they get passed-through without being transformed:
It's because
pytree.tree_flatten(TensorDict)
returns[TensorDict]
and so our transforms just pass it through as per our convention.Some interesting property of TensorDicts is that they could potentially be able to
stack()
tensors with different shapes which is particularly relevant for BBoxes:gives:
note the
-1
in the BBox dim which replaces3
and12
.I suppose the default behaviour (i.e. not passing a custom collate_fn) could be supported by tweaking
default_collate_fn_map
https://github.com/pytorch/pytorch/blob/21ede4547aa6873971c990d527c4511bcebf390d/torch/utils/data/_utils/collate.py#L190, but it's private (CC @vmoens )The text was updated successfully, but these errors were encountered: