Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MNIST 503 error by changing URL to AWS S3 #633

Merged
merged 15 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## [Unreleased] - yyyy-mm-dd

### Fixed

- Fixed the MNIST download giving HTTP 503 ([#633](https://github.com/PyTorchLightning/lightning-bolts/pull/633))


## [0.3.3] - 2021-04-17

### Changed
Expand Down
3 changes: 1 addition & 2 deletions pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import Any, Callable, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import MNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import MNIST
else: # pragma: no cover
warn_missing_pkg('torchvision')
MNIST = None


class MNISTDataModule(VisionDataModule):
Expand Down
3 changes: 2 additions & 1 deletion pl_bolts/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from pl_bolts.datasets.imagenet_dataset import extract_archive, parse_devkit_archive, UnlabeledImagenet
from pl_bolts.datasets.kitti_dataset import KittiDataset
from pl_bolts.datasets.mnist_dataset import BinaryMNIST
from pl_bolts.datasets.mnist_dataset import BinaryMNIST, MNIST
from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed, SSLDatasetMixin

__all__ = [
Expand All @@ -22,6 +22,7 @@
"ConcatDataset",
"DummyDataset",
"DummyDetectionDataset",
"MNIST",
"RandomDataset",
"RandomDictDataset",
"RandomDictStringDataset",
Expand Down
13 changes: 12 additions & 1 deletion pl_bolts/datasets/mnist_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -12,6 +12,17 @@
else: # pragma: no cover
warn_missing_pkg('PIL', pypi_name='Pillow')

# TODO(akihironitta): This is needed to avoid 503 error when downloading MNIST dataset
# from http://yann.lecun.com/exdb/mnist/ and can be removed after `torchvision==0.9.1`.
# See https://github.com/pytorch/vision/issues/3549 for details.
if _TORCHVISION_AVAILABLE and _TORCHVISION_LESS_THAN_0_9_1:
MNIST.resources = [
("https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), # noqa: E501
("https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), # noqa: E501
("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), # noqa: E501
("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), # noqa: E501
]


class BinaryMNIST(MNIST):

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/mnist_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from pl_bolts.datasets import MNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision import transforms
from torchvision.datasets import MNIST
else: # pragma: no cover
warn_missing_pkg('torchvision')

Expand Down
27 changes: 27 additions & 0 deletions pl_bolts/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,33 @@
import importlib
import operator

import torch
from packaging.version import Version
from pkg_resources import DistributionNotFound
from pytorch_lightning.utilities import _module_available

from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore


# Ported from https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/imports.py
def _compare_version(package: str, op, version) -> bool:
"""
Compare package version with some requirements
>>> _compare_version("torch", operator.ge, "0.1")
True
"""
try:
pkg = importlib.import_module(package)
except (ModuleNotFoundError, DistributionNotFound):
return False
try:
pkg_version = Version(pkg.__version__)
except TypeError:
# this is mock by sphinx, so it shall return True ro generate all summaries
return True
return op(pkg_version, Version(version))


_NATIVE_AMP_AVAILABLE: bool = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")

_TORCHVISION_AVAILABLE: bool = _module_available("torchvision")
Expand All @@ -12,5 +37,7 @@
_OPENCV_AVAILABLE: bool = _module_available("cv2")
_WANDB_AVAILABLE: bool = _module_available("wandb")
_MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib")
_TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.ge, "0.9.1")


__all__ = ["BatchGradientVerification"]
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch>=1.6
torchmetrics>=0.2.0
pytorch-lightning>=1.1.1
dataclasses ; python_version <= "3.6"
dataclasses ; python_version <= "3.6"
packaging