-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEEDBACK] Transforms V2 API #6753
Comments
Could you post the link to the blog post in this thread when it becomes available? I am not sure where else to look to find it. |
Can you provide a good starting point for getting an overview of the current state of Transforms V2? Or is reading through https://github.com/pytorch/vision/tree/main/torchvision/prototype/transforms and https://github.com/pytorch/vision/projects/5 the best approach for now? |
@rsokl @jangop We got a blogpost in the pipeline that provides an overview. We are waiting for marketing to publish. I'll post the link here once we do. Until then those 2 references are the best places to look. There is no documentation because the API was being modified but all transforms receive the same exact parameters as V1. |
Hi all, I just read the blog post and have been exploring the prototype's internals for the last few of days: First, let me summarize some of the primary features that are offered by v2. This will, in part, help me make sure that I have a clear understanding of things. v2 includes:
I had been doing a review of augmentation libraries, including kornia, albumentations, and augly, when I came across v2 for torchvision. To me, your new API is the simplest, most capable, and the easiest to extend. I particularly like how simple So first of all, thank you for all of the hard work that you have been doing on this effort. The PyTorch community is fortunate to benefit from this excellent work ❤️. There are two features that I would like to propose. I would be happy to assist with these if there is interest. Enabling local reproducibility by passing
|
Thanks a ton for the kind words @rsokl! I think your summary is on point. Let me share my thoughts on your proposals. Allow
|
def convert_format_bounding_box( |
because there would only be a single kernel. Even if we added such a dispatcher for consistency we would get into hot waters, since the dispatcher would violate the rule "for BC plain tensors need to be treated as image or video" since there is no image or video kernel to dispatch to. This rule is currently kept by all our other dispatchers which allows us to keep them BC even with regard to JIT. Of course there is no BC concern for new functions, but it is a lot easier to say "the functional API is JIT compatible" rather than "the functional API that was already there for v1 is JIT compatible, but for the remainder it depends".
Internally, a lot of transformations look up the spatial size from the inputs with
height, width = query_spatial_size(flat_inputs) |
This is already pretty flexible and can pull the information from a lot of types, but there is no "official" protocol yet. For now we are using the spatial_size
attribute
vision/torchvision/prototype/transforms/functional/_meta.py
Lines 103 to 104 in 4f3a000
elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)): | |
return list(inpt.spatial_size) |
but this only a convention.
Taking 3. one step further, some transformations also query bounding boxes like that:
vision/torchvision/prototype/transforms/_utils.py
Lines 105 to 108 in 4f3a000
def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox: | |
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)] | |
if not bounding_boxes: | |
raise TypeError("No bounding box was found in the sample") |
Without a more elaborate protocol, any bounding-box-like feature like RotatedBoundingBox
proposed in [RFC] Rotated Bounding Boxes #2761 cannot be a free feature, but has to subclass features.BoundingBox
. Otherwise it won't be picked up and in turn the transformation will fail.
These are just the limitations from the top of my head. There will likely be more if take a closer look. Thus, the design for this will be more complex.
That being said some of these points mentioned above are already in our backlog, but we focused on performance the past few weeks. I agree allowing users to extend our API like that would be very powerful, but we need to solve all rough edges before we can promote this as official part of the API. Let's hear what @datumbox and @vfdev-5 say about this before I cast my vote.
@rsokl Would you perhaps share a few words on why / how / in which regard Transforms V2 is “the simplest, most capable, and the easiest to extend”? Reading through the blog post, the one section I missed was related work for context. Personally, I only have experience with albumentations. Coming from that, I am inclined to agree with your assessment. |
Is there an estimate in terms of either time, or version number, or both, for when the v2 transforms will be included in pytorch-stable? The blog says "planned for Q1" but I wonder if there's a better estimate than that somewhere. I was about to implement, on a small scale, functions that basically do what v2 does, but if your prototype works well then I will use it for my project. Should I expect any issues if I train a model on the nightly builds with the v2 transforms, but then run it for predictions on pytorch-stable 1.13? Other than the v2 transforms, I do not use any unusual features. |
test case: image segmentation with SegFormer. My dataset has images and masks organized in tuples. I am following this example: https://huggingface.co/blog/fine-tune-segformer I need to make sure that geometric transforms (horizontal flip, rotation, affine, perspective) are applied randomly, but are applied the exact same way to each image and its corresponding mask. I am testing with horizontal flip and the image and the mask are flipped in an uncoordinated fashion. Code example (admittedly inefficient): from torchvision.prototype.transforms import (
Compose,
RandomApply,
ColorJitter,
RandomRotation,
RandomCrop,
RandomAffine,
RandomHorizontalFlip,
RandomPerspective,
)
from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
feature_extractor = SegformerFeatureExtractor()
augmentations = Compose(
[
# RandomApply([ColorJitter(brightness=0.5, contrast=0.5)], p=0.75),
RandomHorizontalFlip(p=0.5),
# RandomApply([RandomRotation(degrees=45)], p=0.75),
# RandomApply([RandomAffine(degrees=0.0, scale=(0.5, 2.0))], p=0.25),
# RandomPerspective(distortion_scale=0.5, p=0.25),
]
)
def train_transforms(example_batch):
# original
images = [augmentations(x) for x in example_batch["pixel_values"]]
labels = [x for x in example_batch["label"]]
inputs = feature_extractor(images, labels)
return inputs
def train_transforms2(example_batch):
# my version
batch_items = list(
zip(
[x for x in example_batch["pixel_values"]],
[x for x in example_batch["label"]],
)
)
batch_items_aug = [
augmentations(
features.Image(
np.swapaxes(np.array(x[0]), 0, -1), color_space=ColorSpace.RGB
),
features.Mask(np.swapaxes(np.array(x[1]), 0, -1)),
)
for x in batch_items
]
images, labels = map(list, zip(*batch_items_aug))
inputs = feature_extractor(images, labels)
return inputs
train_ds.set_transform(train_transforms2) This is one entry in the train dataset that the transforms are applied to:
If I repeatedly visualize the image and the mask in What I expect:
|
Hey @FlorinAndrei
We are aiming to publish it with the next release, i.e.
No. Apart from JIT, the transforms are BC and thus there should be no complications. You can have a look at #6433 where we do just that.
You are not flipping the masks at all: images = [augmentations(x) for x in example_batch["pixel_values"]]
labels = [x for x in example_batch["label"]] The horizontal flip is happening in Transforms v2 can handle arbitrary input structures and so you don't need to handle images and masks separately. You can just pass them into the transforms together like: augmented_batch = [augmentations(sample) for sample in example_batch]
images, labels = zip(*augmented_batch)
inputs = feature_extractor(images, labels) The only caveat of this is that In transforms v2 this is done by wrapping the data into custom tensor classes located under from torchvision.prototype import features, transforms
from torchvision.prototype.transforms import functional as F
class WrapIntoFeatures(transforms.Transform):
def forward(self, sample):
image, label = sample
label = features.Mask(F.pil_to_tensor(mask))
return image, label
augmentations = transforms.Compose([WrapIntoFeatures(), ...]) For custom datasets this will have to happen manually. For our builtin ones, we currently explore the options. See #6662 for the current favorite. I've put together a small notebook for you. I hope it helps. but let us know if you encounter anything. |
I understand the need of wrapping the torch.Tensor class to mimic the behavior of "matching" transformations for a specific input type using (a) It is fine for "transforms v2", that operation on two feature.Image types as input defaults back to a non features.Image tensor. I think this behavior might feel inconsistent for some users? It might be that my concerns are irrelevant: I have just read through the code briefly. However, I think that this behavior needs to be specified / discussed upfront. import torch as th
from torchvision.prototype.features import Image
img1 = Image(th.rand(3, 256, 256))
img2 = Image(th.rand(3, 256, 256))
type(img1), type(img2)
>> torchvision.prototype.features._image.Image, torchvision.prototype.features._image.Image
img = img1 * img2
type(img)
>> torch.Tensor |
Not sure I understand.
meaning all tensors and PIL images will be transformed.
Inside our transformations, we make sure that there is no inconsistent behavior. If you pass in a vision/torchvision/prototype/transforms/functional/_misc.py Lines 64 to 66 in 7a7ab7e
The behavior is specified and encoded here:
The function has extensive comments what it is doing, but let me give you the TL;DR: except for very few operators, namely vision/torchvision/prototype/features/_feature.py Lines 55 to 61 in 7a7ab7e
there is no fast and safe way for us to determine if the result of the operation should retain the feature type or not. You already highlighted a few cases above. Thus, any operation except for the ones mentioned above will "unwrap" the result, i.e. give you a plain Now, there are times where the result should retain the input type. This is happening all over our transforms. In that case you will need to wrap again, like output = ...
return features.Image(output) As a shorthand if you don't want to copy the metadata manually, you can use vision/torchvision/prototype/features/_image.py Lines 113 to 115 in 7a7ab7e
(Note that the |
@pmeier: Yes I agree and I think it's great that you can define Transformations with op or no-op depending on the specified As addition to your provided example, here are two masks for your image above. from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F
# Loading data
img = features.Image(io.read_image('COCO_val2014_000000418825.jpg'), color_space=features.ColorSpace.RGB)
smask = features.Mask((io.read_image('strawberry.png') != 0).type(th.uint8)) # has a no-op for GaussianBlur
tmask = features.Mask((io.read_image('tomato.png') != 0).type(th.uint8)) # has a no-op for GaussianBlur
sotmask = smask | tmask # is now a simple tensor and _isinstance(..., _is_simple_tensor) --> has op for GuassianBlur
# Defining and applying Transforms V2
trans = T.Compose(
[
T.GaussianBlur(kernel_size=7),
]
)
np.all((trans(img) == img).numpy())
>> False. # Expected
np.all((trans(smask) == smask).numpy())
>> True # Expected
np.all((trans(sotmask) == sotmask).numpy())
>> False # This is what I am talking about... Edit: Fixed, code based on @pmeier's comment below. |
@maxoppelt thanks for your feedback and the discussion ! Concerning your point about
and your examples: type(img1), type(img2)
>> torchvision.prototype.features._image.Image, torchvision.prototype.features._image.Image
img = img1 * img2
type(img)
>> torch.Tensor and np.all((trans(sotmask) == sotmask).numpy())
>> False # This is what I am talking about... What do you think would be a consistent/expected behaviour ? Should they raise an error instead ? |
You are right that the communication of this needs to be clear. Right now documentation is scarce, but this will change before the release. Basically the rule is: "if you perform any operation on tensor subclasses outside of our builtin transforms, you need to wrap again". We will have detailed documentation about how to write a custom transform that will cover this. We thought about this very issue a lot during the design phase. In the end we decided against an automatic detection whether or not the result can retain the type or not due to the issues listed above. And the behavior you mention is just the logical consequence of this. Unless we missed something obvious, there is no better way to do this if we want to use the type of the input as dispatch mechanism. The solutions of comparable libraries also require you to be explicit either by calling the transform like P.S.: >>> label = features.Label([0, 1, 1], categories=["strawberry", "tomato"])
>>> label
Label([0, 1, 1])
>>> label.to_categories()
['strawberry', 'tomato', 'tomato'] |
Depends, I could image something like type(img)
> torchvision.prototype.features._image.Image
type(mask)
> torchvision.prototype.features._mask.Mask
result = image * mask
type(result)
> torchvision.prototype.features._image.Image would be self-explanatory and result in something like this: Raising an error is probably too much, as it might be a valid operation... It is just not clear that any operation using a However, I kind of agree with @pmeier, as these new An alternative approach could look like this: type(img1), type(img2), type(mask)
> th.Tensor, th.Tensor, th.Tensor
trans = T.Compose(
[
T.GaussianBlur(kernel_size=7),
T.ColorJitter(contrast=0.5),
T.RandomRotation(30),
T.CenterCrop(480)
],
input_types = [Image, Image, Mask]
)
trans(img1, img2, mask) instead of this: type(img1), type(img2), type(mask)
> torchvision.prototype.features._image.Image, torchvision.prototype.features._image.Image, torchvision.prototype.features._mask.Mask
trans = T.Compose(
[
T.GaussianBlur(kernel_size=7),
T.ColorJitter(contrast=0.5),
T.RandomRotation(30),
T.CenterCrop(480)
],
)
trans (img1, img2, mask) However, this approach might have some disadvantages, too: Type and value are separated and why not use https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor (like currently done...). At the end it is probably an issue that could be resolved/prevented in the future by writing a good documentation like @pmeier pointed out. Maybe one could use python warnings to inform the user... Btw. why is the new torchvision type system called features? In most cases torchvision transforms will work on the input and not latent representation (usually called features). |
@pmeier In the example above, I actually use The reason is that I was invoking the dataset twice: once to get the image, and again to get the mask. Those were two separate applications of the augmentation function, and of course they were not coordinated in terms of random geometric transforms. All that was fixed when I extracted that single item from the dataset once, and then extracted the image and the mask from the item. Apologies for the non-issue. |
@FlorinAndrei Thanks for taking the time to provide in depth feedback. We really appreciate it. Please keep it coming; perhaps the confusion caused indicates there are still rough edges on the API that we need to fix. Or perhaps we need to document the gotcha's better. If you have other ideas on the overall API (both public and developer) or on the naming conventions please let us know. :) |
You are absolutely right. This is very confusing. This is a placeholder name until we find something better. We need a concept for the base tensor class that can be reused for Images, Videos, BoundingBoxes, Labels, Masks etc. Naming is NP-hard, any help on that front would be highly appreciated... |
Perhaps |
The naming issue is not new. See for example the thread in #5045. We used "feature" in the beginning since that is what
|
Hi. As a PhD student working on object detection and instance segmentation, I'm very happy to see this added to torchvision. I'm eager to use this new transforms API in my own code base. Possible inconsistencies between boxes and masksWhen doing (box-based) instance segmentation, both target boxes and target masks need to be provided to the model. As written in the blog post, one would use
to apply the transforms to the various data structures. By applying each of the transforms on the boxes and masks individually (as I believe is done in Transforms V2), one might expect the transformed bounding boxes to tightly fit the transformed object masks, as was the case before the applying the transforms. However, some operations like the crop operation might result in inconsistencies between the cropped boxes and masks, with the cropped boxes no longer tightly fitting the cropped masks. See https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/dataset_mapper.py#L135 for a toy example. I believe the same might happen to other transforms, such as e.g. rotation-based transforms. I'm not sure how much these inconsistencies affect the performance (I expect not much), but I think this issue needs to be addressed. If not, some users might think there's some kind of bug as they would expect the transformed bounding boxes to always tightly fit the transformed masks. Possible solutionsRegardless of the precise solution to this problem, I think a Solution 1: Change nothing to the API. It's up to the users to decide on how they want to transform the boxes. If the user wants the transformed boxes to always tightly fit the transformed masks, then the user can proceed as follows:
Solution 2: Change the API. I'm not sure what the cleanest solution would be. Maybe something like
could be considered where internally something like
would be used. |
Hey @CedricPicron and thanks for your feedback. I agree that we need to touch on this in the documentation to explain what is happening.
There is already an operator for it: As for the proposed solutions, with the op from above, you can easily add a custom transform that does what you want: from typing import *
from torchvision.prototype import transforms, features
from torchvision.prototype.transforms import functional as F
from torchvision.ops import masks_to_boxes
# We are currently debating whether we should make this public
from torchvision.prototype.transforms._utils import has_all
# This is modeled after `query_bounding_box` located in `torchvision.prototype.transforms._utils`
def query_mask(flat_inputs: List[Any]) -> features.Mask:
masks = [inpt for inpt in flat_inputs if isinstance(inpt, features.Mask)]
if not masks:
raise TypeError("No mask was found in the sample")
elif len(masks) > 1:
raise ValueError("Found multiple masks in the sample")
return masks.pop()
class TightenBoundingBoxes(transforms.Transform):
_transformed_types = (features.BoundingBox,)
def _check_inputs(self, flat_inputs: List[Any]) -> None:
# Of course, we could also make this a no-op in case we don't find both
if not has_all(flat_inputs, features.Mask, features.BoundingBox):
raise TypeError("TightenBoundingBoxes needs masks and bounding boxes")
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
mask = query_mask(flat_inputs)
tight_bounding_box = masks_to_boxes(mask)
return dict(tight_bounding_box=tight_bounding_box)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return features.BoundingBox.wrap_like(
inpt,
F.convert_format_bounding_box(
params["tight_bounding_box"].to(inpt),
old_format=features.BoundingBoxFormat.XYXY,
new_format=inpt.format,
inplace=True,
),
) Stealing the toy example from above: import torch
mask1 = features.Mask(torch.triu(torch.ones(4, 4), diagonal=1).unsqueeze(0))
box1 = features.BoundingBox(
[[1, 0, 3, 2]], format=features.BoundingBoxFormat.XYXY, spatial_size=mask1.shape[-2:]
)
print(mask1)
print(box1)
# Emulating a fixed crop transform
mask2 = F.crop(mask1, top=0, left=0, height=4, width=3)
box2 = F.crop(box1, top=0, left=0, height=4, width=3)
print(mask2)
print(box2)
transform = TightenBoundingBoxes()
mask3, box3 = transform(mask2, box2)
print(mask3)
print(box3)
Now you can drop this in wherever you need it in your pipeline. Given that this has a performance implication and we currently don't have evidence that this actually impacts the performance of the model, IMO it is best to keep this behavior manual. Plus, it is not as easy as the check you proposed in solution 2 above since these transforms need to work regardless of the task, but object detection may have no masks available. If you have another look at the output from the example above, there is another case where some manual action is required by the user if we didn't have We already have two builtin transforms that function similar to the one I proposed above:
|
Hi @pmeier. Thanks a lot for your detailed response! I like the proposed solution based on the Some additional comments:
|
I agree the naming of the modules is not perfect here. On the flip side, I'd wager a guess regardless of what scheme you choose, you will always have outliers that don't fit anywhere. Plus, you are looking at internal / private namespaces here. I'm fully aware that this is the only thing you can do right now due to the non-existent other documentation, but this will change before release. Meaning, you can discover everything there instead of going through the source. Since this bounding box behavior might indeed be unexpected, I think it would be good to add a small gallery to show the effects. If you do want to look at the source, I suggest to have a look at the
I agree with the statement, but I think there is a far easier solution: don't pass the bounding box. Unless you have a transformation in your pipeline that requires a bounding box to be present, e.g. from torch.utils._pytree import tree_flatten, tree_unflatten
class NoBoundingBoxesContainer(transforms.Compose):
def forward(self, *inputs):
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
indexed_mask = None
indexed_bounding_box = None
everything_else = []
for idx, inpt in enumerate(flat_inputs):
if isinstance(inpt, features.BoundingBox):
indexed_bounding_box = (idx, inpt)
elif isinstance(inpt, features.Mask):
indexed_mask = (idx, inpt)
else:
everything_else.append((idx, inpt))
# do a proper error checking here if a mask and box is available
transformed_indexed_mask, transformed_everything_else = super().forward(
indexed_mask, everything_else
)
transformed_indexed_bounding_box = (
indexed_bounding_box[0],
features.BoundingBox.wrap_like(
indexed_bounding_box[1],
F.convert_format_bounding_box(
masks_to_boxes(transformed_indexed_mask[1]).to(indexed_bounding_box[1]),
old_format=features.BoundingBoxFormat.XYXY,
new_format=indexed_bounding_box[1].format,
inplace=True,
),
),
)
flat_outputs = list(
zip(
*sorted(
[
transformed_indexed_bounding_box,
transformed_indexed_mask,
*transformed_everything_else,
],
key=lambda indexed_output: indexed_output[0],
)
)
)[1]
return tree_unflatten(flat_outputs, spec) Although this looks quite daunting, it actually doesn't do anything complicated. Basically we fish out the mask and bounding box from the input, transform the mask as well as everything else, create a new bounding box from the transformed mask, and assemble everything back together. With this you can do pipeline = transforms.Compose(
[
NoBoundingBoxesContainer(
[
transforms.RandomRotation(...),
transforms.RandomCrop(...),
]
),
transforms.RandomIoUCrop(...),
NoBoundingBoxesContainer(...),
]
) Still, IMO we are going deep into specialized transforms here. Unless there is significant demand for something like this in the library, I think you are better off defining such a transform yourself. |
Yes, using I guess the key for users will be to have good documentation and examples regarding the implementation of custom Thanks @pmeier for the quick responses and good luck finalizing this project. I hope the feedback was (somewhat) useful. |
@eirikeve With the upcoming release the v2 wrapper is now also pickleable and thus will work with a multiprocessing spawning context as is the default on macOS. The patch should be available as nightly release in a few hours. |
I am trying to use the CutMix augmentation following the guide on the web page: https://pytorch.org/vision/main/auto_examples/v2_transforms/plot_cutmix_mixup.html#sphx-glr-auto-examples-v2-transforms-plot-cutmix-mixup-py |
@fvgt you may need to install torchvision from source: https://github.com/pytorch/vision/blob/main/CONTRIBUTING.md#development-installation |
@fvgt you are looking at the documentation for the |
That was my first intuition as well and I tried using the nightly version, using the following command: pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 Unfortunately, I still got the same error. Edit: I will follow the guide posted by @vfdev-5 and will check again. Thank you very much for the quick replies! |
@fvgt This is most likely an environment issue. Of course it should work as well, but there is no need to build from source. The nightly build is sufficient. Please open a separate issue and post the results of the environment collection script. |
the cutmix only support for classification task for now, hope v2.cutmix support segmentation task! thanks. |
If I have a batch of videos and I want to run self.train_video_pipeline = torch.nn.Sequential(
v2.RandomPerspective(0.5,1),
torchvision.transforms.Normalize(0.15228, 0.0794)) All videos will have the same transformation. train_video_pipeline = torch.nn.Sequential(
v2.RandomPerspective(0.2,0.6),
torchvision.transforms.Normalize(0.15228, 0.0794))
out = train_video_pipeline(1+torch.zeros((10,123,1,100,100))) # e.g. 10 videos each with 123 time-steps, one channel
import matplotlib.pyplot as plt
axs = plt.subplots(1,3,figsize=(13,4))[1]
axs[0].imshow(out[0,7,0].detach().numpy())
axs[1].imshow(out[1,7,0].detach().numpy())
axs[2].imshow(out[2,7,0].detach().numpy()) Is there anyway to have different transformation for each video? |
Hi @orena1 , the main way to do that is to unbatch, call the random transforms individually on all samples (or use `.get_params + functional API)), and then re-batch the samples. This is something we'd like to support more transparently, perhaps at least by providing some kind of |
Howdy! The TVTensors + V2 transforms are a pretty cool addition. I'm finding it easy to integrate into one of my current projects, which is great. I found and am using
|
No.
Yup. Right now we hard-assume that bounding boxes are in absolute coordinates. This makes it easier to implement the corresponding kernels:
From your comment I get that normalized bounding boxes are only required for the model. If that is true, I suggest you implement a custom import torch
from torchvision import tv_tensors
def normalize_bounding_boxes(bounding_boxes: tv_tensors.BoundingBoxes, dtype=torch.float32) -> torch.Tensor:
canvas_height, canvas_width = bounding_boxes.canvas_size
# The .as_subclass(torch.Tensor) is not required, but only a performance improvement
# See https://pytorch.org/vision/stable/auto_examples/transforms/plot_tv_tensors.html#why-is-this-happening
return (
bounding_boxes.as_subclass(torch.Tensor)
.to(dtype)
.div_(
torch.tensor(
[canvas_width, canvas_height, canvas_width, canvas_height],
dtype=dtype,
device=bounding_boxes.device,
)
)
)
class NormalizeBoundingBoxes(torch.nn.Module):
def forward(self, image, target):
target["boxes"] = normalize_bounding_boxes(target["boxes"])
return image, target
image = tv_tensors.Image(torch.rand(3, 100, 100))
bounding_boxes = tv_tensors.BoundingBoxes(
[[50, 50, 25, 25]],
format=tv_tensors.BoundingBoxFormat.CXCYWH,
canvas_size=(100, 100),
)
target = {"boxes": bounding_boxes}
transform = NormalizeBoundingBoxes()
transformed_sample = transform(image, target)
torch.testing.assert_close(
transformed_sample[1]["boxes"],
torch.tensor([[0.5, 0.5, 0.25, 0.25]]),
) This requires you to hardcode the schema of the samples that you want to pass. If you need a version of the transform that works for arbitrary sample schemas, as is the default for all builtin v2 transforms, you can do: from torchvision.transforms import v2 as transforms
class NormalizeBoundingBoxes(transforms.Transform):
_transformed_types = (tv_tensors.BoundingBoxes,)
def _transform(self, input, params):
return normalize_bounding_boxes(input) But be aware that we are using private parts of the API here and there no BC guarantees for them. |
@pmeier sounds good! I appreciate the quick and thorough reply. I'll give this a go in my project. |
Thanks for making this new API for transformations it's great! I was sent here from a link on the ToDtype page, as I'm trying to figure out the intent and consequences of the My understanding was that (for instance for a float32 I dug around in the implementation a bit, and while there is some checking to see if the data types support scaling, I'm not seeing any actual computational consequences of the vision/torchvision/transforms/v2/_misc.py Line 206 in d234307
But there is a good chance I'm just missing something too 😄 . At any rate, I can implement it myself easily enough, but I was confused by |
@EricThomson ToDtype calls functional implementation:
You can see there a quick return when scale is False: vision/torchvision/transforms/v2/functional/_misc.py Lines 214 to 215 in d234307
|
@EricThomson if you pass a torch.float32 tensor to To convert a float tensor form an arbitrary scale to another, you could use |
Thanks @vfdev-5 for pointing out in more detail how kernel dispatching works (I'm embarrassed I didn't go deeply enough 😳 ). The logic becomes clear in @NicolasHug thanks for explaining in more detail and the suggestion. I'm not sure if Clearly I was trying to get |
Normalize just returns
To clarify the feature request: you mean converting from an arbitrary scale into [0, 1], where the arbitrary scale of the input |
@NicolasHug nice! I can piggyback on In terms of the feature request, yes that is what I was suggesting. |
Thank you so much everyone for your input and feedback. The V2 transforms are now stable and part of the latest torchvision release https://github.com/pytorch/vision/releases/tag/v0.17.0. I'll close this issue as it's getting quite big and somewhat outdated now, but we'd still love to hear from you! Please feel free to open new issues with any feedback or feature requests you may have! |
🚀 The feature
This issue is dedicated for collecting community feedback on the Transforms V2 API. Please review the dedicated blogpost where we describe the API in detail and provide an overview of its features.
We would love to get your thoughts, comments and input in order to improve the API and graduate it from prototype on the near future.
Please also check out #7319 where we collect feedback on some specific design decision, and document as well which APIs may change in the future!
Code example using this image:
The text was updated successfully, but these errors were encountered: