diff --git a/test/test_transforms.py b/test/test_transforms.py index e2232e2491b..21305f5fa8f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -404,6 +404,12 @@ def test_to_tensor(self): expected_output = ndarray.transpose((2, 0, 1)) assert np.allclose(output.numpy(), expected_output) + # separate test for mode '1' PIL images + input_data = torch.ByteTensor(1, height, width).bernoulli_() + img = transforms.ToPILImage()(input_data.mul(255)).convert('1') + output = trans(img) + assert np.allclose(input_data.numpy(), output.numpy()) + @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_to_tensor(self): trans = transforms.ToTensor() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index b9b7730df94..0b06e059b6c 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -66,9 +66,11 @@ def to_tensor(pic): img = torch.from_numpy(np.array(pic, np.int16, copy=False)) elif pic.mode == 'F': img = torch.from_numpy(np.array(pic, np.float32, copy=False)) + elif pic.mode == '1': + img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) else: img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) - # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK if pic.mode == 'YCbCr': nchannel = 3 elif pic.mode == 'I;16':