Skip to content

Commit

Permalink
Separate the tests of Adjust Sharpness from ColorJitter. (#3128)
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox authored Dec 6, 2020
1 parent c7337b9 commit ff4bfbb
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 31 deletions.
10 changes: 9 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ def test_adjusts_L_mode(self):
self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')

def test_color_jitter(self):
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1, 2)
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)

x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
Expand Down Expand Up @@ -1840,6 +1840,14 @@ def test_random_solarize(self):
[{"threshold": 192}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_adjust_sharpness(self):
self._test_randomness(
F.adjust_sharpness,
transforms.RandomAdjustSharpness,
[{"sharpness_factor": 2.0}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_autocontrast(self):
self._test_randomness(
Expand Down
14 changes: 7 additions & 7 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def test_random_solarize(self):
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_adjust_sharpness(self):
fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0}
self._test_op(
'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_autocontrast(self):
self._test_op('autocontrast', 'RandomAutocontrast')

Expand Down Expand Up @@ -138,14 +144,8 @@ def test_color_jitter(self):
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max"
)

for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
meth_kwargs = {"sharpness": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

# All 4 parameters together
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2, "sharpness": 0.2}
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max"
)
Expand Down
77 changes: 54 additions & 23 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
"RandomSolarize", "RandomAutocontrast", "RandomEqualize"]
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]


class Compose:
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def __repr__(self):


class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast, saturation, hue and sharpness of an image.
"""Randomly change the brightness, contrast, saturation and hue of an image.
Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
Expand All @@ -1054,19 +1054,15 @@ class ColorJitter(torch.nn.Module):
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
sharpness (float or tuple of float (min, max)): How much to jitter sharpness.
sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness]
or the given [min, max]. Should be non negative numbers.
"""

def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharpness=0):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__()
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
clip_first_on_zero=False)
self.sharpness = self._check_input(sharpness, 'sharpness')

@torch.jit.unused
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
Expand All @@ -1082,7 +1078,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
else:
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))

# if value is 0 or (1., 1.) for brightness/contrast/saturation/sharpness
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
Expand All @@ -1092,10 +1088,8 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
def get_params(brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]],
sharpness: Optional[List[float]]
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float],
Optional[float]]:
hue: Optional[List[float]]
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
"""Get the parameters for the randomized transform to be applied on image.
Args:
Expand All @@ -1107,22 +1101,19 @@ def get_params(brightness: Optional[List[float]],
uniformly. Pass None to turn off the transformation.
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
Pass None to turn off the transformation.
sharpness (tuple of float (min, max), optional): The range from which the sharpness is chosen
uniformly. Pass None to turn off the transformation.
Returns:
tuple: The parameters used to apply the randomized transform
along with their random order.
"""
fn_idx = torch.randperm(5)
fn_idx = torch.randperm(4)

b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
sp = None if sharpness is None else float(torch.empty(1).uniform_(sharpness[0], sharpness[1]))

return fn_idx, b, c, s, h, sp
return fn_idx, b, c, s, h

def forward(self, img):
"""
Expand All @@ -1132,8 +1123,8 @@ def forward(self, img):
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, sharpness_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue, self.sharpness)
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)

for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
Expand All @@ -1144,8 +1135,6 @@ def forward(self, img):
img = F.adjust_saturation(img, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img = F.adjust_hue(img, hue_factor)
elif fn_id == 4 and sharpness_factor is not None:
img = F.adjust_sharpness(img, sharpness_factor)

return img

Expand All @@ -1154,8 +1143,7 @@ def __repr__(self):
format_string += 'brightness={0}'.format(self.brightness)
format_string += ', contrast={0}'.format(self.contrast)
format_string += ', saturation={0}'.format(self.saturation)
format_string += ', hue={0}'.format(self.hue)
format_string += ', sharpness={0})'.format(self.sharpness)
format_string += ', hue={0})'.format(self.hue)
return format_string


Expand Down Expand Up @@ -1838,6 +1826,49 @@ def __repr__(self):
return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p)


class RandomAdjustSharpness(torch.nn.Module):
"""Adjust the sharpness of the image randomly with a given probability. The image
can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W]
shape, where ... means an arbitrary number of leading dimensions.
Args:
sharpness_factor (float): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.
p (float): probability of the image being color inverted. Default value is 0.5
"""

def __init__(self, sharpness_factor, p=0.5):
super().__init__()
self.sharpness_factor = sharpness_factor
self.p = p

@staticmethod
def get_params() -> float:
"""Choose a value for the random transformation.
Returns:
float: Random value which is used to determine whether the random transformation
should occur.
"""
return torch.rand(1).item()

def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be sharpened.
Returns:
PIL Image or Tensor: Randomly sharpened image.
"""
if self.get_params() < self.p:
return F.adjust_sharpness(img, self.sharpness_factor)
return img

def __repr__(self):
return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p)


class RandomAutocontrast(torch.nn.Module):
"""Autocontrast the pixels of the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
Expand Down

0 comments on commit ff4bfbb

Please sign in to comment.