diff --git a/test/test_transforms.py b/test/test_transforms.py index fbfab98be45..f2ca0dc9d1e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1316,7 +1316,7 @@ def test_adjusts_L_mode(self): self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L') def test_color_jitter(self): - color_jitter = transforms.ColorJitter(2, 2, 2, 0.1, 2) + color_jitter = transforms.ColorJitter(2, 2, 2, 0.1) x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] @@ -1840,6 +1840,14 @@ def test_random_solarize(self): [{"threshold": 192}] ) + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_adjust_sharpness(self): + self._test_randomness( + F.adjust_sharpness, + transforms.RandomAdjustSharpness, + [{"sharpness_factor": 2.0}] + ) + @unittest.skipIf(stats is None, 'scipy.stats not available') def test_random_autocontrast(self): self._test_randomness( diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 7af1f1d4c46..ea3f818ad45 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -104,6 +104,12 @@ def test_random_solarize(self): 'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_random_adjust_sharpness(self): + fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0} + self._test_op( + 'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + def test_random_autocontrast(self): self._test_op('autocontrast', 'RandomAutocontrast') @@ -138,14 +144,8 @@ def test_color_jitter(self): "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max" ) - for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]: - meth_kwargs = {"sharpness": f} - self._test_class_op( - "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" - ) - # All 4 parameters together - meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2, "sharpness": 0.2} + meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2} self._test_class_op( "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max" ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index f4416b36acd..c6322ef71d4 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -22,7 +22,7 @@ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAutocontrast", "RandomEqualize"] + "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] class Compose: @@ -1039,7 +1039,7 @@ def __repr__(self): class ColorJitter(torch.nn.Module): - """Randomly change the brightness, contrast, saturation, hue and sharpness of an image. + """Randomly change the brightness, contrast, saturation and hue of an image. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness. @@ -1054,19 +1054,15 @@ class ColorJitter(torch.nn.Module): hue (float or tuple of float (min, max)): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. - sharpness (float or tuple of float (min, max)): How much to jitter sharpness. - sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness] - or the given [min, max]. Should be non negative numbers. """ - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharpness=0): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): super().__init__() self.brightness = self._check_input(brightness, 'brightness') self.contrast = self._check_input(contrast, 'contrast') self.saturation = self._check_input(saturation, 'saturation') self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) - self.sharpness = self._check_input(sharpness, 'sharpness') @torch.jit.unused def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): @@ -1082,7 +1078,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs else: raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) - # if value is 0 or (1., 1.) for brightness/contrast/saturation/sharpness + # if value is 0 or (1., 1.) for brightness/contrast/saturation # or (0., 0.) for hue, do nothing if value[0] == value[1] == center: value = None @@ -1092,10 +1088,8 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs def get_params(brightness: Optional[List[float]], contrast: Optional[List[float]], saturation: Optional[List[float]], - hue: Optional[List[float]], - sharpness: Optional[List[float]] - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float], - Optional[float]]: + hue: Optional[List[float]] + ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: """Get the parameters for the randomized transform to be applied on image. Args: @@ -1107,22 +1101,19 @@ def get_params(brightness: Optional[List[float]], uniformly. Pass None to turn off the transformation. hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. Pass None to turn off the transformation. - sharpness (tuple of float (min, max), optional): The range from which the sharpness is chosen - uniformly. Pass None to turn off the transformation. Returns: tuple: The parameters used to apply the randomized transform along with their random order. """ - fn_idx = torch.randperm(5) + fn_idx = torch.randperm(4) b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) - sp = None if sharpness is None else float(torch.empty(1).uniform_(sharpness[0], sharpness[1])) - return fn_idx, b, c, s, h, sp + return fn_idx, b, c, s, h def forward(self, img): """ @@ -1132,8 +1123,8 @@ def forward(self, img): Returns: PIL Image or Tensor: Color jittered image. """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, sharpness_factor = \ - self.get_params(self.brightness, self.contrast, self.saturation, self.hue, self.sharpness) + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ + self.get_params(self.brightness, self.contrast, self.saturation, self.hue) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: @@ -1144,8 +1135,6 @@ def forward(self, img): img = F.adjust_saturation(img, saturation_factor) elif fn_id == 3 and hue_factor is not None: img = F.adjust_hue(img, hue_factor) - elif fn_id == 4 and sharpness_factor is not None: - img = F.adjust_sharpness(img, sharpness_factor) return img @@ -1154,8 +1143,7 @@ def __repr__(self): format_string += 'brightness={0}'.format(self.brightness) format_string += ', contrast={0}'.format(self.contrast) format_string += ', saturation={0}'.format(self.saturation) - format_string += ', hue={0}'.format(self.hue) - format_string += ', sharpness={0})'.format(self.sharpness) + format_string += ', hue={0})'.format(self.hue) return format_string @@ -1838,6 +1826,49 @@ def __repr__(self): return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) +class RandomAdjustSharpness(torch.nn.Module): + """Adjust the sharpness of the image randomly with a given probability. The image + can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] + shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness_factor (float): How much to adjust the sharpness. Can be + any non negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, sharpness_factor, p=0.5): + super().__init__() + self.sharpness_factor = sharpness_factor + self.p = p + + @staticmethod + def get_params() -> float: + """Choose a value for the random transformation. + + Returns: + float: Random value which is used to determine whether the random transformation + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be sharpened. + + Returns: + PIL Image or Tensor: Randomly sharpened image. + """ + if self.get_params() < self.p: + return F.adjust_sharpness(img, self.sharpness_factor) + return img + + def __repr__(self): + return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) + + class RandomAutocontrast(torch.nn.Module): """Autocontrast the pixels of the given image randomly with a given probability. The image can be a PIL Image or a torch Tensor, in which case it is expected