From f6ab10793eb7b1d0e079d0b6df8accfa824b21ed Mon Sep 17 00:00:00 2001 From: arturml Date: Mon, 16 Apr 2018 11:10:25 -0300 Subject: [PATCH] Add support in transforms.ToTensor for PIL Images mod '1' (#471) * Add case in test_to_tensor for PIL Images mode '1' * Add support in ToTensor for PIL Images mode '1' * Fix pep8 issues --- test/test_transforms.py | 6 ++++++ torchvision/transforms/functional.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index e2232e2491b..21305f5fa8f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -404,6 +404,12 @@ def test_to_tensor(self): expected_output = ndarray.transpose((2, 0, 1)) assert np.allclose(output.numpy(), expected_output) + # separate test for mode '1' PIL images + input_data = torch.ByteTensor(1, height, width).bernoulli_() + img = transforms.ToPILImage()(input_data.mul(255)).convert('1') + output = trans(img) + assert np.allclose(input_data.numpy(), output.numpy()) + @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_to_tensor(self): trans = transforms.ToTensor() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5d5325078be..7d1a8921a6b 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -64,9 +64,11 @@ def to_tensor(pic): img = torch.from_numpy(np.array(pic, np.int16, copy=False)) elif pic.mode == 'F': img = torch.from_numpy(np.array(pic, np.float32, copy=False)) + elif pic.mode == '1': + img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) else: img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) - # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK if pic.mode == 'YCbCr': nchannel = 3 elif pic.mode == 'I;16':