Skip to content

Commit

Permalink
Decouple DataModules from Models - GAN (#206)
Browse files Browse the repository at this point in the history
* 🎨 decouple dms from gan

* ✅ update tests

* ✅ update tests

* 💄 style

* 🐛 Gan now has required args

* 🐛 use argparse not hyperopt parser
  • Loading branch information
nateraw authored Sep 12, 2020
1 parent d1b3119 commit 25f8b03
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 60 deletions.
82 changes: 33 additions & 49 deletions pl_bolts/models/gans/basic/basic_gan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
import torch
from torch.nn import functional as F

from pl_bolts.datamodules import MNISTDataModule
from pl_bolts.models.gans.basic.components import Generator, Discriminator


class GAN(pl.LightningModule):

def __init__(self,
datamodule: pl.LightningDataModule = None,
latent_dim: int = 32,
batch_size: int = 100,
learning_rate: float = 0.0002,
data_dir: str = '',
num_workers: int = 8,
**kwargs):
def __init__(
self,
input_channels: int,
input_height: int,
input_width: int,
latent_dim: int = 32,
learning_rate: float = 0.0002,
**kwargs
):
"""
Vanilla GAN implementation.
Expand Down Expand Up @@ -53,24 +53,12 @@ def __init__(self,

# makes self.hparams under the hood and saves to ckpt
self.save_hyperparameters()

self._set_default_datamodule(datamodule)
self.img_dim = (input_channels, input_height, input_width)

# networks
self.generator = self.init_generator(self.img_dim)
self.discriminator = self.init_discriminator(self.img_dim)

def _set_default_datamodule(self, datamodule):
# link default data
if datamodule is None:
datamodule = MNISTDataModule(
data_dir=self.hparams.data_dir,
num_workers=self.hparams.num_workers,
normalize=True
)
self.datamodule = datamodule
self.img_dim = self.datamodule.size()

def init_generator(self, img_dim):
generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_dim)
return generator
Expand Down Expand Up @@ -179,44 +167,40 @@ def add_model_specific_args(parent_parser):
help="adam: decay of first order momentum of gradient")
parser.add_argument('--latent_dim', type=int, default=100,
help="generator embedding dim")
parser.add_argument('--batch_size', type=int, default=64, help="size of the batches")
parser.add_argument('--num_workers', type=int, default=8, help="num dataloader workers")
parser.add_argument('--data_dir', type=str, default=os.getcwd())

return parser


def cli_main():
def cli_main(args=None):
from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
from pl_bolts.datamodules import STL10DataModule, ImagenetDataModule
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule

pl.seed_everything(1234)

parser = ArgumentParser()
parser.add_argument('--dataset', type=str, default='mnist', help='mnist, stl10, imagenet2012')

parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet")
script_args, _ = parser.parse_known_args(args)

if script_args.dataset == "mnist":
dm_cls = MNISTDataModule
elif script_args.dataset == "cifar10":
dm_cls = CIFAR10DataModule
elif script_args.dataset == "stl10":
dm_cls = STL10DataModule
elif script_args.dataset == "imagenet":
dm_cls = ImagenetDataModule

parser = dm_cls.add_argparse_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
parser = GAN.add_model_specific_args(parser)
parser = ImagenetDataModule.add_argparse_args(parser)
args = parser.parse_args()

# default is mnist
datamodule = None
if args.dataset == 'imagenet2012':
datamodule = ImagenetDataModule.from_argparse_args(args)
elif args.dataset == 'stl10':
datamodule = STL10DataModule.from_argparse_args(args)

gan = GAN(**vars(args), datamodule=datamodule)
callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator()]
args = parser.parse_args(args)

trainer = pl.Trainer.from_argparse_args(
args,
callbacks=callbacks,
progress_bar_refresh_rate=10
)
trainer.fit(gan)
dm = dm_cls.from_argparse_args(args)
model = GAN(*dm.size(), **vars(args))
callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)]
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=20)
trainer.fit(model, dm)
return dm, model, trainer


if __name__ == '__main__':
cli_main()
dm, model, trainer = cli_main()
2 changes: 1 addition & 1 deletion tests/callbacks/test_variational_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self):
self.global_step = 1
self.logger = DummyLogger()

model = GAN()
model = GAN(3, 28, 28)
cb = LatentDimInterpolator(interpolate_epoch_interval=2)

cb.on_epoch_end(FakeTrainer(), model)
21 changes: 15 additions & 6 deletions tests/models/test_executable_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
import pytest


@pytest.mark.parametrize('cli_args', ['--max_epochs 1'
' --limit_train_batches 3'
' --limit_val_batches 3'
' --batch_size 3'])
def test_cli_basic_gan(cli_args):
@pytest.mark.parametrize(
"dataset_name", [
pytest.param('mnist', id="mnist"),
pytest.param('cifar10', id="cifar10")
]
)
def test_cli_basic_gan(dataset_name):
from pl_bolts.models.gans.basic.basic_gan_module import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = f"""
--dataset {dataset_name}
--max_epochs 1
--limit_train_batches 3
--limit_val_batches 3
--batch_size 3
""".strip().split()

with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()

Expand Down
14 changes: 10 additions & 4 deletions tests/models/test_gans.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import pytest
import pytorch_lightning as pl
from pytorch_lightning import seed_everything

from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule
from pl_bolts.models.gans import GAN


def test_gan(tmpdir):
@pytest.mark.parametrize(
"dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")]
)
def test_gan(tmpdir, dm_cls):
seed_everything()

model = GAN(data_dir=tmpdir)
dm = dm_cls()
model = GAN(*dm.size())
trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model)
trainer.test(model)
trainer.fit(model, dm)
trainer.test(datamodule=dm)

0 comments on commit 25f8b03

Please sign in to comment.