Skip to content

Commit

Permalink
Binary MNIST/EMNIST Datasets and Datamodules (#866)
Browse files Browse the repository at this point in the history
Co-authored-by: otaj <[email protected]>
  • Loading branch information
Shion Matsumoto and otaj authored Aug 23, 2022
1 parent 645a66f commit 7d2a9a1
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 58 deletions.
2 changes: 0 additions & 2 deletions pl_bolts/datamodules/binary_emnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from pl_bolts.datamodules.emnist_datamodule import EMNISTDataModule
from pl_bolts.datasets import BinaryEMNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review


@under_review()
class BinaryEMNISTDataModule(EMNISTDataModule):
"""
.. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png
Expand Down
2 changes: 0 additions & 2 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import BinaryMNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -12,7 +11,6 @@
warn_missing_pkg("torchvision")


@under_review()
class BinaryMNISTDataModule(VisionDataModule):
"""
.. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png
Expand Down
47 changes: 38 additions & 9 deletions pl_bolts/datasets/emnist_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Tuple, Union

from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -14,9 +15,39 @@
warn_missing_pkg("PIL", pypi_name="Pillow")


@under_review()
class BinaryEMNIST(EMNIST):
def __getitem__(self, idx):
"""Binarized EMNIST Dataset.
EMNIST dataset binarized using a thresholding operation. Default threshold value is 127.
Note that the images are binarized prior to the application of any transforms.
Args:
root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
which one to use.
threshold (Union[int, float], optional): Threshold value for binarizing image.
Pixel value is set to 255 if value is greater than threshold, otherwise 0.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
Note:
Documentation is based on https://pytorch.org/vision/main/generated/torchvision.datasets.EMNIST.html
"""

def __init__(self, root: str, split: str, threshold: Union[int, float] = 127.0, **kwargs: Any) -> None:
super().__init__(root, split, **kwargs)
self.threshold = threshold

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
"""
Args:
index: Index
Expand All @@ -29,18 +60,16 @@ def __getitem__(self, idx):

img, target = self.data[idx], int(self.targets[idx])

# doing this so that it is consistent with all other datasets
# to return a PIL Image
# Convert to PIL Image (8-bit BW)
img = Image.fromarray(img.numpy(), mode="L")

# Binarize image at threshold
img = img.point(lambda p: 255 if p > self.threshold else 0)

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

# binary
img[img < 0.5] = 0.0
img[img >= 0.5] = 1.0

return img, target
72 changes: 38 additions & 34 deletions pl_bolts/datasets/mnist_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1
from pl_bolts.utils.stability import under_review
from typing import Any, Tuple, Union

from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -13,36 +14,41 @@
else: # pragma: no cover
warn_missing_pkg("PIL", pypi_name="Pillow")

# TODO(akihironitta): This is needed to avoid 503 error when downloading MNIST dataset
# from http://yann.lecun.com/exdb/mnist/ and can be removed after `torchvision==0.9.1`.
# See https://github.com/pytorch/vision/issues/3549 for details.
if _TORCHVISION_AVAILABLE and _TORCHVISION_LESS_THAN_0_9_1:
MNIST.resources = [
(
"https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz",
"f68b3c2dcbeaaa9fbdd348bbdeb94873",
),
(
"https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz",
"d53e105ee54ea40749a09fcbcd1e9432",
),
(
"https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz",
"9fb629c4189551a2d022fa330f9573f3",
),
(
"https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz",
"ec29112dd5afa0611ce80d1b7f02629c",
),
]


@under_review()

class BinaryMNIST(MNIST):
def __getitem__(self, idx):
"""Binarized MNIST Dataset.
MNIST dataset binarized using a thresholding operation. Default threshold value is 127.
Note that the images are binarized prior to the application of any transforms.
Args:
root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
threshold (Union[int, float], optional): Threshold value for binarizing image.
Pixel value is set to 255 if value is greater than threshold, otherwise 0.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
Note:
Documentation is based on https://pytorch.org/vision/main/generated/torchvision.datasets.EMNIST.html
"""

def __init__(self, root: str, threshold: Union[int, float] = 127.0, **kwargs: Any) -> None:
super().__init__(root, **kwargs)
self.threshold = threshold

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
Expand All @@ -51,18 +57,16 @@ def __getitem__(self, idx):

img, target = self.data[idx], int(self.targets[idx])

# doing this so that it is consistent with all other datasets
# to return a PIL Image
# Convert to PIL Image (8-bit BW)
img = Image.fromarray(img.numpy(), mode="L")

# Binarize image at threshold
img = img.point(lambda p: 255 if p > self.threshold else 0)

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

# binary
img[img < 0.5] = 0.0
img[img >= 0.5] = 1.0

return img, target
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ module = [
"pl_bolts.datasets.cifar10_dataset",
"pl_bolts.datasets.concat_dataset",
"pl_bolts.datasets.dummy_dataset",
"pl_bolts.datasets.emnist_dataset",
"pl_bolts.datasets.imagenet_dataset",
"pl_bolts.datasets.kitti_dataset",
"pl_bolts.datasets.mnist_dataset",
"pl_bolts.datasets.sr_celeba_dataset",
"pl_bolts.datasets.sr_mnist_dataset",
"pl_bolts.datasets.sr_stl10_dataset",
Expand Down
30 changes: 23 additions & 7 deletions tests/datamodules/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,18 @@ def test_vision_data_module(datadir, val_split, catch_warnings, train_len):

@pytest.mark.parametrize("dm_cls", [BinaryMNISTDataModule, CIFAR10DataModule, FashionMNISTDataModule, MNISTDataModule])
def test_data_modules(datadir, catch_warnings, dm_cls):
"""Test datamodules train, val, and test dataloaders outputs have correct shape."""
dm = _create_dm(dm_cls, datadir)
loader = dm.train_dataloader()
img, _ = next(iter(loader))
train_loader = dm.train_dataloader()
img, _ = next(iter(train_loader))
assert img.size() == torch.Size([2, *dm.dims])

val_loader = dm.val_dataloader()
img, _ = next(iter(val_loader))
assert img.size() == torch.Size([2, *dm.dims])

test_loader = dm.test_dataloader()
img, _ = next(iter(test_loader))
assert img.size() == torch.Size([2, *dm.dims])


Expand All @@ -104,12 +113,19 @@ def test_sr_datamodule(datadir):
@pytest.mark.parametrize("split", ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"])
@pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule])
def test_emnist_datamodules(datadir, dm_cls, split):
"""Test EMNIST datamodules download data and have the correct shape."""

"""Test BinaryEMNIST and EMNIST datamodules download data and have the correct shape."""
dm = _create_dm(dm_cls, datadir, split=split)
loader = dm.train_dataloader()
img, _ = next(iter(loader))
assert img.size() == torch.Size([2, 1, 28, 28])
train_loader = dm.train_dataloader()
img, _ = next(iter(train_loader))
assert img.size() == torch.Size([2, *dm.dims])

val_loader = dm.val_dataloader()
img, _ = next(iter(val_loader))
assert img.size() == torch.Size([2, *dm.dims])

test_loader = dm.test_dataloader()
img, _ = next(iter(test_loader))
assert img.size() == torch.Size([2, *dm.dims])


@pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule])
Expand Down
39 changes: 37 additions & 2 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import pytest
import torch
from torch.utils.data import DataLoader, Dataset

from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset
from torchvision import transforms as transform_lib

from pl_bolts.datasets import (
BinaryEMNIST,
BinaryMNIST,
DummyDataset,
RandomDataset,
RandomDictDataset,
RandomDictStringDataset,
)
from pl_bolts.datasets.dummy_dataset import DummyDetectionDataset
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST

Expand Down Expand Up @@ -129,3 +137,30 @@ def test_sr_datasets(datadir, scale_factor):
assert torch.allclose(hr_image.max(), torch.tensor(1.0), atol=atol)
assert torch.allclose(lr_image.min(), torch.tensor(0.0), atol=atol)
assert torch.allclose(lr_image.max(), torch.tensor(1.0), atol=atol)


def test_binary_mnist_dataset(datadir):
"""Check BinaryMNIST image and target dimensions and value range."""
dl = DataLoader(BinaryMNIST(root=datadir, download=True, transform=transform_lib.ToTensor()))
img, target = next(iter(dl))

assert img.size() == torch.Size([1, 1, 28, 28])
assert target.size() == torch.Size([1])

assert torch.allclose(img.min(), torch.tensor(0.0))
assert torch.allclose(img.max(), torch.tensor(1.0))
assert torch.equal(torch.unique(img), torch.tensor([0.0, 1.0]))


@pytest.mark.parametrize("split", ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"])
def test_binary_emnist_dataset(datadir, split):
"""Check BinaryEMNIST image and target dimensions and value range for each split."""
dl = DataLoader(BinaryEMNIST(root=datadir, split=split, download=True, transform=transform_lib.ToTensor()))
img, target = next(iter(dl))

assert img.size() == torch.Size([1, 1, 28, 28])
assert target.size() == torch.Size([1])

assert torch.allclose(img.min(), torch.tensor(0.0))
assert torch.allclose(img.max(), torch.tensor(1.0))
assert torch.equal(torch.unique(img), torch.tensor([0.0, 1.0]))

0 comments on commit 7d2a9a1

Please sign in to comment.