diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index c070c5c1d61..a651c0e9f38 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -289,13 +289,14 @@ def test_pad(self): self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs) - def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"): + def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max", + dts=(None, torch.float32, torch.float64)): script_fn = torch.jit.script(fn) torch.manual_seed(15) tensor, pil_img = self._create_data(26, 34, device=self.device) batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) - for dt in [None, torch.float32, torch.float64]: + for dt in dts: if dt is not None: tensor = F.convert_image_dtype(tensor, dt) @@ -862,6 +863,77 @@ def test_gaussian_blur(self): msg="{}, {}".format(ksize, sigma) ) + def test_invert(self): + self._test_adjust_fn( + F.invert, + F_pil.invert, + F_t.invert, + [{}], + tol=1.0, + agg_method="max" + ) + + def test_posterize(self): + self._test_adjust_fn( + F.posterize, + F_pil.posterize, + F_t.posterize, + [{"bits": bits} for bits in range(0, 8)], + tol=1.0, + agg_method="max", + dts=(None,) + ) + + def test_solarize(self): + self._test_adjust_fn( + F.solarize, + F_pil.solarize, + F_t.solarize, + [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]], + tol=1.0, + agg_method="max", + dts=(None,) + ) + self._test_adjust_fn( + F.solarize, + lambda img, threshold: F_pil.solarize(img, 255 * threshold), + F_t.solarize, + [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]], + tol=1.0, + agg_method="max", + dts=(torch.float32, torch.float64) + ) + + def test_adjust_sharpness(self): + self._test_adjust_fn( + F.adjust_sharpness, + F_pil.adjust_sharpness, + F_t.adjust_sharpness, + [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]] + ) + + def test_autocontrast(self): + self._test_adjust_fn( + F.autocontrast, + F_pil.autocontrast, + F_t.autocontrast, + [{}], + tol=1.0, + agg_method="max" + ) + + def test_equalize(self): + torch.set_deterministic(False) + self._test_adjust_fn( + F.equalize, + F_pil.equalize, + F_t.equalize, + [{}], + tol=1.0, + agg_method="max", + dts=(None,) + ) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index 8a0762327f9..b3c82334d14 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1234,6 +1234,48 @@ def test_adjust_hue(self): y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) self.assertTrue(np.allclose(y_np, y_ans)) + def test_adjust_sharpness(self): + x_shape = [4, 4, 3] + x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, + 0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105, + 111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = F.adjust_sharpness(x_pil, 1) + y_np = np.array(y_pil) + self.assertTrue(np.allclose(y_np, x_np)) + + # test 1 + y_pil = F.adjust_sharpness(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30, + 30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101, + 107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # test 2 + y_pil = F.adjust_sharpness(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, + 0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112, + 119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # test 3 + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + x_th = torch.tensor(x_np.transpose(2, 0, 1)) + y_pil = F.adjust_sharpness(x_pil, 2) + y_np = np.array(y_pil).transpose(2, 0, 1) + y_th = F.adjust_sharpness(x_th, 2) + self.assertTrue(np.allclose(y_np, y_th.numpy())) + def test_adjust_gamma(self): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] @@ -1270,6 +1312,7 @@ def test_adjusts_L_mode(self): self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L') self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L') self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L') + self.assertEqual(F.adjust_sharpness(x_l, 2).mode, 'L') self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L') def test_color_jitter(self): @@ -1751,6 +1794,86 @@ def test_gaussian_blur_asserts(self): with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"): transforms.GaussianBlur(3, "sigma_string") + def _test_randomness(self, fn, trans, configs): + random_state = random.getstate() + random.seed(42) + img = transforms.ToPILImage()(torch.rand(3, 16, 18)) + + for p in [0.5, 0.7]: + for config in configs: + inv_img = fn(img, **config) + + num_samples = 250 + counts = 0 + for _ in range(num_samples): + tranformation = trans(p=p, **config) + tranformation.__repr__() + out = tranformation(img) + if out == inv_img: + counts += 1 + + p_value = stats.binom_test(counts, num_samples, p=p) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_invert(self): + self._test_randomness( + F.invert, + transforms.RandomInvert, + [{}] + ) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_posterize(self): + self._test_randomness( + F.posterize, + transforms.RandomPosterize, + [{"bits": 4}] + ) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_solarize(self): + self._test_randomness( + F.solarize, + transforms.RandomSolarize, + [{"threshold": 192}] + ) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_adjust_sharpness(self): + self._test_randomness( + F.adjust_sharpness, + transforms.RandomAdjustSharpness, + [{"sharpness_factor": 2.0}] + ) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_autocontrast(self): + self._test_randomness( + F.autocontrast, + transforms.RandomAutocontrast, + [{}] + ) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_equalize(self): + self._test_randomness( + F.equalize, + transforms.RandomEqualize, + [{}] + ) + + def test_autoaugment(self): + for policy in transforms.AutoAugmentPolicy: + for fill in [None, 85, (128, 128, 128)]: + random.seed(42) + img = Image.open(GRACE_HOPPER) + transform = transforms.AutoAugment(policy=policy, fill=fill) + for _ in range(100): + img = transform(img) + transform.__repr__() + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 562057deabd..22a7c065122 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -89,6 +89,34 @@ def test_random_horizontal_flip(self): def test_random_vertical_flip(self): self._test_op('vflip', 'RandomVerticalFlip') + def test_random_invert(self): + self._test_op('invert', 'RandomInvert') + + def test_random_posterize(self): + fn_kwargs = meth_kwargs = {"bits": 4} + self._test_op( + 'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + + def test_random_solarize(self): + fn_kwargs = meth_kwargs = {"threshold": 192.0} + self._test_op( + 'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + + def test_random_adjust_sharpness(self): + fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0} + self._test_op( + 'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + + def test_random_autocontrast(self): + self._test_op('autocontrast', 'RandomAutocontrast') + + def test_random_equalize(self): + torch.set_deterministic(False) + self._test_op('equalize', 'RandomEqualize') + def test_color_jitter(self): tol = 1.0 + 1e-10 @@ -598,6 +626,22 @@ def test_convert_image_dtype(self): with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt")) + def test_autoaugment(self): + tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device) + batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) + + for policy in T.AutoAugmentPolicy: + for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]: + for _ in range(100): + transform = T.AutoAugment(policy=policy, fill=fill) + s_transform = torch.jit.script(transform) + + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + + with get_tmp_dir() as tmp_dir: + s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt")) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py index 7986cdd6429..77680a14f0d 100644 --- a/torchvision/transforms/__init__.py +++ b/torchvision/transforms/__init__.py @@ -1 +1,2 @@ from .transforms import * +from .autoaugment import * diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py new file mode 100644 index 00000000000..26847521998 --- /dev/null +++ b/torchvision/transforms/autoaugment.py @@ -0,0 +1,245 @@ +import math +import torch + +from enum import Enum +from torch import Tensor +from torch.jit.annotations import List, Tuple +from typing import Optional + +from . import functional as F, InterpolationMode + + +class AutoAugmentPolicy(Enum): + """AutoAugment policies learned on different datasets. + """ + IMAGENET = "imagenet" + CIFAR10 = "cifar10" + SVHN = "svhn" + + +def _get_transforms(policy: AutoAugmentPolicy): + if policy == AutoAugmentPolicy.IMAGENET: + return [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + return [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + return [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + + +def _get_magnitudes(): + _BINS = 10 + return { + # name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, _BINS), True), + "ShearY": (torch.linspace(0.0, 0.3, _BINS), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), + "Rotate": (torch.linspace(0.0, 30.0, _BINS), True), + "Brightness": (torch.linspace(0.0, 0.9, _BINS), True), + "Color": (torch.linspace(0.0, 0.9, _BINS), True), + "Contrast": (torch.linspace(0.0, 0.9, _BINS), True), + "Sharpness": (torch.linspace(0.0, 0.9, _BINS), True), + "Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False), + "Solarize": (torch.linspace(256.0, 0.0, _BINS), False), + "AutoContrast": (None, None), + "Equalize": (None, None), + "Invert": (None, None), + } + + +class AutoAugment(torch.nn.Module): + r"""AutoAugment data augmentation method based on + `"AutoAugment: Learning Augmentation Strategies from Data" `_. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + policy (AutoAugmentPolicy): Desired policy enum defined by + :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed + image. If int or float, the value is used for all bands respectively. + This option is supported for PIL image and Tensor inputs. + If input is PIL Image, the options is only available for ``Pillow>=5.0.0``. + + Example: + >>> t = transforms.AutoAugment() + >>> transformed = t(image) + + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> transforms.AutoAugment(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None): + super().__init__() + self.policy = policy + self.interpolation = interpolation + self.fill = fill + + self.transforms = _get_transforms(policy) + if self.transforms is None: + raise ValueError("The provided policy {} is not recognized.".format(policy)) + self._op_meta = _get_magnitudes() + + @staticmethod + def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: + """Get parameters for autoaugment transformation + + Returns: + params required by the autoaugment transformation + """ + policy_id = torch.randint(transform_num, (1,)).item() + probs = torch.rand((2,)) + signs = torch.randint(2, (2,)) + + return policy_id, probs, signs + + def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]: + return self._op_meta[name] + + def forward(self, img: Tensor): + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: AutoAugmented image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F._get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + transform_id, probs, signs = self.get_params(len(self.transforms)) + + for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]): + if probs[i] <= p: + magnitudes, signed = self._get_op_meta(op_name) + magnitude = float(magnitudes[magnitude_id].item()) \ + if magnitudes is not None and magnitude_id is not None else 0.0 + if signed is not None and signed and signs[i] == 0: + magnitude *= -1.0 + + if op_name == "ShearX": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], + interpolation=self.interpolation, fill=fill) + elif op_name == "ShearY": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], + interpolation=self.interpolation, fill=fill) + elif op_name == "TranslateX": + img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0, + interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) + elif op_name == "TranslateY": + img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0, + interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) + elif op_name == "Rotate": + img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill) + elif op_name == "Brightness": + img = F.adjust_brightness(img, 1.0 + magnitude) + elif op_name == "Color": + img = F.adjust_saturation(img, 1.0 + magnitude) + elif op_name == "Contrast": + img = F.adjust_contrast(img, 1.0 + magnitude) + elif op_name == "Sharpness": + img = F.adjust_sharpness(img, 1.0 + magnitude) + elif op_name == "Posterize": + img = F.posterize(img, int(magnitude)) + elif op_name == "Solarize": + img = F.solarize(img, magnitude) + elif op_name == "AutoContrast": + img = F.autocontrast(img) + elif op_name == "Equalize": + img = F.equalize(img) + elif op_name == "Invert": + img = F.invert(img) + else: + raise ValueError("The provided operator {} is not recognized.".format(op_name)) + + return img + + def __repr__(self): + return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 6be63765529..ec7e511989a 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1173,3 +1173,118 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa if not isinstance(img, torch.Tensor): output = to_pil_image(output) return output + + +def invert(img: Tensor) -> Tensor: + """Invert the colors of an RGB/grayscale PIL Image or torch Tensor. + + Args: + img (PIL Image or Tensor): Image to have its colors inverted. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + + Returns: + PIL Image or Tensor: Color inverted image. + """ + if not isinstance(img, torch.Tensor): + return F_pil.invert(img) + + return F_t.invert(img) + + +def posterize(img: Tensor, bits: int) -> Tensor: + """Posterize a PIL Image or torch Tensor by reducing the number of bits for each color channel. + + Args: + img (PIL Image or Tensor): Image to have its colors posterized. + If img is a Tensor, it should be of type torch.uint8 and + it is expected to be in [..., H, W] format, where ... means + it can have an arbitrary number of trailing dimensions. + bits (int): The number of bits to keep for each channel (0-8). + Returns: + PIL Image or Tensor: Posterized image. + """ + if not (0 <= bits <= 8): + raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits)) + + if not isinstance(img, torch.Tensor): + return F_pil.posterize(img, bits) + + return F_t.posterize(img, bits) + + +def solarize(img: Tensor, threshold: float) -> Tensor: + """Solarize a PIL Image or torch Tensor by inverting all pixel values above a threshold. + + Args: + img (PIL Image or Tensor): Image to have its colors inverted. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + threshold (float): All pixels equal or above this value are inverted. + Returns: + PIL Image or Tensor: Solarized image. + """ + if not isinstance(img, torch.Tensor): + return F_pil.solarize(img, threshold) + + return F_t.solarize(img, threshold) + + +def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: + """Adjust the sharpness of an Image. + + Args: + img (PIL Image or Tensor): Image to be adjusted. + sharpness_factor (float): How much to adjust the sharpness. Can be + any non negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + + Returns: + PIL Image or Tensor: Sharpness adjusted image. + """ + if not isinstance(img, torch.Tensor): + return F_pil.adjust_sharpness(img, sharpness_factor) + + return F_t.adjust_sharpness(img, sharpness_factor) + + +def autocontrast(img: Tensor) -> Tensor: + """Maximize contrast of a PIL Image or torch Tensor by remapping its + pixels per channel so that the lowest becomes black and the lightest + becomes white. + + Args: + img (PIL Image or Tensor): Image on which autocontrast is applied. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + + Returns: + PIL Image or Tensor: An image that was autocontrasted. + """ + if not isinstance(img, torch.Tensor): + return F_pil.autocontrast(img) + + return F_t.autocontrast(img) + + +def equalize(img: Tensor) -> Tensor: + """Equalize the histogram of a PIL Image or torch Tensor by applying + a non-linear mapping to the input in order to create a uniform + distribution of grayscale values in the output. + + Args: + img (PIL Image or Tensor): Image on which equalize is applied. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + + Returns: + PIL Image or Tensor: An image that was equalized. + """ + if not isinstance(img, torch.Tensor): + return F_pil.equalize(img) + + return F_t.equalize(img) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 51d83f0fd63..26f3b504d99 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -606,3 +606,48 @@ def to_grayscale(img, num_output_channels): raise ValueError('num_output_channels should be either 1 or 3') return img + + +@torch.jit.unused +def invert(img): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.invert(img) + + +@torch.jit.unused +def posterize(img, bits): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.posterize(img, bits) + + +@torch.jit.unused +def solarize(img, threshold): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.solarize(img, threshold) + + +@torch.jit.unused +def adjust_sharpness(img, sharpness_factor): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Sharpness(img) + img = enhancer.enhance(sharpness_factor) + return img + + +@torch.jit.unused +def autocontrast(img): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.autocontrast(img) + + +@torch.jit.unused +def equalize(img): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.equalize(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index ebc178ac561..a72cc41f5cd 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -570,6 +570,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: + ratio = float(ratio) bound = 1.0 if img1.is_floating_point() else 255.0 return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) @@ -1180,3 +1181,133 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) return img + + +def invert(img: Tensor) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + + _assert_channels(img, [1, 3]) + + bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device) + return bound - img + + +def posterize(img: Tensor, bits: int) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + if img.dtype != torch.uint8: + raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype)) + + _assert_channels(img, [1, 3]) + mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) + return img & mask + + +def solarize(img: Tensor, threshold: float) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + + _assert_channels(img, [1, 3]) + + inverted_img = invert(img) + return torch.where(img >= threshold, inverted_img, img) + + +def _blurred_degenerate_image(img: Tensor) -> Tensor: + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + kernel = torch.ones((3, 3), dtype=dtype, device=img.device) + kernel[1, 1] = 5.0 + kernel /= kernel.sum() + kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) + result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3]) + result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype) + + result = img.clone() + result[..., 1:-1, 1:-1] = result_tmp + + return result + + +def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: + if sharpness_factor < 0: + raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor)) + + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + _assert_channels(img, [1, 3]) + + if img.size(-1) <= 2 or img.size(-2) <= 2: + return img + + return _blend(img, _blurred_degenerate_image(img), sharpness_factor) + + +def autocontrast(img: Tensor) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + + _assert_channels(img, [1, 3]) + + bound = 1.0 if img.is_floating_point() else 255.0 + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype) + maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype) + eq_idxs = torch.where(minimum == maximum)[0] + minimum[eq_idxs] = 0 + maximum[eq_idxs] = bound + scale = bound / (maximum - minimum) + + return ((img - minimum) * scale).clamp(0, bound).to(img.dtype) + + +def _scale_channel(img_chan): + hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) + + nonzero_hist = hist[hist != 0] + step = nonzero_hist[:-1].sum() // 255 + if step == 0: + return img_chan + + lut = (torch.cumsum(hist, 0) + (step // 2)) // step + lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255) + + return lut[img_chan.to(torch.int64)].to(torch.uint8) + + +def _equalize_single_image(img: Tensor) -> Tensor: + return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))]) + + +def equalize(img: Tensor) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if not (3 <= img.ndim <= 4): + raise TypeError("Input image tensor should have 3 or 4 dimensions, but found {}".format(img.ndim)) + if img.dtype != torch.uint8: + raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype)) + + _assert_channels(img, [1, 3]) + + if img.ndim == 3: + return _equalize_single_image(img) + + return torch.stack([_equalize_single_image(x) for x in img]) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 3b159fd3f22..117ba74b83a 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -21,7 +21,8 @@ "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode"] + "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", + "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] class Compose: @@ -1038,7 +1039,7 @@ def __repr__(self): class ColorJitter(torch.nn.Module): - """Randomly change the brightness, contrast and saturation of an image. + """Randomly change the brightness, contrast, saturation and hue of an image. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness. @@ -1699,3 +1700,190 @@ def _setup_angle(x, name, req_sizes=(2, )): _check_sequence_input(x, name, req_sizes) return [float(d) for d in x] + + +class RandomInvert(torch.nn.Module): + """Inverts the colors of the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions. + + Args: + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be inverted. + + Returns: + PIL Image or Tensor: Randomly color inverted image. + """ + if torch.rand(1).item() < self.p: + return F.invert(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomPosterize(torch.nn.Module): + """Posterize the image randomly with a given probability by reducing the + number of bits for each color channel. The image can be a PIL Image or a torch + Tensor, in which case it is expected to have [..., H, W] shape, where ... means + an arbitrary number of leading dimensions. + + Args: + bits (int): number of bits to keep for each channel (0-8) + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, bits, p=0.5): + super().__init__() + self.bits = bits + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be posterized. + + Returns: + PIL Image or Tensor: Randomly posterized image. + """ + if torch.rand(1).item() < self.p: + return F.posterize(img, self.bits) + return img + + def __repr__(self): + return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) + + +class RandomSolarize(torch.nn.Module): + """Solarize the image randomly with a given probability by inverting all pixel + values above a threshold. The image can be a PIL Image or a torch Tensor, in + which case it is expected to have [..., H, W] shape, where ... means an arbitrary + number of leading dimensions. + + Args: + threshold (float): all pixels equal or above this value are inverted. + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, threshold, p=0.5): + super().__init__() + self.threshold = threshold + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be solarized. + + Returns: + PIL Image or Tensor: Randomly solarized image. + """ + if torch.rand(1).item() < self.p: + return F.solarize(img, self.threshold) + return img + + def __repr__(self): + return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) + + +class RandomAdjustSharpness(torch.nn.Module): + """Adjust the sharpness of the image randomly with a given probability. The image + can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] + shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness_factor (float): How much to adjust the sharpness. Can be + any non negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, sharpness_factor, p=0.5): + super().__init__() + self.sharpness_factor = sharpness_factor + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be sharpened. + + Returns: + PIL Image or Tensor: Randomly sharpened image. + """ + if torch.rand(1).item() < self.p: + return F.adjust_sharpness(img, self.sharpness_factor) + return img + + def __repr__(self): + return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) + + +class RandomAutocontrast(torch.nn.Module): + """Autocontrast the pixels of the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions. + + Args: + p (float): probability of the image being autocontrasted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be autocontrasted. + + Returns: + PIL Image or Tensor: Randomly autocontrasted image. + """ + if torch.rand(1).item() < self.p: + return F.autocontrast(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomEqualize(torch.nn.Module): + """Equalize the histogram of the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions. + + Args: + p (float): probability of the image being equalized. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be equalized. + + Returns: + PIL Image or Tensor: Randomly equalized image. + """ + if torch.rand(1).item() < self.p: + return F.equalize(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p)