Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update examples - use DataModule #4740

Merged
merged 11 commits into from
Nov 20, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pl_examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os

from pytorch_lightning.utilities import _module_available

EXAMPLES_ROOT = os.path.dirname(__file__)
PACKAGE_ROOT = os.path.dirname(EXAMPLES_ROOT)
DATASETS_PATH = os.path.join(PACKAGE_ROOT, 'Datasets')

TORCHVISION_AVAILABLE = _module_available("torchvision")
DALI_AVAILABLE = _module_available("nvidia.dali")
11 changes: 7 additions & 4 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@
# limitations under the License.

from argparse import ArgumentParser

import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch.utils.data import random_split

try:
import pytorch_lightning as pl
from pl_examples import TORCHVISION_AVAILABLE

if TORCHVISION_AVAILABLE:
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
except ModuleNotFoundError:
else:
Borda marked this conversation as resolved.
Show resolved Hide resolved
from tests.base.datasets import MNIST


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
from argparse import ArgumentParser

import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

try:
import pytorch_lightning as pl
from pl_examples import DATASETS_PATH, TORCHVISION_AVAILABLE

if TORCHVISION_AVAILABLE:
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
except Exception as e:
else:
from tests.base.datasets import MNIST


Expand Down Expand Up @@ -96,8 +98,8 @@ def cli_main():
# ------------
# data
# ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST(DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@
from torch.utils.data import random_split

import pytorch_lightning as pl
from pl_examples import TORCHVISION_AVAILABLE, DALI_AVAILABLE

try:
if TORCHVISION_AVAILABLE:
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
except Exception:
else:
from tests.base.datasets import MNIST

try:
if DALI_AVAILABLE:
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
except (ImportError, ModuleNotFoundError):
else:
warn('NVIDIA DALI is not available')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this isn't related to this PR, but shouldn't this be a hard crash if DALI isn't available?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, the case is that with tests running on CPU where we do not install DALI the testing will fail just because of this raising error, an alternative would be to return from this script if DALI is missing without crashing...

ops, types, Pipeline, DALIClassificationIterator = ..., ..., ABC, ABC
ops, Pipeline, DALIClassificationIterator = ..., ABC, ABC


class ExternalMNISTInputIterator(object):
Expand Down
132 changes: 132 additions & 0 deletions pl_examples/basic_examples/mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from torch.utils.data import DataLoader, random_split

from pl_examples import DATASETS_PATH, TORCHVISION_AVAILABLE
from pytorch_lightning import LightningDataModule

if TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import MNIST
else:
from tests.base.datasets import MNIST


class MNISTDataModule(LightningDataModule):
"""
Standard MNIST, train, val, test splits and transforms
"""

name = "mnist"

def __init__(
self,
data_dir: str = DATASETS_PATH,
val_split: int = 5000,
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
batch_size: int = 32,
*args,
**kwargs,
):
"""
Args:
data_dir: where to save/load the data
val_split: how many of the training images to use for the validation split
num_workers: how many workers to use for loading data
normalize: If true applies image normalize
"""
super().__init__(*args, **kwargs)

self.dims = (1, 28, 28)
self.data_dir = data_dir
self.val_split = val_split
self.num_workers = num_workers
self.normalize = normalize
self.seed = seed
self.batch_size = batch_size
self.dataset_train = ...
self.dataset_val = ...
self.test_transforms = self.default_transforms

@property
def num_classes(self):
return 10

def prepare_data(self):
"""Saves MNIST files to `data_dir`"""
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)

def setup(self, stage: Optional[str] = None):
"""Split the train and valid dataset"""
extra = dict(transform=self.default_transforms) if self.default_transforms else {}
dataset = MNIST(self.data_dir, train=True, download=False, **extra)
train_length = len(dataset)
self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split])

def train_dataloader(self):
"""MNIST train set removes a subset to use for validation"""
loader = DataLoader(
self.dataset_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True,
)
return loader

def val_dataloader(self):
"""MNIST val set uses a subset of the training set for validation"""
loader = DataLoader(
self.dataset_val,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True,
)
return loader

def test_dataloader(self):
"""MNIST test set uses the test split"""
extra = dict(transform=self.test_transforms) if self.test_transforms else {}
dataset = MNIST(self.data_dir, train=False, download=False, **extra)
loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True,
)
return loader

@property
def default_transforms(self):
if not TORCHVISION_AVAILABLE:
return None
if self.normalize:
mnist_transforms = transform_lib.Compose(
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
)
else:
mnist_transforms = transform_lib.ToTensor()

return mnist_transforms
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argparse import ArgumentParser
from pprint import pprint

import torch
from torch.utils.data import random_split, DataLoader

import pytorch_lightning as pl
from torch.nn import functional as F

try:
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
except Exception as e:
from tests.base.datasets import MNIST
import pytorch_lightning as pl
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule


class LitClassifier(pl.LightningModule):
Expand Down Expand Up @@ -76,21 +72,15 @@ def cli_main():
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
parser = MNISTDataModule.add_argparse_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
dm = MNISTDataModule.from_argparse_args(args)

# ------------
# model
Expand All @@ -101,12 +91,13 @@ def cli_main():
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)
trainer.fit(model, datamodule=dm)

# ------------
# testing
# ------------
result = trainer.test(test_dataloaders=test_loader)
result = trainer.test(datamodule=dm)
pprint(result)


if __name__ == '__main__':
Expand Down
21 changes: 8 additions & 13 deletions pl_examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
import pytest
import torch

try:
from nvidia.dali import ops, types, pipeline, plugin
except (ImportError, ModuleNotFoundError):
DALI_AVAILABLE = False
else:
DALI_AVAILABLE = True
from pl_examples import DALI_AVAILABLE

ARGS_DEFAULT = """
--max_epochs 1 \
Expand Down Expand Up @@ -38,8 +33,8 @@

# ToDo: fix this failing example
# @pytest.mark.parametrize('import_cli', [
# 'pl_examples.basic_examples.mnist_classifier',
# 'pl_examples.basic_examples.image_classifier',
# 'pl_examples.basic_examples.simple_image_classifier',
# 'pl_examples.basic_examples.backbone_image_classifier',
# 'pl_examples.basic_examples.autoencoder',
# ])
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
Expand All @@ -54,8 +49,8 @@

# ToDo: fix this failing example
# @pytest.mark.parametrize('import_cli', [
# 'pl_examples.basic_examples.mnist_classifier',
# 'pl_examples.basic_examples.image_classifier',
# 'pl_examples.basic_examples.simple_image_classifier',
# 'pl_examples.basic_examples.backbone_image_classifier',
# 'pl_examples.basic_examples.autoencoder',
# ])
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
Expand All @@ -69,8 +64,8 @@


@pytest.mark.parametrize('import_cli', [
'pl_examples.basic_examples.mnist_classifier',
'pl_examples.basic_examples.image_classifier',
'pl_examples.basic_examples.simple_image_classifier',
'pl_examples.basic_examples.backbone_image_classifier',
'pl_examples.basic_examples.autoencoder',
])
@pytest.mark.parametrize('cli_args', [ARGS_DEFAULT])
Expand All @@ -87,7 +82,7 @@ def test_examples_cpu(import_cli, cli_args):
@pytest.mark.skipif(platform.system() != 'Linux', reason='Only applies to Linux platform.')
@pytest.mark.parametrize('cli_args', [ARGS_GPU])
def test_examples_mnist_dali(cli_args):
from pl_examples.basic_examples.mnist_classifier_dali import cli_main
from pl_examples.basic_examples.dali_image_classifier import cli_main

with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
cli_main()
26 changes: 20 additions & 6 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""
import importlib
from enum import Enum

import numpy
Expand All @@ -21,13 +22,26 @@
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable

try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True

def _module_available(module_path: str) -> bool:
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Testing if given module is avalaible in your env

>>> _module_available('system')
True
>>> _module_available('bla.bla')
False
"""
mods = module_path.split('.')
assert mods, 'nothing given to test'
# it has to be tested as per partets
for i in range(1, len(mods)):
module_path = '.'.join(mods[:i])
if importlib.util.find_spec(module_path) is None:
return False
return True


APEX_AVAILABLE = _module_available("apex.amp")
NATIVE_AMP_AVALAIBLE = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
Expand Down