Skip to content

Commit

Permalink
Implement the solarize transform.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Dec 3, 2020
1 parent 4b800b9 commit 17f3c25
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 7 deletions.
20 changes: 20 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,26 @@ def test_posterize(self):
dts=(None,)
)

def test_solarize(self):
self._test_adjust_fn(
F.solarize,
F_pil.solarize,
F_t.solarize,
[{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
tol=1.0,
agg_method="max",
dts=(None,)
)
self._test_adjust_fn(
F.solarize,
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
F_t.solarize,
[{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
tol=1.0,
agg_method="max",
dts=(torch.float32, torch.float64)
)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
8 changes: 8 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,14 @@ def test_random_posterize(self):
[{"bits": 4}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_solarize(self):
self._test_randomness(
F.solarize,
transforms.RandomSolarize,
[{"threshold": 192}]
)


if __name__ == '__main__':
unittest.main()
6 changes: 6 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def test_random_posterize(self):
'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_solarize(self):
fn_kwargs = meth_kwargs = {"threshold": 192.0}
self._test_op(
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_color_jitter(self):

tol = 1.0 + 1e-10
Expand Down
24 changes: 21 additions & 3 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,9 +1203,9 @@ def posterize(img: Tensor, bits: int) -> Tensor:
Args:
img (PIL Image or Tensor): Image to have its colors inverted.
If img is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
If img is a Tensor, it should be of type torch.uint8 and
it is expected to be in [..., H, W] format, where ... means
it can have an arbitrary number of trailing dimensions.
bits (int): The number of bits to keep for each channel (0-8).
Returns:
PIL Image: Posterized image.
Expand All @@ -1217,3 +1217,21 @@ def posterize(img: Tensor, bits: int) -> Tensor:
return F_pil.posterize(img, bits)

return F_t.posterize(img, bits)


def solarize(img: Tensor, threshold: float) -> Tensor:
"""Solarize a PIL Image or torch Tensor by inverting all pixel values above a threshold.
Args:
img (PIL Image or Tensor): Image to have its colors inverted.
If img is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
threshold (float): All pixels equal or above this value are inverted.
Returns:
PIL Image: Solarized image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.solarize(img, threshold)

return F_t.solarize(img, threshold)
7 changes: 7 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,10 @@ def posterize(img, bits):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.posterize(img, bits)


@torch.jit.unused
def solarize(img, threshold):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.solarize(img, threshold)
21 changes: 20 additions & 1 deletion torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,7 @@ def invert(img: Tensor) -> Tensor:

bound = 1.0 if img.is_floating_point() else 255.0
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
return (bound - img.to(dtype)).to(img.dtype)
return (bound - img.to(dtype)).clamp(0, bound).to(img.dtype)


def posterize(img: Tensor, bits: int) -> Tensor:
Expand All @@ -1207,3 +1207,22 @@ def posterize(img: Tensor, bits: int) -> Tensor:
_assert_channels(img, [1, 3])
mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
return img & mask


def solarize(img: Tensor, threshold: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))

_assert_channels(img, [1, 3])

bound = 1.0 if img.is_floating_point() else 255.0
dtype = img.dtype if torch.is_floating_point(img) else torch.float32

result = img.clone().view(-1)
invert_idx = torch.where(result >= threshold)[0]
result[invert_idx] = (bound - result[invert_idx].to(dtype=dtype)).clamp(0, bound).to(dtype=img.dtype)

return result.view(img.shape)
49 changes: 46 additions & 3 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize"]
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
"RandomSolarize"]


class Compose:
Expand Down Expand Up @@ -1705,7 +1706,7 @@ class RandomInvert(torch.nn.Module):
"""Inverts the colors of the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
dimensions.
Args:
p (float): probability of the image being color inverted. Default value is 0.5
Expand Down Expand Up @@ -1745,7 +1746,7 @@ class RandomPosterize(torch.nn.Module):
"""Posterize the image randomly with a given probability by reducing the
number of bits for each color channel. The image can be a PIL Image or a torch
Tensor, in which case it is expected to have [..., H, W] shape, where ... means
an arbitrary number of leading dimensions
an arbitrary number of leading dimensions.
Args:
bits (int): number of bits to keep for each channel (0-8)
Expand Down Expand Up @@ -1781,3 +1782,45 @@ def forward(self, img):

def __repr__(self):
return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p)


class RandomSolarize(torch.nn.Module):
"""Solarize the image randomly with a given probability by inverting all pixel
values above a threshold. The image can be a PIL Image or a torch Tensor, in
which case it is expected to have [..., H, W] shape, where ... means an arbitrary
number of leading dimensions.
Args:
threshold (float): all pixels equal or above this value are inverted.
p (float): probability of the image being color inverted. Default value is 0.5
"""

def __init__(self, threshold, p=0.5):
super().__init__()
self.threshold = threshold
self.p = p

@staticmethod
def get_params() -> float:
"""Choose a value for the random transformation.
Returns:
float: Random value which is used to determine whether the random transformation
should occur.
"""
return torch.rand(1).item()

def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be solarized.
Returns:
PIL Image or Tensor: Randomly solarized image.
"""
if self.get_params() < self.p:
return F.solarize(img, self.threshold)
return img

def __repr__(self):
return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p)

0 comments on commit 17f3c25

Please sign in to comment.