From 6e1d852747b6e65849112f003e6acb4b6ce180a8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 6 Dec 2020 18:17:46 +0000 Subject: [PATCH 1/7] Separate the tests of Adjust Sharpness from ColorJitter. --- test/test_transforms.py | 10 +++- test/test_transforms_tensor.py | 14 ++--- torchvision/transforms/transforms.py | 77 +++++++++++++++++++--------- 3 files changed, 70 insertions(+), 31 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index fbfab98be45..f2ca0dc9d1e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1316,7 +1316,7 @@ def test_adjusts_L_mode(self): self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L') def test_color_jitter(self): - color_jitter = transforms.ColorJitter(2, 2, 2, 0.1, 2) + color_jitter = transforms.ColorJitter(2, 2, 2, 0.1) x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] @@ -1840,6 +1840,14 @@ def test_random_solarize(self): [{"threshold": 192}] ) + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_adjust_sharpness(self): + self._test_randomness( + F.adjust_sharpness, + transforms.RandomAdjustSharpness, + [{"sharpness_factor": 2.0}] + ) + @unittest.skipIf(stats is None, 'scipy.stats not available') def test_random_autocontrast(self): self._test_randomness( diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 7af1f1d4c46..ea3f818ad45 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -104,6 +104,12 @@ def test_random_solarize(self): 'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_random_adjust_sharpness(self): + fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0} + self._test_op( + 'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + def test_random_autocontrast(self): self._test_op('autocontrast', 'RandomAutocontrast') @@ -138,14 +144,8 @@ def test_color_jitter(self): "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max" ) - for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]: - meth_kwargs = {"sharpness": f} - self._test_class_op( - "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" - ) - # All 4 parameters together - meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2, "sharpness": 0.2} + meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2} self._test_class_op( "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max" ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index f4416b36acd..c6322ef71d4 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -22,7 +22,7 @@ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAutocontrast", "RandomEqualize"] + "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] class Compose: @@ -1039,7 +1039,7 @@ def __repr__(self): class ColorJitter(torch.nn.Module): - """Randomly change the brightness, contrast, saturation, hue and sharpness of an image. + """Randomly change the brightness, contrast, saturation and hue of an image. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness. @@ -1054,19 +1054,15 @@ class ColorJitter(torch.nn.Module): hue (float or tuple of float (min, max)): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. - sharpness (float or tuple of float (min, max)): How much to jitter sharpness. - sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness] - or the given [min, max]. Should be non negative numbers. """ - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharpness=0): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): super().__init__() self.brightness = self._check_input(brightness, 'brightness') self.contrast = self._check_input(contrast, 'contrast') self.saturation = self._check_input(saturation, 'saturation') self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) - self.sharpness = self._check_input(sharpness, 'sharpness') @torch.jit.unused def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): @@ -1082,7 +1078,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs else: raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) - # if value is 0 or (1., 1.) for brightness/contrast/saturation/sharpness + # if value is 0 or (1., 1.) for brightness/contrast/saturation # or (0., 0.) for hue, do nothing if value[0] == value[1] == center: value = None @@ -1092,10 +1088,8 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs def get_params(brightness: Optional[List[float]], contrast: Optional[List[float]], saturation: Optional[List[float]], - hue: Optional[List[float]], - sharpness: Optional[List[float]] - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float], - Optional[float]]: + hue: Optional[List[float]] + ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: """Get the parameters for the randomized transform to be applied on image. Args: @@ -1107,22 +1101,19 @@ def get_params(brightness: Optional[List[float]], uniformly. Pass None to turn off the transformation. hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. Pass None to turn off the transformation. - sharpness (tuple of float (min, max), optional): The range from which the sharpness is chosen - uniformly. Pass None to turn off the transformation. Returns: tuple: The parameters used to apply the randomized transform along with their random order. """ - fn_idx = torch.randperm(5) + fn_idx = torch.randperm(4) b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) - sp = None if sharpness is None else float(torch.empty(1).uniform_(sharpness[0], sharpness[1])) - return fn_idx, b, c, s, h, sp + return fn_idx, b, c, s, h def forward(self, img): """ @@ -1132,8 +1123,8 @@ def forward(self, img): Returns: PIL Image or Tensor: Color jittered image. """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, sharpness_factor = \ - self.get_params(self.brightness, self.contrast, self.saturation, self.hue, self.sharpness) + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ + self.get_params(self.brightness, self.contrast, self.saturation, self.hue) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: @@ -1144,8 +1135,6 @@ def forward(self, img): img = F.adjust_saturation(img, saturation_factor) elif fn_id == 3 and hue_factor is not None: img = F.adjust_hue(img, hue_factor) - elif fn_id == 4 and sharpness_factor is not None: - img = F.adjust_sharpness(img, sharpness_factor) return img @@ -1154,8 +1143,7 @@ def __repr__(self): format_string += 'brightness={0}'.format(self.brightness) format_string += ', contrast={0}'.format(self.contrast) format_string += ', saturation={0}'.format(self.saturation) - format_string += ', hue={0}'.format(self.hue) - format_string += ', sharpness={0})'.format(self.sharpness) + format_string += ', hue={0})'.format(self.hue) return format_string @@ -1838,6 +1826,49 @@ def __repr__(self): return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) +class RandomAdjustSharpness(torch.nn.Module): + """Adjust the sharpness of the image randomly with a given probability. The image + can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] + shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness_factor (float): How much to adjust the sharpness. Can be + any non negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, sharpness_factor, p=0.5): + super().__init__() + self.sharpness_factor = sharpness_factor + self.p = p + + @staticmethod + def get_params() -> float: + """Choose a value for the random transformation. + + Returns: + float: Random value which is used to determine whether the random transformation + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be sharpened. + + Returns: + PIL Image or Tensor: Randomly sharpened image. + """ + if self.get_params() < self.p: + return F.adjust_sharpness(img, self.sharpness_factor) + return img + + def __repr__(self): + return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) + + class RandomAutocontrast(torch.nn.Module): """Autocontrast the pixels of the given image randomly with a given probability. The image can be a PIL Image or a torch Tensor, in which case it is expected From 326b6dcf33a07a4e49ac315e5868d608887d7867 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 8 Dec 2020 12:36:26 +0000 Subject: [PATCH 2/7] Initial implementation, not-jitable. --- test/test_transforms_tensor.py | 14 ++ torchvision/transforms/autoaugment.py | 228 ++++++++++++++++++++++++++ 2 files changed, 242 insertions(+) create mode 100644 torchvision/transforms/autoaugment.py diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index ea3f818ad45..1849c34f5df 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -3,6 +3,7 @@ from torchvision import transforms as T from torchvision.transforms import functional as F from torchvision.transforms import InterpolationMode +from torchvision.transforms.autoaugment import AutoAugment, AutoAugmentPolicy import numpy as np @@ -626,6 +627,19 @@ def test_convert_image_dtype(self): with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt")) + def test_autoaugment(self): + tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device) + batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) + + for policy in AutoAugmentPolicy: + for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]: + for _ in range(100): + transform = AutoAugment(policy=policy, fill=fill) + s_transform = torch.jit.script(transform) + + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py new file mode 100644 index 00000000000..606f15afb2b --- /dev/null +++ b/torchvision/transforms/autoaugment.py @@ -0,0 +1,228 @@ +import math +import torch + +from enum import Enum +from torch import Tensor +from torch.jit.annotations import List, Tuple +from typing import Optional + +from . import functional as F, InterpolationMode + + +class AutoAugmentPolicy(Enum): + """AutoAugment policies learned on different datasets. + """ + IMAGENET = "imagenet" + CIFAR10 = "cifar10" + SVHN = "svhn" + + +def _shearX(img: Tensor, magnitude: float, fill: Optional[List[float]]) -> Tensor: + v = math.degrees(magnitude) + return F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[v, 0.0], + interpolation=InterpolationMode.BICUBIC, fill=fill) + + +def _shearY(img: Tensor, magnitude: float, fill: Optional[List[float]]) -> Tensor: + v = math.degrees(magnitude) + return F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, v], + interpolation=InterpolationMode.BICUBIC, fill=fill) + + +def _translateX(img: Tensor, magnitude: float, fill: Optional[List[float]]) -> Tensor: + v = int(F._get_image_size(img)[0] * magnitude) + return F.affine(img, angle=0.0, translate=[v, 0], scale=1.0, shear=[0.0, 0.0], + interpolation=InterpolationMode.BICUBIC, fill=fill) + + +def _translateY(img: Tensor, magnitude: float, fill: Optional[List[float]]) -> Tensor: + v = int(F._get_image_size(img)[1] * magnitude) + return F.affine(img, angle=0.0, translate=[0, v], scale=1.0, shear=[0.0, 0.0], + interpolation=InterpolationMode.BICUBIC, fill=fill) + + +def _rotate(img: Tensor, magnitude: float, fill: Optional[List[float]]) -> Tensor: + return F.rotate(img, magnitude, interpolation=InterpolationMode.BICUBIC, fill=fill) + + +def _brightness(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: + v = 1.0 + magnitude + return F.adjust_brightness(img, v) + + +def _color(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: + v = 1.0 + magnitude + return F.adjust_saturation(img, v) + + +def _contrast(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: + v = 1.0 + magnitude + return F.adjust_contrast(img, v) + + +def _sharpness(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: + v = 1.0 + magnitude + return F.adjust_sharpness(img, v) + + +def _posterize(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: + v = int(magnitude) + return F.posterize(img, v) + + +def _solarize(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: + return F.solarize(img, magnitude) + + +def _autocontrast(img: Tensor, _: float, __: Optional[List[float]]) -> Tensor: + return F.autocontrast(img) + + +def _equalize(img: Tensor, _: float, __: Optional[List[float]]) -> Tensor: + return F.equalize(img) + + +def _invert(img: Tensor, _: float, __: Optional[List[float]]) -> Tensor: + return F.invert(img) + + +_BINS = 10 + +_OPERATIONS = { + # name: (method, magnitudes, signed) + "ShearX": (_shearX, torch.linspace(0.0, 0.3, _BINS), True), + "ShearY": (_shearY, torch.linspace(0.0, 0.3, _BINS), True), + "TranslateX": (_translateX, torch.linspace(0.0, 150.0 / 331.0, _BINS), True), + "TranslateY": (_translateY, torch.linspace(0.0, 150.0 / 331.0, _BINS), True), + "Rotate": (_rotate, torch.linspace(0.0, 30.0, _BINS), True), + "Brightness": (_brightness, torch.linspace(0.0, 0.9, _BINS), True), + "Color": (_color, torch.linspace(0.0, 0.9, _BINS), True), + "Contrast": (_contrast, torch.linspace(0.0, 0.9, _BINS), True), + "Sharpness": (_sharpness, torch.linspace(0.0, 0.9, _BINS), True), + "Posterize": (_posterize, torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False), + "Solarize": (_solarize, torch.linspace(256.0, 0.0, _BINS), False), + "AutoContrast": (_autocontrast, None, None), + "Equalize": (_equalize, None, None), + "Invert": (_invert, None, None), +} + +_POLICIES = { + AutoAugmentPolicy.IMAGENET: [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ], + AutoAugmentPolicy.CIFAR10: [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ], + AutoAugmentPolicy.SVHN: [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ], +} + + +class AutoAugment(torch.nn.Module): + r"""AutoAugment method, based on + `"AutoAugment: Learning Augmentation Strategies from Data" `_. + """ + def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, fill: Optional[List[float]] = None): + super().__init__() + self.policy = policy + self.fill = fill + if policy not in _POLICIES: + raise ValueError("The provided policy {} is not recognized.".format(policy)) + self.policies = _POLICIES[policy] + + @staticmethod + def get_params(policy_num: int) -> Tuple[int, Tensor, Tensor]: + policy_id = torch.randint(policy_num, (1,)).item() + probs = torch.rand((2,)) + signs = torch.randint(2, (2,)) + + return policy_id, probs, signs + + def forward(self, img): + policy_id, probs, signs = self.get_params(len(self.policies)) + + for i, (name, p, magnitude_id) in enumerate(self.policy[policy_id]): + if probs[i] <= p: + method, magnitudes, signed = _OPERATIONS[name] + magnitude = magnitudes[magnitude_id] if magnitudes is not None else None + if signed and signs[i] == 0: + magnitude *= -1 + img = method(img, magnitude, self.fill) + + return img + + def __repr__(self): + return self.__class__.__name__ + '(policy={},fill={})'.format(self.policy, self.fill) From 23b492d6ae1326f2fa442da30e5fa6c5c8006e74 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 8 Dec 2020 14:03:00 +0000 Subject: [PATCH 3/7] AutoAugment passing JIT. --- torchvision/transforms/autoaugment.py | 339 ++++++++++++-------------- 1 file changed, 156 insertions(+), 183 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 606f15afb2b..96d501181ab 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -17,191 +17,120 @@ class AutoAugmentPolicy(Enum): SVHN = "svhn" -def _shearX(img: Tensor, magnitude: float, fill: Optional[List[float]]) -> Tensor: - v = math.degrees(magnitude) - return F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[v, 0.0], - interpolation=InterpolationMode.BICUBIC, fill=fill) - - -def _shearY(img: Tensor, magnitude: float, fill: Optional[List[float]]) -> Tensor: - v = math.degrees(magnitude) - return F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, v], - interpolation=InterpolationMode.BICUBIC, fill=fill) - - -def _translateX(img: Tensor, magnitude: float, fill: Optional[List[float]]) -> Tensor: - v = int(F._get_image_size(img)[0] * magnitude) - return F.affine(img, angle=0.0, translate=[v, 0], scale=1.0, shear=[0.0, 0.0], - interpolation=InterpolationMode.BICUBIC, fill=fill) - - -def _translateY(img: Tensor, magnitude: float, fill: Optional[List[float]]) -> Tensor: - v = int(F._get_image_size(img)[1] * magnitude) - return F.affine(img, angle=0.0, translate=[0, v], scale=1.0, shear=[0.0, 0.0], - interpolation=InterpolationMode.BICUBIC, fill=fill) - - -def _rotate(img: Tensor, magnitude: float, fill: Optional[List[float]]) -> Tensor: - return F.rotate(img, magnitude, interpolation=InterpolationMode.BICUBIC, fill=fill) - - -def _brightness(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: - v = 1.0 + magnitude - return F.adjust_brightness(img, v) - - -def _color(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: - v = 1.0 + magnitude - return F.adjust_saturation(img, v) - - -def _contrast(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: - v = 1.0 + magnitude - return F.adjust_contrast(img, v) - - -def _sharpness(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: - v = 1.0 + magnitude - return F.adjust_sharpness(img, v) - - -def _posterize(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: - v = int(magnitude) - return F.posterize(img, v) - - -def _solarize(img: Tensor, magnitude: float, _: Optional[List[float]]) -> Tensor: - return F.solarize(img, magnitude) - - -def _autocontrast(img: Tensor, _: float, __: Optional[List[float]]) -> Tensor: - return F.autocontrast(img) - - -def _equalize(img: Tensor, _: float, __: Optional[List[float]]) -> Tensor: - return F.equalize(img) - - -def _invert(img: Tensor, _: float, __: Optional[List[float]]) -> Tensor: - return F.invert(img) - - -_BINS = 10 - -_OPERATIONS = { - # name: (method, magnitudes, signed) - "ShearX": (_shearX, torch.linspace(0.0, 0.3, _BINS), True), - "ShearY": (_shearY, torch.linspace(0.0, 0.3, _BINS), True), - "TranslateX": (_translateX, torch.linspace(0.0, 150.0 / 331.0, _BINS), True), - "TranslateY": (_translateY, torch.linspace(0.0, 150.0 / 331.0, _BINS), True), - "Rotate": (_rotate, torch.linspace(0.0, 30.0, _BINS), True), - "Brightness": (_brightness, torch.linspace(0.0, 0.9, _BINS), True), - "Color": (_color, torch.linspace(0.0, 0.9, _BINS), True), - "Contrast": (_contrast, torch.linspace(0.0, 0.9, _BINS), True), - "Sharpness": (_sharpness, torch.linspace(0.0, 0.9, _BINS), True), - "Posterize": (_posterize, torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False), - "Solarize": (_solarize, torch.linspace(256.0, 0.0, _BINS), False), - "AutoContrast": (_autocontrast, None, None), - "Equalize": (_equalize, None, None), - "Invert": (_invert, None, None), -} - -_POLICIES = { - AutoAugmentPolicy.IMAGENET: [ - (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), - (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), - (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), - (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), - (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), - (("Rotate", 0.8, 8), ("Color", 0.4, 0)), - (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), - (("Equalize", 0.0, None), ("Equalize", 0.8, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Rotate", 0.8, 8), ("Color", 1.0, 2)), - (("Color", 0.8, 8), ("Solarize", 0.8, 7)), - (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), - (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), - (("Color", 0.4, 0), ("Equalize", 0.6, None)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - ], - AutoAugmentPolicy.CIFAR10: [ - (("Invert", 0.1, None), ("Contrast", 0.2, 6)), - (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), - (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), - (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), - (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), - (("Color", 0.4, 3), ("Brightness", 0.6, 7)), - (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), - (("Equalize", 0.6, None), ("Equalize", 0.5, None)), - (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), - (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), - (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), - (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), - (("Brightness", 0.9, 6), ("Color", 0.2, 8)), - (("Solarize", 0.5, 2), ("Invert", 0.0, None)), - (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), - (("Equalize", 0.2, None), ("Equalize", 0.6, None)), - (("Color", 0.9, 9), ("Equalize", 0.6, None)), - (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), - (("Brightness", 0.1, 3), ("Color", 0.7, 0)), - (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), - (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), - (("Equalize", 0.8, None), ("Invert", 0.1, None)), - (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), - ], - AutoAugmentPolicy.SVHN: [ - (("ShearX", 0.9, 4), ("Invert", 0.2, None)), - (("ShearY", 0.9, 8), ("Invert", 0.7, None)), - (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), - (("ShearY", 0.9, 8), ("Invert", 0.4, None)), - (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), - (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), - (("ShearY", 0.8, 8), ("Invert", 0.7, None)), - (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), - (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), - (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), - (("Invert", 0.6, None), ("Rotate", 0.8, 4)), - (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), - (("ShearX", 0.1, 6), ("Invert", 0.6, None)), - (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), - (("ShearY", 0.8, 4), ("Invert", 0.8, None)), - (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), - (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), - (("ShearX", 0.7, 2), ("Invert", 0.1, None)), - ], -} - - class AutoAugment(torch.nn.Module): r"""AutoAugment method, based on `"AutoAugment: Learning Augmentation Strategies from Data" `_. """ + def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, fill: Optional[List[float]] = None): super().__init__() self.policy = policy self.fill = fill - if policy not in _POLICIES: + if policy == AutoAugmentPolicy.IMAGENET: + self.policies = [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + self.policies = [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + self.policies = [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + else: raise ValueError("The provided policy {} is not recognized.".format(policy)) - self.policies = _POLICIES[policy] + + _BINS = 10 + self._op_meta = { + # name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, _BINS), True), + "ShearY": (torch.linspace(0.0, 0.3, _BINS), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), + "Rotate": (torch.linspace(0.0, 30.0, _BINS), True), + "Brightness": (torch.linspace(0.0, 0.9, _BINS), True), + "Color": (torch.linspace(0.0, 0.9, _BINS), True), + "Contrast": (torch.linspace(0.0, 0.9, _BINS), True), + "Sharpness": (torch.linspace(0.0, 0.9, _BINS), True), + "Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False), + "Solarize": (torch.linspace(256.0, 0.0, _BINS), False), + "AutoContrast": (None, None), + "Equalize": (None, None), + "Invert": (None, None), + } @staticmethod def get_params(policy_num: int) -> Tuple[int, Tensor, Tensor]: @@ -211,16 +140,60 @@ def get_params(policy_num: int) -> Tuple[int, Tensor, Tensor]: return policy_id, probs, signs - def forward(self, img): + def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]: + return self._op_meta[name] + + def forward(self, img: Tensor): + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F._get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + policy_id, probs, signs = self.get_params(len(self.policies)) - for i, (name, p, magnitude_id) in enumerate(self.policy[policy_id]): + for i, (op_name, p, magnitude_id) in enumerate(self.policies[policy_id]): if probs[i] <= p: - method, magnitudes, signed = _OPERATIONS[name] - magnitude = magnitudes[magnitude_id] if magnitudes is not None else None - if signed and signs[i] == 0: - magnitude *= -1 - img = method(img, magnitude, self.fill) + magnitudes, signed = self._get_op_meta(op_name) + magnitude = float(magnitudes[magnitude_id].item()) if magnitudes is not None and magnitude_id is not None else 0.0 + if signed is not None and signed and signs[i] == 0: + magnitude *= -1.0 + + if op_name == "ShearX": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], + interpolation=InterpolationMode.BICUBIC, fill=fill) + elif op_name == "ShearY": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], + interpolation=InterpolationMode.BICUBIC, fill=fill) + elif op_name == "TranslateX": + img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0, + shear=[0.0, 0.0], interpolation=InterpolationMode.BICUBIC, fill=fill) + elif op_name == "TranslateY": + img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0, + shear=[0.0, 0.0], interpolation=InterpolationMode.BICUBIC, fill=fill) + elif op_name == "Rotate": + img = F.rotate(img, magnitude, interpolation=InterpolationMode.BICUBIC, fill=fill) + elif op_name == "Brightness": + img = F.adjust_brightness(img, 1.0 + magnitude) + elif op_name == "Color": + img = F.adjust_saturation(img, 1.0 + magnitude) + elif op_name == "Contrast": + img = F.adjust_contrast(img, 1.0 + magnitude) + elif op_name == "Sharpness": + img = F.adjust_sharpness(img, 1.0 + magnitude) + elif op_name == "Posterize": + img = F.posterize(img, int(magnitude)) + elif op_name == "Solarize": + img = F.solarize(img, magnitude) + elif op_name == "AutoContrast": + img = F.autocontrast(img) + elif op_name == "Equalize": + img = F.equalize(img) + elif op_name == "Invert": + img = F.invert(img) + else: + raise ValueError("The provided operator {} is not recognized.".format(op_name)) return img From f7fa15e5193ebfd42a89e02586788272f74d6a13 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 8 Dec 2020 15:22:12 +0000 Subject: [PATCH 4/7] Adding tests/docs, changing formatting. --- test/test_transforms.py | 11 ++++++ test/test_transforms_tensor.py | 2 ++ torchvision/transforms/autoaugment.py | 52 +++++++++++++++++++++++---- 3 files changed, 58 insertions(+), 7 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index f2ca0dc9d1e..30db91d5aae 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -4,6 +4,7 @@ import torchvision.transforms.functional as F import torchvision.transforms.functional_tensor as F_t from torch._utils_internal import get_file_path_2 +from torchvision.transforms.autoaugment import AutoAugment, AutoAugmentPolicy from numpy.testing import assert_array_almost_equal import unittest import math @@ -1864,6 +1865,16 @@ def test_random_equalize(self): [{}] ) + def test_autoaugment(self): + for policy in AutoAugmentPolicy: + for fill in [None, 85, (128, 128, 128)]: + random.seed(42) + img = transforms.ToPILImage()(torch.rand(3, 44, 56)) + transform = AutoAugment(policy=policy, fill=fill) + for _ in range(1000): + img = transform(img) + transform.__repr__() + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 1849c34f5df..f0680be7c07 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -640,6 +640,8 @@ def test_autoaugment(self): self._test_transform_vs_scripted(transform, s_transform, tensor) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + with get_tmp_dir() as tmp_dir: + s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt")) @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 96d501181ab..74ce11a8917 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -18,8 +18,34 @@ class AutoAugmentPolicy(Enum): class AutoAugment(torch.nn.Module): - r"""AutoAugment method, based on + r"""AutoAugment data augmentation method based on `"AutoAugment: Learning Augmentation Strategies from Data" `_. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + policy (AutoAugmentPolicy): Desired policy enum defined by + :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. + translate (tuple, optional): tuple of maximum absolute fraction for horizontal + and vertical translations. For example translate=(a, b), then horizontal shift + is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is + randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. + scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is + randomly sampled from the range a <= scale <= b. Will keep original scale by default. + shear (sequence or float or int, optional): Range of degrees to select from. + If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) + will be applied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the + range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, + a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. + Will not apply shear by default. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. + fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed + image. If int or float, the value is used for all bands respectively. + This option is supported for PIL image and Tensor inputs. + If input is PIL Image, the options is only available for ``Pillow>=5.0.0``. """ def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, fill: Optional[List[float]] = None): @@ -134,6 +160,11 @@ def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, fill: @staticmethod def get_params(policy_num: int) -> Tuple[int, Tensor, Tensor]: + """Get parameters for autoaugment transformation + + Returns: + params required by the autoaugment transformation + """ policy_id = torch.randint(policy_num, (1,)).item() probs = torch.rand((2,)) signs = torch.randint(2, (2,)) @@ -144,6 +175,12 @@ def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]: return self._op_meta[name] def forward(self, img: Tensor): + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: AutoAugmented image. + """ fill = self.fill if isinstance(img, Tensor): if isinstance(fill, (int, float)): @@ -156,24 +193,25 @@ def forward(self, img: Tensor): for i, (op_name, p, magnitude_id) in enumerate(self.policies[policy_id]): if probs[i] <= p: magnitudes, signed = self._get_op_meta(op_name) - magnitude = float(magnitudes[magnitude_id].item()) if magnitudes is not None and magnitude_id is not None else 0.0 + magnitude = float(magnitudes[magnitude_id].item()) if magnitudes is not None and \ + magnitude_id is not None else 0.0 if signed is not None and signed and signs[i] == 0: magnitude *= -1.0 if op_name == "ShearX": img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], - interpolation=InterpolationMode.BICUBIC, fill=fill) + fill=fill) elif op_name == "ShearY": img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], - interpolation=InterpolationMode.BICUBIC, fill=fill) + fill=fill) elif op_name == "TranslateX": img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0, - shear=[0.0, 0.0], interpolation=InterpolationMode.BICUBIC, fill=fill) + shear=[0.0, 0.0], fill=fill) elif op_name == "TranslateY": img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0, - shear=[0.0, 0.0], interpolation=InterpolationMode.BICUBIC, fill=fill) + shear=[0.0, 0.0], fill=fill) elif op_name == "Rotate": - img = F.rotate(img, magnitude, interpolation=InterpolationMode.BICUBIC, fill=fill) + img = F.rotate(img, magnitude, fill=fill) elif op_name == "Brightness": img = F.adjust_brightness(img, 1.0 + magnitude) elif op_name == "Color": From 495e3e963472e694630ff12f3ec3c41f04de3f16 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 8 Dec 2020 15:43:21 +0000 Subject: [PATCH 5/7] Update test. --- test/test_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 30db91d5aae..a25bfd7ce8c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1869,9 +1869,9 @@ def test_autoaugment(self): for policy in AutoAugmentPolicy: for fill in [None, 85, (128, 128, 128)]: random.seed(42) - img = transforms.ToPILImage()(torch.rand(3, 44, 56)) + img = Image.open(GRACE_HOPPER) transform = AutoAugment(policy=policy, fill=fill) - for _ in range(1000): + for _ in range(100): img = transform(img) transform.__repr__() From bbe6590c78d0f0d6ad0d841a144841ae63513717 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 8 Dec 2020 15:57:49 +0000 Subject: [PATCH 6/7] Fix formats --- test/test_transforms_tensor.py | 1 + torchvision/transforms/autoaugment.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 1f22dd83158..4aabdf9e2c7 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -643,6 +643,7 @@ def test_autoaugment(self): with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt")) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 74ce11a8917..182391ecd29 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -193,8 +193,8 @@ def forward(self, img: Tensor): for i, (op_name, p, magnitude_id) in enumerate(self.policies[policy_id]): if probs[i] <= p: magnitudes, signed = self._get_op_meta(op_name) - magnitude = float(magnitudes[magnitude_id].item()) if magnitudes is not None and \ - magnitude_id is not None else 0.0 + magnitude = float(magnitudes[magnitude_id].item()) \ + if magnitudes is not None and magnitude_id is not None else 0.0 if signed is not None and signed and signs[i] == 0: magnitude *= -1.0 From 5b2001b7135ed9aa3b9383039d710dd4d783d174 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 8 Dec 2020 16:27:23 +0000 Subject: [PATCH 7/7] Fix documentation and imports. --- test/test_transforms.py | 5 ++--- test/test_transforms_tensor.py | 5 ++--- torchvision/transforms/__init__.py | 1 + torchvision/transforms/autoaugment.py | 27 ++++++++++----------------- 4 files changed, 15 insertions(+), 23 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index a25bfd7ce8c..b3c82334d14 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -4,7 +4,6 @@ import torchvision.transforms.functional as F import torchvision.transforms.functional_tensor as F_t from torch._utils_internal import get_file_path_2 -from torchvision.transforms.autoaugment import AutoAugment, AutoAugmentPolicy from numpy.testing import assert_array_almost_equal import unittest import math @@ -1866,11 +1865,11 @@ def test_random_equalize(self): ) def test_autoaugment(self): - for policy in AutoAugmentPolicy: + for policy in transforms.AutoAugmentPolicy: for fill in [None, 85, (128, 128, 128)]: random.seed(42) img = Image.open(GRACE_HOPPER) - transform = AutoAugment(policy=policy, fill=fill) + transform = transforms.AutoAugment(policy=policy, fill=fill) for _ in range(100): img = transform(img) transform.__repr__() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 4aabdf9e2c7..22a7c065122 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -3,7 +3,6 @@ from torchvision import transforms as T from torchvision.transforms import functional as F from torchvision.transforms import InterpolationMode -from torchvision.transforms.autoaugment import AutoAugment, AutoAugmentPolicy import numpy as np @@ -631,10 +630,10 @@ def test_autoaugment(self): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) - for policy in AutoAugmentPolicy: + for policy in T.AutoAugmentPolicy: for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]: for _ in range(100): - transform = AutoAugment(policy=policy, fill=fill) + transform = T.AutoAugment(policy=policy, fill=fill) s_transform = torch.jit.script(transform) self._test_transform_vs_scripted(transform, s_transform, tensor) diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py index 7986cdd6429..77680a14f0d 100644 --- a/torchvision/transforms/__init__.py +++ b/torchvision/transforms/__init__.py @@ -1 +1,2 @@ from .transforms import * +from .autoaugment import * diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 182391ecd29..4cdf219c22c 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -6,7 +6,7 @@ from torch.jit.annotations import List, Tuple from typing import Optional -from . import functional as F, InterpolationMode +from . import functional as F class AutoAugmentPolicy(Enum): @@ -26,26 +26,19 @@ class AutoAugment(torch.nn.Module): Args: policy (AutoAugmentPolicy): Desired policy enum defined by :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. - translate (tuple, optional): tuple of maximum absolute fraction for horizontal - and vertical translations. For example translate=(a, b), then horizontal shift - is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is - randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. - scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is - randomly sampled from the range a <= scale <= b. Will keep original scale by default. - shear (sequence or float or int, optional): Range of degrees to select from. - If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) - will be applied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the - range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, - a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. - Will not apply shear by default. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. - For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed image. If int or float, the value is used for all bands respectively. This option is supported for PIL image and Tensor inputs. If input is PIL Image, the options is only available for ``Pillow>=5.0.0``. + + Example: + >>> t = transforms.AutoAugment() + >>> transformed = t(image) + + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> transforms.AutoAugment(), + >>> transforms.ToTensor()]) """ def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, fill: Optional[List[float]] = None):