Skip to content

Commit

Permalink
Revision of dataset normalizations (#898)
Browse files Browse the repository at this point in the history
  • Loading branch information
matsumotosan authored Oct 11, 2022
1 parent a8bbfc9 commit 7f8ede8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
17 changes: 7 additions & 10 deletions pl_bolts/transforms/dataset_normalizations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -8,19 +9,17 @@
warn_missing_pkg("torchvision")


@under_review()
def imagenet_normalization():
def imagenet_normalization() -> Callable:
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`."
)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
return normalize


@under_review()
def cifar10_normalization():
def cifar10_normalization() -> Callable:
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`."
Expand All @@ -33,8 +32,7 @@ def cifar10_normalization():
return normalize


@under_review()
def stl10_normalization():
def stl10_normalization() -> Callable:
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`."
Expand All @@ -44,8 +42,7 @@ def stl10_normalization():
return normalize


@under_review()
def emnist_normalization(split: str):
def emnist_normalization(split: str) -> Callable:
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`."
Expand Down
35 changes: 35 additions & 0 deletions tests/transforms/test_normalizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
import torch
from pytorch_lightning import seed_everything

from pl_bolts.transforms.dataset_normalizations import (
cifar10_normalization,
emnist_normalization,
imagenet_normalization,
stl10_normalization,
)


@pytest.mark.parametrize(
"normalization",
[cifar10_normalization, imagenet_normalization, stl10_normalization],
)
def test_normalizations(normalization, catch_warnings):
"""Test normalizations for CIFAR10, ImageNet, STL10."""
seed_everything(1234)
x = torch.rand(3, 32, 32)
assert normalization()(x).shape == (3, 32, 32)
assert x.min() >= 0.0
assert x.max() <= 1.0


@pytest.mark.parametrize(
"split",
["balanced", "byclass", "bymerge", "digits", "letters", "mnist"],
)
def test_emnist_normalizations(split, catch_warnings):
"""Test normalizations for each EMNIST dataset split."""
x = torch.rand(1, 28, 28)
assert emnist_normalization(split)(x).shape == (1, 28, 28)
assert x.min() >= 0.0
assert x.max() <= 1.0

0 comments on commit 7f8ede8

Please sign in to comment.