Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement all AutoAugment transforms + Policies #3123

Merged
merged 27 commits into from
Dec 14, 2020
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
10c3efa
Invert Transform (#3104)
datumbox Dec 3, 2020
cd03c18
Remove private doc from invert, create or reuse generic testing metho…
datumbox Dec 3, 2020
4b800b9
Create posterize transformation and refactor common methods to assist…
datumbox Dec 3, 2020
63b8a27
Implement the solarize transform. (#3112)
datumbox Dec 3, 2020
b4e9a2f
Implement the adjust_sharpness transform (#3114)
datumbox Dec 4, 2020
94fc573
Implement the autocontrast transform. (#3117)
datumbox Dec 4, 2020
64a3e1b
Implement the equalize transform (#3119)
datumbox Dec 4, 2020
05cf567
Fixing test. (#3126)
datumbox Dec 5, 2020
a055539
Merge branch 'master' into autoaugment_transforms
datumbox Dec 6, 2020
c7337b9
Force ratio to be float to avoid numeric overflows on blend. (#3127)
datumbox Dec 6, 2020
ff4bfbb
Separate the tests of Adjust Sharpness from ColorJitter. (#3128)
datumbox Dec 6, 2020
f6669f6
Merge branch 'master' into autoaugment_transforms
datumbox Dec 8, 2020
70f2042
Add AutoAugment Policies and main Transform (#3142)
datumbox Dec 8, 2020
8bca2f8
Merge branch 'master' into autoaugment_transforms
datumbox Dec 9, 2020
19f49c6
Merge branch 'master' into autoaugment_transforms
datumbox Dec 9, 2020
32c7f39
Merge branch 'master' into autoaugment_transforms
datumbox Dec 11, 2020
1b3d645
Apply changes from code review:
datumbox Dec 11, 2020
b62ec5d
Update torchvision/transforms/functional.py
datumbox Dec 11, 2020
061db1f
Apply more changes from code review:
datumbox Dec 11, 2020
74b9eb2
Merge remote-tracking branch 'upstream/autoaugment_transforms' into a…
datumbox Dec 11, 2020
6439dbc
Update documentation.
datumbox Dec 11, 2020
3071d53
Merge branch 'master' into autoaugment_transforms
datumbox Dec 12, 2020
a5bc492
Apply suggestions from code review
datumbox Dec 14, 2020
48cfe22
Apply changes from code review:
datumbox Dec 14, 2020
a8e6c45
Merge branch 'master' into autoaugment_transforms
datumbox Dec 14, 2020
a9a8537
Replacing pad.
datumbox Dec 14, 2020
8bf0f34
Merge remote-tracking branch 'upstream/autoaugment_transforms' into a…
datumbox Dec 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,14 @@ def test_pad(self):

self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)

def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"):
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max",
dts=(None, torch.float32, torch.float64)):
script_fn = torch.jit.script(fn)
torch.manual_seed(15)
tensor, pil_img = self._create_data(26, 34, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)

for dt in [None, torch.float32, torch.float64]:
for dt in dts:

if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)
Expand Down Expand Up @@ -862,6 +863,77 @@ def test_gaussian_blur(self):
msg="{}, {}".format(ksize, sigma)
)

def test_invert(self):
self._test_adjust_fn(
F.invert,
F_pil.invert,
F_t.invert,
[{}],
tol=1.0,
agg_method="max"
)

def test_posterize(self):
self._test_adjust_fn(
F.posterize,
F_pil.posterize,
F_t.posterize,
[{"bits": bits} for bits in range(0, 8)],
tol=1.0,
agg_method="max",
dts=(None,)
)

def test_solarize(self):
self._test_adjust_fn(
F.solarize,
F_pil.solarize,
F_t.solarize,
[{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
tol=1.0,
agg_method="max",
dts=(None,)
)
self._test_adjust_fn(
F.solarize,
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
F_t.solarize,
[{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
tol=1.0,
agg_method="max",
dts=(torch.float32, torch.float64)
)

def test_adjust_sharpness(self):
self._test_adjust_fn(
F.adjust_sharpness,
F_pil.adjust_sharpness,
F_t.adjust_sharpness,
[{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)

def test_autocontrast(self):
self._test_adjust_fn(
F.autocontrast,
F_pil.autocontrast,
F_t.autocontrast,
[{}],
tol=1.0,
agg_method="max"
)

def test_equalize(self):
torch.set_deterministic(False)
self._test_adjust_fn(
F.equalize,
F_pil.equalize,
F_t.equalize,
[{}],
tol=1.0,
agg_method="max",
dts=(None,)
)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
123 changes: 123 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,48 @@ def test_adjust_hue(self):
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))

def test_adjust_sharpness(self):
x_shape = [4, 4, 3]
x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105,
111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')

# test 0
y_pil = F.adjust_sharpness(x_pil, 1)
y_np = np.array(y_pil)
self.assertTrue(np.allclose(y_np, x_np))

# test 1
y_pil = F.adjust_sharpness(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30,
30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101,
107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))

# test 2
y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112,
119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))

# test 3
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
x_th = torch.tensor(x_np.transpose(2, 0, 1))
y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil).transpose(2, 0, 1)
y_th = F.adjust_sharpness(x_th, 2)
self.assertTrue(np.allclose(y_np, y_th.numpy()))

def test_adjust_gamma(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
Expand Down Expand Up @@ -1270,6 +1312,7 @@ def test_adjusts_L_mode(self):
self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L')
self.assertEqual(F.adjust_sharpness(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')

def test_color_jitter(self):
Expand Down Expand Up @@ -1751,6 +1794,86 @@ def test_gaussian_blur_asserts(self):
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
transforms.GaussianBlur(3, "sigma_string")

def _test_randomness(self, fn, trans, configs):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 16, 18))

for p in [0.5, 0.7]:
for config in configs:
inv_img = fn(img, **config)

num_samples = 250
counts = 0
for _ in range(num_samples):
tranformation = trans(p=p, **config)
tranformation.__repr__()
out = tranformation(img)
if out == inv_img:
counts += 1

p_value = stats.binom_test(counts, num_samples, p=p)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_invert(self):
self._test_randomness(
F.invert,
transforms.RandomInvert,
[{}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_posterize(self):
self._test_randomness(
F.posterize,
transforms.RandomPosterize,
[{"bits": 4}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_solarize(self):
self._test_randomness(
F.solarize,
transforms.RandomSolarize,
[{"threshold": 192}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_adjust_sharpness(self):
self._test_randomness(
F.adjust_sharpness,
transforms.RandomAdjustSharpness,
[{"sharpness_factor": 2.0}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_autocontrast(self):
self._test_randomness(
F.autocontrast,
transforms.RandomAutocontrast,
[{}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_equalize(self):
self._test_randomness(
F.equalize,
transforms.RandomEqualize,
[{}]
)

def test_autoaugment(self):
for policy in transforms.AutoAugmentPolicy:
for fill in [None, 85, (128, 128, 128)]:
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.AutoAugment(policy=policy, fill=fill)
for _ in range(100):
img = transform(img)
transform.__repr__()


if __name__ == '__main__':
unittest.main()
44 changes: 44 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,34 @@ def test_random_horizontal_flip(self):
def test_random_vertical_flip(self):
self._test_op('vflip', 'RandomVerticalFlip')

def test_random_invert(self):
self._test_op('invert', 'RandomInvert')

def test_random_posterize(self):
fn_kwargs = meth_kwargs = {"bits": 4}
self._test_op(
'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_solarize(self):
fn_kwargs = meth_kwargs = {"threshold": 192.0}
self._test_op(
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_adjust_sharpness(self):
fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0}
self._test_op(
'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_random_autocontrast(self):
self._test_op('autocontrast', 'RandomAutocontrast')

def test_random_equalize(self):
torch.set_deterministic(False)
self._test_op('equalize', 'RandomEqualize')

def test_color_jitter(self):

tol = 1.0 + 1e-10
Expand Down Expand Up @@ -598,6 +626,22 @@ def test_convert_image_dtype(self):
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))

def test_autoaugment(self):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)

for policy in T.AutoAugmentPolicy:
for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
for _ in range(100):
transform = T.AutoAugment(policy=policy, fill=fill)
s_transform = torch.jit.script(transform)

self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
1 change: 1 addition & 0 deletions torchvision/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .transforms import *
from .autoaugment import *
Loading