From f7bed1130c6e0098bd0696042de0ef8c34bfbb5f Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 6 May 2021 23:45:41 +0900 Subject: [PATCH 01/10] Use s3 url for mnist --- pl_bolts/datamodules/mnist_datamodule.py | 3 +-- pl_bolts/datasets/__init__.py | 3 ++- pl_bolts/datasets/mnist_dataset.py | 14 +++++++++++++- pl_bolts/models/mnist_module.py | 2 +- pl_bolts/utils/__init__.py | 5 +++++ 5 files changed, 22 insertions(+), 5 deletions(-) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 0889d71d09..0c0e9cb1a1 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -1,15 +1,14 @@ from typing import Any, Callable, Optional, Union +from pl_bolts.datasets import MNIST from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib - from torchvision.datasets import MNIST else: # pragma: no cover warn_missing_pkg('torchvision') - MNIST = None class MNISTDataModule(VisionDataModule): diff --git a/pl_bolts/datasets/__init__.py b/pl_bolts/datasets/__init__.py index 191f37c10a..b7ac6c5fee 100644 --- a/pl_bolts/datasets/__init__.py +++ b/pl_bolts/datasets/__init__.py @@ -12,7 +12,7 @@ ) from pl_bolts.datasets.imagenet_dataset import extract_archive, parse_devkit_archive, UnlabeledImagenet from pl_bolts.datasets.kitti_dataset import KittiDataset -from pl_bolts.datasets.mnist_dataset import BinaryMNIST +from pl_bolts.datasets.mnist_dataset import BinaryMNIST, MNIST from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed, SSLDatasetMixin __all__ = [ @@ -22,6 +22,7 @@ "ConcatDataset", "DummyDataset", "DummyDetectionDataset", + "MNIST", "RandomDataset", "RandomDictDataset", "RandomDictStringDataset", diff --git a/pl_bolts/datasets/mnist_dataset.py b/pl_bolts/datasets/mnist_dataset.py index 31019d6abd..6b65dacbb2 100644 --- a/pl_bolts/datasets/mnist_dataset.py +++ b/pl_bolts/datasets/mnist_dataset.py @@ -1,4 +1,4 @@ -from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1 from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -13,6 +13,18 @@ warn_missing_pkg('PIL', pypi_name='Pillow') +# TODO(akihironitta): This is needed to avoid 503 error when downloading MNIST dataset +# from http://yann.lecun.com/exdb/mnist/ and can be removed after `torchvision==0.9.1`. +# See https://github.com/pytorch/vision/issues/3549 for details. +if _TORCHVISION_LESS_THAN_0_9_1: + MNIST.resources = [ + ("https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), + ("https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), + ("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), + ("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), + ] + + class BinaryMNIST(MNIST): def __getitem__(self, idx): diff --git a/pl_bolts/models/mnist_module.py b/pl_bolts/models/mnist_module.py index be6c4ee623..3ca742eeed 100644 --- a/pl_bolts/models/mnist_module.py +++ b/pl_bolts/models/mnist_module.py @@ -5,12 +5,12 @@ from torch.nn import functional as F from torch.utils.data import DataLoader, random_split +from pl_bolts.datasets import MNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets import MNIST else: # pragma: no cover warn_missing_pkg('torchvision') diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index 0a49d730c4..3baf5a53ae 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -1,5 +1,8 @@ +import operator + import torch from pytorch_lightning.utilities import _module_available +from pytorch_lightning.utilities.imports import _compare_version from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore @@ -12,5 +15,7 @@ _OPENCV_AVAILABLE: bool = _module_available("cv2") _WANDB_AVAILABLE: bool = _module_available("wandb") _MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib") +_TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.ge, "0.9.1") + __all__ = ["BatchGradientVerification"] From dc3af28210acff63c8f879578ee396bf9edb9412 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 6 May 2021 23:54:39 +0900 Subject: [PATCH 02/10] flake8 --- pl_bolts/datasets/mnist_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pl_bolts/datasets/mnist_dataset.py b/pl_bolts/datasets/mnist_dataset.py index 6b65dacbb2..58746ac60e 100644 --- a/pl_bolts/datasets/mnist_dataset.py +++ b/pl_bolts/datasets/mnist_dataset.py @@ -18,10 +18,10 @@ # See https://github.com/pytorch/vision/issues/3549 for details. if _TORCHVISION_LESS_THAN_0_9_1: MNIST.resources = [ - ("https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), - ("https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), - ("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), - ("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), + ("https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), # noqa: E501 + ("https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), # noqa: E501 + ("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), # noqa: E501 + ("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), # noqa: E501 ] From 08c5ac4df745fb2d69687696d7ac44ccaade6eb6 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 6 May 2021 23:57:41 +0900 Subject: [PATCH 03/10] Update changelog --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 324b1e7fd9..2a5c44f774 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] - yyyy-mm-dd + +### Fixed + +- Fixed the MNIST download giving HTTP 503 ([#633](https://github.com/PyTorchLightning/lightning-bolts/pull/633)) + + ## [0.3.3] - 2021-04-17 ### Changed From 59f6c55ad13ab7ec923b608ebf766cb7017ef4c1 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 7 May 2021 00:11:00 +0900 Subject: [PATCH 04/10] formatting --- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datasets/mnist_dataset.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 0c0e9cb1a1..8ba86f2403 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Optional, Union -from pl_bolts.datasets import MNIST from pl_bolts.datamodules.vision_datamodule import VisionDataModule +from pl_bolts.datasets import MNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/pl_bolts/datasets/mnist_dataset.py b/pl_bolts/datasets/mnist_dataset.py index 58746ac60e..e83c4620ca 100644 --- a/pl_bolts/datasets/mnist_dataset.py +++ b/pl_bolts/datasets/mnist_dataset.py @@ -12,7 +12,6 @@ else: # pragma: no cover warn_missing_pkg('PIL', pypi_name='Pillow') - # TODO(akihironitta): This is needed to avoid 503 error when downloading MNIST dataset # from http://yann.lecun.com/exdb/mnist/ and can be removed after `torchvision==0.9.1`. # See https://github.com/pytorch/vision/issues/3549 for details. From 886b4a159a74f37b694940adc4c58b3844eac572 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 7 May 2021 00:41:09 +0900 Subject: [PATCH 05/10] Port _compare_version from PL --- pl_bolts/utils/__init__.py | 24 +++++++++++++++++++++++- requirements.txt | 3 ++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index 3baf5a53ae..2cfd6d470a 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -1,11 +1,33 @@ +import importlib import operator import torch +from packaging.version import Version +from pkg_resources import DistributionNotFound from pytorch_lightning.utilities import _module_available -from pytorch_lightning.utilities.imports import _compare_version from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore + +# Ported from https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/imports.py +def _compare_version(package: str, op, version) -> bool: + """ + Compare package version with some requirements + >>> _compare_version("torch", operator.ge, "0.1") + True + """ + try: + pkg = importlib.import_module(package) + except (ModuleNotFoundError, DistributionNotFound): + return False + try: + pkg_version = Version(pkg.__version__) + except TypeError: + # this is mock by sphinx, so it shall return True ro generate all summaries + return True + return op(pkg_version, Version(version)) + + _NATIVE_AMP_AVAILABLE: bool = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") _TORCHVISION_AVAILABLE: bool = _module_available("torchvision") diff --git a/requirements.txt b/requirements.txt index 4c6f122d8c..2de8f7e1ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ torch>=1.6 torchmetrics>=0.2.0 pytorch-lightning>=1.1.1 -dataclasses ; python_version <= "3.6" \ No newline at end of file +dataclasses ; python_version <= "3.6" +packaging From 625da7b43678805719d0f79cd3575436daed627d Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 7 May 2021 00:51:14 +0900 Subject: [PATCH 06/10] . --- pl_bolts/datasets/mnist_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datasets/mnist_dataset.py b/pl_bolts/datasets/mnist_dataset.py index e83c4620ca..f3c143b825 100644 --- a/pl_bolts/datasets/mnist_dataset.py +++ b/pl_bolts/datasets/mnist_dataset.py @@ -15,7 +15,7 @@ # TODO(akihironitta): This is needed to avoid 503 error when downloading MNIST dataset # from http://yann.lecun.com/exdb/mnist/ and can be removed after `torchvision==0.9.1`. # See https://github.com/pytorch/vision/issues/3549 for details. -if _TORCHVISION_LESS_THAN_0_9_1: +if _TORCHVISION_AVAILABLE and _TORCHVISION_LESS_THAN_0_9_1: MNIST.resources = [ ("https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), # noqa: E501 ("https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), # noqa: E501 From 7d97fed147e10d25cc6e4f1500b5e726ae12478b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 May 2021 09:06:52 +0000 Subject: [PATCH 07/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_bolts/datasets/mnist_dataset.py | 16 ++++++++++++---- pl_bolts/utils/__init__.py | 1 - requirements.txt | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pl_bolts/datasets/mnist_dataset.py b/pl_bolts/datasets/mnist_dataset.py index f3c143b825..b9d599579b 100644 --- a/pl_bolts/datasets/mnist_dataset.py +++ b/pl_bolts/datasets/mnist_dataset.py @@ -17,10 +17,18 @@ # See https://github.com/pytorch/vision/issues/3549 for details. if _TORCHVISION_AVAILABLE and _TORCHVISION_LESS_THAN_0_9_1: MNIST.resources = [ - ("https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), # noqa: E501 - ("https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), # noqa: E501 - ("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), # noqa: E501 - ("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), # noqa: E501 + ( + "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", + "f68b3c2dcbeaaa9fbdd348bbdeb94873" + ), # noqa: E501 + ( + "https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz", + "d53e105ee54ea40749a09fcbcd1e9432" + ), # noqa: E501 + ("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz", + "9fb629c4189551a2d022fa330f9573f3"), # noqa: E501 + ("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz", + "ec29112dd5afa0611ce80d1b7f02629c"), # noqa: E501 ] diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index 2cfd6d470a..d597d0bcf7 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -39,5 +39,4 @@ def _compare_version(package: str, op, version) -> bool: _MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib") _TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.ge, "0.9.1") - __all__ = ["BatchGradientVerification"] diff --git a/requirements.txt b/requirements.txt index 0168079c72..2de8f7e1ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ torch>=1.6 torchmetrics>=0.2.0 pytorch-lightning>=1.1.1 dataclasses ; python_version <= "3.6" -packaging \ No newline at end of file +packaging From 0b6d6d9f68ddac6a10750dc7d4e58bba86f86657 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 12 May 2021 02:21:41 +0900 Subject: [PATCH 08/10] Update tests for trainer.fit returning None --- tests/models/rl/integration/test_value_models.py | 12 ------------ tests/models/test_autoencoders.py | 6 ++---- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/tests/models/rl/integration/test_value_models.py b/tests/models/rl/integration/test_value_models.py index a723a0a8f0..10bb91d6cd 100644 --- a/tests/models/rl/integration/test_value_models.py +++ b/tests/models/rl/integration/test_value_models.py @@ -39,39 +39,27 @@ def test_dqn(self): model = DQN(self.hparams.env, num_envs=5) result = self.trainer.fit(model) - self.assertEqual(result, 1) - def test_double_dqn(self): """Smoke test that the Double DQN model runs""" model = DoubleDQN(self.hparams.env) result = self.trainer.fit(model) - self.assertEqual(result, 1) - def test_dueling_dqn(self): """Smoke test that the Dueling DQN model runs""" model = DuelingDQN(self.hparams.env) result = self.trainer.fit(model) - self.assertEqual(result, 1) - def test_noisy_dqn(self): """Smoke test that the Noisy DQN model runs""" model = NoisyDQN(self.hparams.env) result = self.trainer.fit(model) - self.assertEqual(result, 1) - def test_per_dqn(self): """Smoke test that the PER DQN model runs""" model = PERDQN(self.hparams.env) result = self.trainer.fit(model) - self.assertEqual(result, 1) - # def test_n_step_dqn(self): # """Smoke test that the N Step DQN model runs""" # model = DQN(self.hparams.env, n_steps=self.hparams.n_steps) # result = self.trainer.fit(model) - # - # self.assertEqual(result, 1) diff --git a/tests/models/test_autoencoders.py b/tests/models/test_autoencoders.py index 322cb28774..36bfb7b1cb 100644 --- a/tests/models/test_autoencoders.py +++ b/tests/models/test_autoencoders.py @@ -19,8 +19,7 @@ def test_vae(tmpdir, datadir, dm_cls): gpus=None, ) - result = trainer.fit(model, datamodule=dm) - assert result == 1 + trainer.fit(model, datamodule=dm) @pytest.mark.parametrize("dm_cls", [pytest.param(CIFAR10DataModule, id="cifar10")]) @@ -35,8 +34,7 @@ def test_ae(tmpdir, datadir, dm_cls): gpus=None, ) - result = trainer.fit(model, datamodule=dm) - assert result == 1 + trainer.fit(model, datamodule=dm) @torch.no_grad() From 0e07d8bfbc86310509964f8fa87be177639d8f7b Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 12 May 2021 02:26:32 +0900 Subject: [PATCH 09/10] Remove unused refs --- tests/models/rl/integration/test_value_models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/rl/integration/test_value_models.py b/tests/models/rl/integration/test_value_models.py index 10bb91d6cd..c127b81aa3 100644 --- a/tests/models/rl/integration/test_value_models.py +++ b/tests/models/rl/integration/test_value_models.py @@ -37,27 +37,27 @@ def setUp(self) -> None: def test_dqn(self): """Smoke test that the DQN model runs""" model = DQN(self.hparams.env, num_envs=5) - result = self.trainer.fit(model) + self.trainer.fit(model) def test_double_dqn(self): """Smoke test that the Double DQN model runs""" model = DoubleDQN(self.hparams.env) - result = self.trainer.fit(model) + self.trainer.fit(model) def test_dueling_dqn(self): """Smoke test that the Dueling DQN model runs""" model = DuelingDQN(self.hparams.env) - result = self.trainer.fit(model) + self.trainer.fit(model) def test_noisy_dqn(self): """Smoke test that the Noisy DQN model runs""" model = NoisyDQN(self.hparams.env) - result = self.trainer.fit(model) + self.trainer.fit(model) def test_per_dqn(self): """Smoke test that the PER DQN model runs""" model = PERDQN(self.hparams.env) - result = self.trainer.fit(model) + self.trainer.fit(model) # def test_n_step_dqn(self): # """Smoke test that the N Step DQN model runs""" From 172a45c613f1492cb51986d08944112e9484cb50 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 12 May 2021 02:38:30 +0900 Subject: [PATCH 10/10] Update tests for trainer.fit returning None --- tests/models/rl/integration/test_policy_models.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/models/rl/integration/test_policy_models.py b/tests/models/rl/integration/test_policy_models.py index 23c8b510d2..440c1465c4 100644 --- a/tests/models/rl/integration/test_policy_models.py +++ b/tests/models/rl/integration/test_policy_models.py @@ -30,13 +30,9 @@ def test_reinforce(self): """Smoke test that the reinforce model runs""" model = Reinforce(self.hparams.env) - result = self.trainer.fit(model) - - self.assertEqual(result, 1) + self.trainer.fit(model) def test_policy_gradient(self): """Smoke test that the policy gradient model runs""" model = VanillaPolicyGradient(self.hparams.env) - result = self.trainer.fit(model) - - self.assertEqual(result, 1) + self.trainer.fit(model)