Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AutoAugment Policies and main Transform #3142

Merged
Merged
10 changes: 10 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,16 @@ def test_random_equalize(self):
[{}]
)

def test_autoaugment(self):
for policy in transforms.AutoAugmentPolicy:
for fill in [None, 85, (128, 128, 128)]:
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.AutoAugment(policy=policy, fill=fill)
for _ in range(100):
img = transform(img)
transform.__repr__()


if __name__ == '__main__':
unittest.main()
16 changes: 16 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,22 @@ def test_convert_image_dtype(self):
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))

def test_autoaugment(self):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)

for policy in T.AutoAugmentPolicy:
for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
for _ in range(100):
transform = T.AutoAugment(policy=policy, fill=fill)
s_transform = torch.jit.script(transform)

self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
1 change: 1 addition & 0 deletions torchvision/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .transforms import *
from .autoaugment import *
232 changes: 232 additions & 0 deletions torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import math
import torch

from enum import Enum
from torch import Tensor
from torch.jit.annotations import List, Tuple
from typing import Optional

from . import functional as F


class AutoAugmentPolicy(Enum):
"""AutoAugment policies learned on different datasets.
"""
IMAGENET = "imagenet"
CIFAR10 = "cifar10"
SVHN = "svhn"


class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.

Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
image. If int or float, the value is used for all bands respectively.
This option is supported for PIL image and Tensor inputs.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.

Example:
>>> t = transforms.AutoAugment()
>>> transformed = t(image)

>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> transforms.AutoAugment(),
>>> transforms.ToTensor()])
"""

def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, fill: Optional[List[float]] = None):
super().__init__()
self.policy = policy
self.fill = fill
if policy == AutoAugmentPolicy.IMAGENET:
self.policies = [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
]
elif policy == AutoAugmentPolicy.CIFAR10:
self.policies = [
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
]
elif policy == AutoAugmentPolicy.SVHN:
self.policies = [
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
]
else:
raise ValueError("The provided policy {} is not recognized.".format(policy))

_BINS = 10
self._op_meta = {
# name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, _BINS), True),
"ShearY": (torch.linspace(0.0, 0.3, _BINS), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
"Rotate": (torch.linspace(0.0, 30.0, _BINS), True),
"Brightness": (torch.linspace(0.0, 0.9, _BINS), True),
"Color": (torch.linspace(0.0, 0.9, _BINS), True),
"Contrast": (torch.linspace(0.0, 0.9, _BINS), True),
"Sharpness": (torch.linspace(0.0, 0.9, _BINS), True),
"Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False),
"Solarize": (torch.linspace(256.0, 0.0, _BINS), False),
"AutoContrast": (None, None),
"Equalize": (None, None),
"Invert": (None, None),
}

@staticmethod
def get_params(policy_num: int) -> Tuple[int, Tensor, Tensor]:
"""Get parameters for autoaugment transformation

Returns:
params required by the autoaugment transformation
"""
policy_id = torch.randint(policy_num, (1,)).item()
probs = torch.rand((2,))
signs = torch.randint(2, (2,))

return policy_id, probs, signs

def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]:
return self._op_meta[name]

def forward(self, img: Tensor):
"""
img (PIL Image or Tensor): Image to be transformed.

Returns:
PIL Image or Tensor: AutoAugmented image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]

policy_id, probs, signs = self.get_params(len(self.policies))

for i, (op_name, p, magnitude_id) in enumerate(self.policies[policy_id]):
if probs[i] <= p:
magnitudes, signed = self._get_op_meta(op_name)
magnitude = float(magnitudes[magnitude_id].item()) \
if magnitudes is not None and magnitude_id is not None else 0.0
if signed is not None and signed and signs[i] == 0:
magnitude *= -1.0

if op_name == "ShearX":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
fill=fill)
elif op_name == "ShearY":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
fill=fill)
elif op_name == "TranslateX":
img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0,
shear=[0.0, 0.0], fill=fill)
elif op_name == "TranslateY":
img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0,
shear=[0.0, 0.0], fill=fill)
elif op_name == "Rotate":
img = F.rotate(img, magnitude, fill=fill)
elif op_name == "Brightness":
img = F.adjust_brightness(img, 1.0 + magnitude)
elif op_name == "Color":
img = F.adjust_saturation(img, 1.0 + magnitude)
elif op_name == "Contrast":
img = F.adjust_contrast(img, 1.0 + magnitude)
elif op_name == "Sharpness":
img = F.adjust_sharpness(img, 1.0 + magnitude)
elif op_name == "Posterize":
img = F.posterize(img, int(magnitude))
elif op_name == "Solarize":
img = F.solarize(img, magnitude)
elif op_name == "AutoContrast":
img = F.autocontrast(img)
elif op_name == "Equalize":
img = F.equalize(img)
elif op_name == "Invert":
img = F.invert(img)
else:
raise ValueError("The provided operator {} is not recognized.".format(op_name))

return img

def __repr__(self):
return self.__class__.__name__ + '(policy={},fill={})'.format(self.policy, self.fill)