From 17f3c251e245d1ca3c30b19beb5f9fde3fe396a9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Dec 2020 21:49:22 +0000 Subject: [PATCH] Implement the solarize transform. --- test/test_functional_tensor.py | 20 +++++++++ test/test_transforms.py | 8 ++++ test/test_transforms_tensor.py | 6 +++ torchvision/transforms/functional.py | 24 ++++++++-- torchvision/transforms/functional_pil.py | 7 +++ torchvision/transforms/functional_tensor.py | 21 ++++++++- torchvision/transforms/transforms.py | 49 +++++++++++++++++++-- 7 files changed, 128 insertions(+), 7 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 4df930d4517..63e8271a858 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -884,6 +884,26 @@ def test_posterize(self): dts=(None,) ) + def test_solarize(self): + self._test_adjust_fn( + F.solarize, + F_pil.solarize, + F_t.solarize, + [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]], + tol=1.0, + agg_method="max", + dts=(None,) + ) + self._test_adjust_fn( + F.solarize, + lambda img, threshold: F_pil.solarize(img, 255 * threshold), + F_t.solarize, + [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]], + tol=1.0, + agg_method="max", + dts=(torch.float32, torch.float64) + ) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index 81757510302..fc52fc66686 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1787,6 +1787,14 @@ def test_random_posterize(self): [{"bits": 4}] ) + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_solarize(self): + self._test_randomness( + F.solarize, + transforms.RandomSolarize, + [{"threshold": 192}] + ) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index eba782a75cb..331f8a2eb4f 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -98,6 +98,12 @@ def test_random_posterize(self): 'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_random_solarize(self): + fn_kwargs = meth_kwargs = {"threshold": 192.0} + self._test_op( + 'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + def test_color_jitter(self): tol = 1.0 + 1e-10 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index dc0c2a2f2bd..e3b0a9bd98a 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1203,9 +1203,9 @@ def posterize(img: Tensor, bits: int) -> Tensor: Args: img (PIL Image or Tensor): Image to have its colors inverted. - If img is a Tensor, it is expected to be in [..., H, W] format, - where ... means it can have an arbitrary number of trailing - dimensions. + If img is a Tensor, it should be of type torch.uint8 and + it is expected to be in [..., H, W] format, where ... means + it can have an arbitrary number of trailing dimensions. bits (int): The number of bits to keep for each channel (0-8). Returns: PIL Image: Posterized image. @@ -1217,3 +1217,21 @@ def posterize(img: Tensor, bits: int) -> Tensor: return F_pil.posterize(img, bits) return F_t.posterize(img, bits) + + +def solarize(img: Tensor, threshold: float) -> Tensor: + """Solarize a PIL Image or torch Tensor by inverting all pixel values above a threshold. + + Args: + img (PIL Image or Tensor): Image to have its colors inverted. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + threshold (float): All pixels equal or above this value are inverted. + Returns: + PIL Image: Solarized image. + """ + if not isinstance(img, torch.Tensor): + return F_pil.solarize(img, threshold) + + return F_t.solarize(img, threshold) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 2e1b16f26b6..d60588fd138 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -620,3 +620,10 @@ def posterize(img, bits): if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return ImageOps.posterize(img, bits) + + +@torch.jit.unused +def solarize(img, threshold): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.solarize(img, threshold) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 003f0138b0c..5eb70988f90 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1192,7 +1192,7 @@ def invert(img: Tensor) -> Tensor: bound = 1.0 if img.is_floating_point() else 255.0 dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - return (bound - img.to(dtype)).to(img.dtype) + return (bound - img.to(dtype)).clamp(0, bound).to(img.dtype) def posterize(img: Tensor, bits: int) -> Tensor: @@ -1207,3 +1207,22 @@ def posterize(img: Tensor, bits: int) -> Tensor: _assert_channels(img, [1, 3]) mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) return img & mask + + +def solarize(img: Tensor, threshold: float) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + + _assert_channels(img, [1, 3]) + + bound = 1.0 if img.is_floating_point() else 255.0 + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + result = img.clone().view(-1) + invert_idx = torch.where(result >= threshold)[0] + result[invert_idx] = (bound - result[invert_idx].to(dtype=dtype)).clamp(0, bound).to(dtype=img.dtype) + + return result.view(img.shape) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index fbe7a23fc61..66ccb42e525 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -21,7 +21,8 @@ "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize"] + "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", + "RandomSolarize"] class Compose: @@ -1705,7 +1706,7 @@ class RandomInvert(torch.nn.Module): """Inverts the colors of the given image randomly with a given probability. The image can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading - dimensions + dimensions. Args: p (float): probability of the image being color inverted. Default value is 0.5 @@ -1745,7 +1746,7 @@ class RandomPosterize(torch.nn.Module): """Posterize the image randomly with a given probability by reducing the number of bits for each color channel. The image can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] shape, where ... means - an arbitrary number of leading dimensions + an arbitrary number of leading dimensions. Args: bits (int): number of bits to keep for each channel (0-8) @@ -1781,3 +1782,45 @@ def forward(self, img): def __repr__(self): return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) + + +class RandomSolarize(torch.nn.Module): + """Solarize the image randomly with a given probability by inverting all pixel + values above a threshold. The image can be a PIL Image or a torch Tensor, in + which case it is expected to have [..., H, W] shape, where ... means an arbitrary + number of leading dimensions. + + Args: + threshold (float): all pixels equal or above this value are inverted. + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, threshold, p=0.5): + super().__init__() + self.threshold = threshold + self.p = p + + @staticmethod + def get_params() -> float: + """Choose a value for the random transformation. + + Returns: + float: Random value which is used to determine whether the random transformation + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be solarized. + + Returns: + PIL Image or Tensor: Randomly solarized image. + """ + if self.get_params() < self.p: + return F.solarize(img, self.threshold) + return img + + def __repr__(self): + return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p)