Skip to content

Commit

Permalink
fix datamodule imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 23, 2020
1 parent b288c81 commit 264bc25
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 36 deletions.
69 changes: 69 additions & 0 deletions pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,98 @@
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader

__all__ = []

try:
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['BinaryMNISTDataModule']

try:
from pl_bolts.datamodules.cifar10_datamodule import (
CIFAR10DataModule,
TinyCIFAR10DataModule,
)
except ModuleNotFoundError:
pass
else:
__all__ += ['CIFAR10DataModule', 'TinyCIFAR10DataModule']

try:
from pl_bolts.datamodules.experience_source import (
ExperienceSourceDataset,
ExperienceSource,
DiscountedExperienceSource,
)
except ModuleNotFoundError:
pass
else:
__all__ += ['ExperienceSourceDataset', 'ExperienceSource', 'DiscountedExperienceSource']

try:
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['FashionMNISTDataModule']

try:
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['ImagenetDataModule']

try:
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['MNISTDataModule']

try:
from pl_bolts.datamodules.sklearn_datamodule import (
SklearnDataset,
SklearnDataModule,
TensorDataset,
)
except ModuleNotFoundError:
pass
else:
__all__ += ['SklearnDataset', 'SklearnDataModule', 'TensorDataset']

try:
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['SSLImagenetDataModule']

try:
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['STL10DataModule']

try:
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['VOCDetectionDataModule']

try:
from pl_bolts.datasets.kitti_dataset import KittiDataset
except ModuleNotFoundError:
pass
else:
__all__ += ['KittiDataset']

try:
from pl_bolts.datamodules.kitti_datamodule import KittiDataModule
except ModuleNotFoundError:
pass
else:
__all__ += ['KittiDataModule']
8 changes: 8 additions & 0 deletions pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@

try:
from torchvision import transforms as transform_lib
except ModuleNotFoundError:
warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover
' install it with `pip install torchvision`.')
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True

try:
from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet
except ModuleNotFoundError:
warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover
Expand Down
5 changes: 2 additions & 3 deletions pl_bolts/datamodules/sklearn_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
from typing import Any
from warnings import warn

import numpy as np
import torch
Expand All @@ -10,8 +9,8 @@
try:
from sklearn.utils import shuffle as sk_shuffle
except ModuleNotFoundError:
warn('You want to use `sklearn` which is not installed yet,' # pragma: no-cover
' install it with `pip install sklearn`.')
raise ModuleNotFoundError('You want to use `sklearn` which is not installed yet,' # pragma: no-cover
' install it with `pip install sklearn`.')
_SKLEARN_AVAILABLE = False
else:
_SKLEARN_AVAILABLE = True
Expand Down
8 changes: 8 additions & 0 deletions pl_bolts/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@
DummyDataset,
DummyDetectionDataset
)

__all__ = [
"RandomDictStringDataset",
"RandomDictDataset",
"RandomDataset",
"DummyDataset",
"DummyDetectionDataset",
]
25 changes: 8 additions & 17 deletions pl_bolts/datasets/imagenet_dataset.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
import gzip
import hashlib
import importlib
import os
import shutil
import tarfile
import tempfile
import zipfile
from contextlib import contextmanager
from warnings import warn

import numpy as np
import torch
from torch._six import PY3

_SKLEARN_AVAILABLE = importlib.util.find_spec("sklearn") is not None
if _SKLEARN_AVAILABLE:
from sklearn.utils import shuffle as sk_shuffle
else:
warn('You want to use `sklearn` which is not installed yet,' # pragma: no-cover
' install it with `pip install sklearn`.')

try:
from torchvision.datasets import ImageNet
from torchvision.datasets.imagenet import load_meta_file
Expand Down Expand Up @@ -74,13 +66,9 @@ def __init__(
super(ImageNet, self).__init__(self.split_folder, **kwargs)
self.root = root

if not _SKLEARN_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `shuffle` function from `scikit-learn` which is not installed yet.'
)

# shuffle images first
self.imgs = sk_shuffle(self.imgs, random_state=1234)

shuffle(self.imgs, random_state=1234)

# partition train set into [train, val]
if split == 'train':
Expand All @@ -105,7 +93,9 @@ def __init__(
# limit the number of classes
if num_classes != -1:
# choose the classes at random (but deterministic)
ok_classes = sk_shuffle(list(range(num_classes)), random_state=1234)
ok_classes = list(range(num_classes))
np.random.seed(1234)
np.random.shuffle(ok_classes)
ok_classes = ok_classes[:num_classes]
ok_classes = set(ok_classes)

Expand All @@ -117,7 +107,8 @@ def __init__(
self.imgs = clean_imgs

# shuffle again for final exit
self.imgs = sk_shuffle(self.imgs, random_state=1234)
np.random.seed(1234)
np.random.shuffle(self.imgs)

# list of class_nbs for each image
idcs = [idx for _, idx in self.imgs]
Expand Down
15 changes: 2 additions & 13 deletions pl_bolts/datasets/ssl_amdim_datasets.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
from abc import ABC
import importlib
from typing import Callable, Optional
from warnings import warn

import numpy as np

_SKLEARN_AVAILABLE = importlib.util.find_spec("sklearn") is not None
if _SKLEARN_AVAILABLE:
from sklearn.utils import shuffle as sk_shuffle
else:
warn('You want to use `sklearn` which is not installed yet,' # pragma: no-cover
' install it with `pip install sklearn`.')

try:
from torchvision.datasets import CIFAR10
except ModuleNotFoundError:
Expand Down Expand Up @@ -83,14 +75,11 @@ def select_nb_imgs_per_class(cls, examples, labels, nb_imgs_in_val):

@classmethod
def deterministic_shuffle(cls, x, y):
if not _SKLEARN_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `shuffle` function from `scikit-learn` which is not installed yet.'
)

n = len(x)
idxs = list(range(0, n))
idxs = sk_shuffle(idxs, random_state=1234)
np.random.seed(1234)
np.random.shuffle(idxs)

x = x[idxs]

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
from pl_bolts.models.mnist_module import LitMNIST
from pl_bolts.models.regression import LinearRegression, LogisticRegression
from pl_bolts.models.vision import PixelCNN
from pl_bolts.models.vision import UNet
from pl_bolts.models.vision import SemSegment
from pl_bolts.models.vision import UNet
from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT
2 changes: 1 addition & 1 deletion pl_bolts/models/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from pl_bolts.models.vision.pixel_cnn import PixelCNN
from pl_bolts.models.vision.unet import UNet
from pl_bolts.models.vision.segmentation import SemSegment
from pl_bolts.models.vision.unet import UNet
2 changes: 1 addition & 1 deletion pl_bolts/models/vision/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from argparse import ArgumentParser, Namespace
from argparse import ArgumentParser

import pytorch_lightning as pl
import torch
Expand Down

0 comments on commit 264bc25

Please sign in to comment.