diff --git a/pl_bolts/transforms/dataset_normalizations.py b/pl_bolts/transforms/dataset_normalizations.py index fcf919b5d7..71b3a37260 100644 --- a/pl_bolts/transforms/dataset_normalizations.py +++ b/pl_bolts/transforms/dataset_normalizations.py @@ -1,5 +1,6 @@ +from typing import Callable + from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -8,19 +9,17 @@ warn_missing_pkg("torchvision") -@under_review() -def imagenet_normalization(): +def imagenet_normalization() -> Callable: if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." ) - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) return normalize -@under_review() -def cifar10_normalization(): +def cifar10_normalization() -> Callable: if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." @@ -33,8 +32,7 @@ def cifar10_normalization(): return normalize -@under_review() -def stl10_normalization(): +def stl10_normalization() -> Callable: if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." @@ -44,8 +42,7 @@ def stl10_normalization(): return normalize -@under_review() -def emnist_normalization(split: str): +def emnist_normalization(split: str) -> Callable: if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." diff --git a/tests/transforms/test_normalizations.py b/tests/transforms/test_normalizations.py new file mode 100644 index 0000000000..bd3af643e3 --- /dev/null +++ b/tests/transforms/test_normalizations.py @@ -0,0 +1,35 @@ +import pytest +import torch +from pytorch_lightning import seed_everything + +from pl_bolts.transforms.dataset_normalizations import ( + cifar10_normalization, + emnist_normalization, + imagenet_normalization, + stl10_normalization, +) + + +@pytest.mark.parametrize( + "normalization", + [cifar10_normalization, imagenet_normalization, stl10_normalization], +) +def test_normalizations(normalization, catch_warnings): + """Test normalizations for CIFAR10, ImageNet, STL10.""" + seed_everything(1234) + x = torch.rand(3, 32, 32) + assert normalization()(x).shape == (3, 32, 32) + assert x.min() >= 0.0 + assert x.max() <= 1.0 + + +@pytest.mark.parametrize( + "split", + ["balanced", "byclass", "bymerge", "digits", "letters", "mnist"], +) +def test_emnist_normalizations(split, catch_warnings): + """Test normalizations for each EMNIST dataset split.""" + x = torch.rand(1, 28, 28) + assert emnist_normalization(split)(x).shape == (1, 28, 28) + assert x.min() >= 0.0 + assert x.max() <= 1.0