Skip to content

Commit

Permalink
Implement all AutoAugment transforms + Policies (#3123)
Browse files Browse the repository at this point in the history
* Invert Transform (#3104)

* Adding invert operator.

* Make use of the _assert_channels().

* Update upper bound value.

* Remove private doc from invert, create or reuse generic testing methods to avoid duplication of code in the tests. (#3106)

* Create posterize transformation and refactor common methods to assist reuse. (#3108)

* Implement the solarize transform. (#3112)

* Implement the adjust_sharpness transform (#3114)

* Adding functional operator for sharpness.

* Adding transforms for sharpness.

* Handling tiny images and adding a test.

* Implement the autocontrast transform. (#3117)

* Implement the equalize transform (#3119)

* Implement the equalize transform.

* Turn off deterministic for histogram.

* Fixing test. (#3126)

* Force ratio to be float to avoid numeric overflows on blend. (#3127)

* Separate the tests of Adjust Sharpness from ColorJitter. (#3128)

* Add AutoAugment Policies and main Transform (#3142)

* Separate the tests of Adjust Sharpness from ColorJitter.

* Initial implementation, not-jitable.

* AutoAugment passing JIT.

* Adding tests/docs, changing formatting.

* Update test.

* Fix formats

* Fix documentation and imports.

* Apply changes from code review:
- Move the transformations outside of AutoAugment on a separate method.
- Renamed degenerate method for sharpness for better clarity.

* Update torchvision/transforms/functional.py

Co-authored-by: vfdev <[email protected]>

* Apply more changes from code review:
- Add InterpolationMode parameter.
- Move all declarations away from AutoAugment constructor and into the private method.

* Update documentation.

* Apply suggestions from code review

Co-authored-by: Francisco Massa <[email protected]>

* Apply changes from code review:
- Refactor code to eliminate as any to() and clamp() as possible.
- Reuse methods where possible.
- Apply speed ups.

* Replacing pad.

Co-authored-by: vfdev <[email protected]>
Co-authored-by: Francisco Massa <[email protected]>
  • Loading branch information
3 people authored Dec 14, 2020
1 parent 4eab7a6 commit 83171d6
Show file tree
Hide file tree
Showing 9 changed files with 968 additions and 4 deletions.
76 changes: 74 additions & 2 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,14 @@ def test_pad(self):

self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)

def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"):
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max",
dts=(None, torch.float32, torch.float64)):
script_fn = torch.jit.script(fn)
torch.manual_seed(15)
tensor, pil_img = self._create_data(26, 34, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)

for dt in [None, torch.float32, torch.float64]:
for dt in dts:

if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)
Expand Down Expand Up @@ -862,6 +863,77 @@ def test_gaussian_blur(self):
msg="{}, {}".format(ksize, sigma)
)

def test_invert(self):
self._test_adjust_fn(
F.invert,
F_pil.invert,
F_t.invert,
[{}],
tol=1.0,
agg_method="max"
)

def test_posterize(self):
self._test_adjust_fn(
F.posterize,
F_pil.posterize,
F_t.posterize,
[{"bits": bits} for bits in range(0, 8)],
tol=1.0,
agg_method="max",
dts=(None,)
)

def test_solarize(self):
self._test_adjust_fn(
F.solarize,
F_pil.solarize,
F_t.solarize,
[{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
tol=1.0,
agg_method="max",
dts=(None,)
)
self._test_adjust_fn(
F.solarize,
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
F_t.solarize,
[{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
tol=1.0,
agg_method="max",
dts=(torch.float32, torch.float64)
)

def test_adjust_sharpness(self):
self._test_adjust_fn(
F.adjust_sharpness,
F_pil.adjust_sharpness,
F_t.adjust_sharpness,
[{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)

def test_autocontrast(self):
self._test_adjust_fn(
F.autocontrast,
F_pil.autocontrast,
F_t.autocontrast,
[{}],
tol=1.0,
agg_method="max"
)

def test_equalize(self):
torch.set_deterministic(False)
self._test_adjust_fn(
F.equalize,
F_pil.equalize,
F_t.equalize,
[{}],
tol=1.0,
agg_method="max",
dts=(None,)
)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
123 changes: 123 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,48 @@ def test_adjust_hue(self):
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))

def test_adjust_sharpness(self):
x_shape = [4, 4, 3]
x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105,
111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')

# test 0
y_pil = F.adjust_sharpness(x_pil, 1)
y_np = np.array(y_pil)
self.assertTrue(np.allclose(y_np, x_np))

# test 1
y_pil = F.adjust_sharpness(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30,
30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101,
107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))

# test 2
y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112,
119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))

# test 3
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
x_th = torch.tensor(x_np.transpose(2, 0, 1))
y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil).transpose(2, 0, 1)
y_th = F.adjust_sharpness(x_th, 2)
self.assertTrue(np.allclose(y_np, y_th.numpy()))

def test_adjust_gamma(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
Expand Down Expand Up @@ -1270,6 +1312,7 @@ def test_adjusts_L_mode(self):
self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L')
self.assertEqual(F.adjust_sharpness(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')

def test_color_jitter(self):
Expand Down Expand Up @@ -1751,6 +1794,86 @@ def test_gaussian_blur_asserts(self):
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
transforms.GaussianBlur(3, "sigma_string")

def _test_randomness(self, fn, trans, configs):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 16, 18))

for p in [0.5, 0.7]:
for config in configs:
inv_img = fn(img, **config)

num_samples = 250
counts = 0
for _ in range(num_samples):
tranformation = trans(p=p, **config)
tranformation.__repr__()
out = tranformation(img)
if out == inv_img:
counts += 1

p_value = stats.binom_test(counts, num_samples, p=p)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_invert(self):
self._test_randomness(
F.invert,
transforms.RandomInvert,
[{}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_posterize(self):
self._test_randomness(
F.posterize,
transforms.RandomPosterize,
[{"bits": 4}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_solarize(self):
self._test_randomness(
F.solarize,
transforms.RandomSolarize,
[{"threshold": 192}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_adjust_sharpness(self):
self._test_randomness(
F.adjust_sharpness,
transforms.RandomAdjustSharpness,
[{"sharpness_factor": 2.0}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_autocontrast(self):
self._test_randomness(
F.autocontrast,
transforms.RandomAutocontrast,
[{}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_equalize(self):
self._test_randomness(
F.equalize,
transforms.RandomEqualize,
[{}]
)

def test_autoaugment(self):
for policy in transforms.AutoAugmentPolicy:
for fill in [None, 85, (128, 128, 128)]:
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.AutoAugment(policy=policy, fill=fill)
for _ in range(100):
img = transform(img)
transform.__repr__()


if __name__ == '__main__':
unittest.main()
44 changes: 44 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,34 @@ def test_random_horizontal_flip(self):
def test_random_vertical_flip(self):
self._test_op('vflip', 'RandomVerticalFlip')

def test_random_invert(self):
self._test_op('invert', 'RandomInvert')

def test_random_posterize(self):
fn_kwargs = meth_kwargs = {"bits": 4}
self._test_op(
'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_solarize(self):
fn_kwargs = meth_kwargs = {"threshold": 192.0}
self._test_op(
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_adjust_sharpness(self):
fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0}
self._test_op(
'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_autocontrast(self):
self._test_op('autocontrast', 'RandomAutocontrast')

def test_random_equalize(self):
torch.set_deterministic(False)
self._test_op('equalize', 'RandomEqualize')

def test_color_jitter(self):

tol = 1.0 + 1e-10
Expand Down Expand Up @@ -598,6 +626,22 @@ def test_convert_image_dtype(self):
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))

def test_autoaugment(self):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)

for policy in T.AutoAugmentPolicy:
for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
for _ in range(100):
transform = T.AutoAugment(policy=policy, fill=fill)
s_transform = torch.jit.script(transform)

self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
1 change: 1 addition & 0 deletions torchvision/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .transforms import *
from .autoaugment import *
Loading

0 comments on commit 83171d6

Please sign in to comment.