Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable one-hot-encoded labels in MixUp and CutMix #8427

Merged
merged 6 commits into from
May 28, 2024
Merged

Enable one-hot-encoded labels in MixUp and CutMix #8427

merged 6 commits into from
May 28, 2024

Conversation

mahdilamb
Copy link
Contributor

@mahdilamb mahdilamb commented May 18, 2024

  • Enable using CutMix/MixUp with pre-encoded labels

Todo:

  • update test
  • check for already encoded inputs

cc @vfdev-5

Copy link

pytorch-bot bot commented May 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8427

Note: Links to docs will display an error until the docs builds have been completed.

❌ 12 New Failures

As of commit 218fc58 with merge base 778ce48 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@NicolasHug
Copy link
Member

Thanks for the PR @mahdilamb .

Supporting labels that are already one-hot-encoded sounds OK to me, but instead of adding a new parameter, it seems that we could instead just check the shape of the labels and only call one_hot if the ndim != 2?

We would also need to add a few tests here

class TestCutMixMixUp:

@mahdilamb
Copy link
Contributor Author

Hi @NicolasHug makes sense to me. I'll get that moving

@mahdilamb
Copy link
Contributor Author

Hi @NicolasHug , that's updates as requested... it breaks Compose though!

@NicolasHug
Copy link
Member

Hi @mahdilamb - I've made a few changes to the PR locally but when I cannot push to update the PR, because you created from your main branch, so I don't have the permissions.

Would you mind closing this one and re-opening a new PR from a dev branch (i.e. do git checkout -b my_branch before committing)?

Alternatively you could also apply this diff to the current PR:

diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py
index 190b590c89..07235333af 100644
--- a/test/test_transforms_v2.py
+++ b/test/test_transforms_v2.py
@@ -2169,29 +2169,30 @@ class TestAdjustBrightness:
 
 class TestCutMixMixUp:
     class DummyDataset:
-        def __init__(self, size, num_classes, encode_labels:bool):
+        def __init__(self, size, num_classes, one_hot_labels):
             self.size = size
             self.num_classes = num_classes
-            self.encode_labels = encode_labels
+            self.one_hot_labels = one_hot_labels
             assert size < num_classes
 
         def __getitem__(self, idx):
             img = torch.rand(3, 100, 100)
-            label = torch.tensor(idx)  # This ensures all labels in a batch are unique and makes testing easier
-            if self.encode_labels:
-                label = torch.nn.functional.one_hot(label, num_classes=self.num_classes)
+            label = idx  # This ensures all labels in a batch are unique and makes testing easier
+            if self.one_hot_labels:
+                label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes)
             return img, label
 
         def __len__(self):
             return self.size
 
-    @pytest.mark.parametrize(["T", "encode_labels"], [[transforms.CutMix, False], [transforms.MixUp, False], [transforms.CutMix, True], [transforms.MixUp, True]])
-    def test_supported_input_structure(self, T, encode_labels: bool):
+    @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
+    @pytest.mark.parametrize("one_hot_labels", (True, False))
+    def test_supported_input_structure(self, T, one_hot_labels):
 
         batch_size = 32
         num_classes = 100
 
-        dataset = self.DummyDataset(size=batch_size, num_classes=num_classes,encode_labels=encode_labels)
+        dataset = self.DummyDataset(size=batch_size, num_classes=num_classes, one_hot_labels=one_hot_labels)
 
         cutmix_mixup = T(num_classes=num_classes)
 
@@ -2201,10 +2202,7 @@ class TestCutMixMixUp:
         img, target = next(iter(dl))
         input_img_size = img.shape[-3:]
         assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
-        if encode_labels:
-            assert target.shape == (batch_size, num_classes)
-        else:
-            assert target.shape == (batch_size,)
+        assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
 
         def check_output(img, target):
             assert img.shape == (batch_size, *input_img_size)
@@ -2215,10 +2213,7 @@ class TestCutMixMixUp:
 
         # After Dataloader, as unpacked input
         img, target = next(iter(dl))
-        if encode_labels:
-            assert target.shape == (batch_size, num_classes)
-        else:
-            assert target.shape == (batch_size,)
+        assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
         img, target = cutmix_mixup(img, target)
         check_output(img, target)
 
@@ -2273,7 +2268,7 @@ class TestCutMixMixUp:
         with pytest.raises(ValueError, match="Could not infer where the labels are"):
             cutmix_mixup({"img": imgs, "Nothing_else": 3})
 
-        with pytest.raises(ValueError, match="labels tensor should be of shape"):
+        with pytest.raises(ValueError, match="labels should be index based"):
             # Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
             # It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
             cutmix_mixup(imgs)
@@ -2281,22 +2276,21 @@ class TestCutMixMixUp:
         with pytest.raises(ValueError, match="When using the default labels_getter"):
             cutmix_mixup(imgs, "not_a_tensor")
 
-        with pytest.raises(ValueError, match="labels tensor should be of shape"):
-            cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3)))
-
         with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
             cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,)))
 
         with pytest.raises(ValueError, match="does not match the batch size of the labels"):
             cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,)))
 
-        with pytest.raises(ValueError, match="labels tensor should be of shape"):
-            # The purpose of this check is more about documenting the current
-            # behaviour of what happens on a Compose(), rather than actually
-            # asserting the expected behaviour. We may support Compose() in the
-            # future, e.g. for 2 consecutive CutMix?
-            labels = torch.randint(0, num_classes, size=(batch_size,))
-            transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels)
+        with pytest.raises(ValueError, match="When passing 2D labels"):
+            wrong_num_classes = num_classes + 1
+            T(alpha=0.5, num_classes=num_classes)(imgs, torch.randint(0, 2, size=(batch_size, wrong_num_classes)))
+
+        with pytest.raises(ValueError, match="but got a tensor of shape"):
+            cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3, 4)))
+
+        with pytest.raises(ValueError, match="num_classes must be passed"):
+            T(alpha=0.5)(imgs, torch.randint(0, num_classes, size=(batch_size,)))
 
 
 @pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py
index 48daa271ea..1d01012654 100644
--- a/torchvision/transforms/v2/_augment.py
+++ b/torchvision/transforms/v2/_augment.py
@@ -1,7 +1,7 @@
 import math
 import numbers
 import warnings
-from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
 
 import PIL.Image
 import torch
@@ -142,7 +142,7 @@ class RandomErasing(_RandomApplyTransform):
 
 
 class _BaseMixUpCutMix(Transform):
-    def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default", labels_encoded: bool = False) -> None:
+    def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None:
         super().__init__()
         self.alpha = float(alpha)
         self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
@@ -150,7 +150,6 @@ class _BaseMixUpCutMix(Transform):
         self.num_classes = num_classes
 
         self._labels_getter = _parse_labels_getter(labels_getter)
-        self._labels_encoded = labels_encoded
 
     def forward(self, *inputs):
         inputs = inputs if len(inputs) > 1 else inputs[0]
@@ -163,10 +162,21 @@ class _BaseMixUpCutMix(Transform):
         labels = self._labels_getter(inputs)
         if not isinstance(labels, torch.Tensor):
             raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
-        elif not 0 < labels.ndim <= 2 or (labels.ndim == 2 and labels.shape[1] != self.num_classes):
+        if labels.ndim not in (1, 2):
             raise ValueError(
-                f"labels tensor should be of shape (batch_size,) or (batch_size,num_classes) " f"but got shape {labels.shape} instead."
+                f"labels should be index based with shape (batch_size,) "
+                f"or probability based with shape (batch_size, num_classes), "
+                f"but got a tensor of shape {labels.shape} instead."
             )
+        if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes:
+            raise ValueError(
+                f"When passing 2D labels, "
+                f"the number of elements in last dimension must match num_classes: "
+                f"{labels.shape[-1]} != {self.num_classes}. "
+                f"You can Leave num_classes to None."
+            )
+        if labels.ndim == 1 and self.num_classes is None:
+            raise ValueError("num_classes must be passed if the labels are index-based (1D)")
 
         params = {
             "labels": labels,
@@ -225,7 +235,8 @@ class MixUp(_BaseMixUpCutMix):
 
     Args:
         alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
-        num_classes (int): number of classes in the batch. Used for one-hot-encoding.
+        num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
+            Can be None only if the labels are already one-hot-encoded.
         labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
             By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
             common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
@@ -273,7 +284,8 @@ class CutMix(_BaseMixUpCutMix):
 
     Args:
         alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
-        num_classes (int): number of classes in the batch. Used for one-hot-encoding.
+        num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
+            Can be None only if the labels are already one-hot-encoded.
         labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
             By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
             common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.

@mahdilamb
Copy link
Contributor Author

Hi @NicolasHug, that's diff applied!

Hope you have a great weekend.

Mahdi

@NicolasHug
Copy link
Member

Thank you @mahdilamb

Before I can merge, do you mind fixing this one linting issue:

  torchvision/transforms/v2/_augment.py:213: error: Argument "num_classes" to
  "one_hot" has incompatible type "Optional[int]"; expected "int"  [arg-type]
                  label = one_hot(label, num_classes=self.num_classes)
                                                     ^~~~~~~~~~~~~~~~
  Found 1 error in 1 file (checked 235 source files)

I think adding a simple # type: ignore[arg-type] comment will be enough - mypy is just not undersanding that self.num_classes can't be None at that point, so we should just silence it.

Thanks!

@mahdilamb
Copy link
Contributor Author

@NicolasHug, made the change, but if it fails will look into it properly. Also added you as a collaborator on the fork so you can mess about!

@NicolasHug NicolasHug changed the title Enable pre-encoded mixup Enable one-hot-encoded labels in MixUp and CutMix May 28, 2024
@NicolasHug NicolasHug merged commit c585a51 into pytorch:main May 28, 2024
57 of 69 checks passed
@NicolasHug
Copy link
Member

Thank you @mahdilamb !

facebook-github-bot pushed a commit that referenced this pull request Jun 7, 2024
Summary: Co-authored-by: Nicolas Hug <[email protected]>

Reviewed By: vmoens

Differential Revision: D58283866

fbshipit-source-id: 32b0b2ade02b3a81d167f64a3743c2bf62049308
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants