Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DCGAN module #403

Merged
merged 62 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
3a483ca
Add DCGAN module
chris-clem Nov 24, 2020
2bb60eb
Undo black on conf.py
chris-clem Nov 24, 2020
dac758b
Add tests for DCGAN
chris-clem Nov 25, 2020
1582816
Fix flake8 and codefactor
chris-clem Nov 25, 2020
0df4de5
Add types and small refactoring
chris-clem Nov 27, 2020
6b1e11a
Make image sampler callback work
chris-clem Nov 27, 2020
ec90342
Upgrade DQN to use .log (#404)
teddykoker Nov 26, 2020
6d3aaaa
fix loss test case for batch size variation (#402)
sidhantls Nov 26, 2020
ac56e5a
Decouple DataModules from Models - CPCV2 (#386)
akihironitta Nov 26, 2020
4fca092
Add docstrings, fix import, and update changelog
chris-clem Nov 27, 2020
4bc4d63
Update transforms
chris-clem Dec 15, 2020
0a13628
bugfix: batch_size parameter for DataModules remaining (#344)
hecoding Dec 1, 2020
07315ae
Fix a typo/copy paste error (#415)
bartolkaruza Dec 5, 2020
931dad8
Just a Typo (#413)
JonathanSum Dec 5, 2020
f8335f9
Remove unused arguments (#418)
akihironitta Dec 5, 2020
5b299ac
tests: Use cached datasets in LitMNIST and the doctests (#414)
akihironitta Dec 5, 2020
427b596
clear replay buffer after trajectory (#425)
sidhantls Dec 7, 2020
794aa20
stale: update label
Borda Dec 7, 2020
8aba29c
bugfix: Add missing imports to pl_bolts/__init__.py (#430)
akihironitta Dec 7, 2020
0f41495
Fix CIFAR num_samples (#432)
teddykoker Dec 8, 2020
4795612
Add static type checker mypy to the tests and pre-commit hooks (#433)
akihironitta Dec 8, 2020
e3f5e38
missing logo
Borda Dec 8, 2020
dbbcad6
Add type annotations to pl_bolts/__init__.py (#435)
akihironitta Dec 8, 2020
e8040a9
skip hanging (#437)
Borda Dec 8, 2020
bb37992
Option to normalize latent interpolation images (#438)
teddykoker Dec 8, 2020
f899c9f
0.2.6rc1
Borda Dec 11, 2020
17e7ae2
Warnings fix (#449)
ganprad Dec 14, 2020
57e2991
Refactor datamodules/datasets (#338)
akihironitta Dec 14, 2020
78c4d00
update min requirements - PL 1.1.1 (#448)
Borda Dec 16, 2020
77ed92e
Add missing optional packages to `requirements/*.txt` (#450)
akihironitta Dec 16, 2020
d81a70f
update Isort (#457)
Borda Dec 16, 2020
a52cd5c
Adding flags to datamodules (#388)
briankosw Dec 16, 2020
9b38f56
Add DCGAN module
chris-clem Nov 24, 2020
e41d9d4
Small fixes
chris-clem Dec 16, 2020
0f322fe
Remove DataModules
chris-clem Dec 16, 2020
3b52710
Update docs
chris-clem Dec 16, 2020
c1dcf70
Update docs
chris-clem Dec 16, 2020
d6d1dca
Update torchvision import
chris-clem Dec 16, 2020
add8bdf
Merge branch 'master' into feature/401_dcgan
chris-clem Dec 16, 2020
f058a3f
Import gym as optional package to build docs successfully (#458)
akihironitta Dec 17, 2020
141a7d3
bugfix: batch_size parameter for DataModules remaining (#344)
hecoding Dec 1, 2020
ce4db83
Option to normalize latent interpolation images (#438)
teddykoker Dec 8, 2020
ab89aa0
update min requirements - PL 1.1.1 (#448)
Borda Dec 16, 2020
54a246c
Apply suggestions from code review
akihironitta Dec 17, 2020
982b22e
Merge branch 'master' into feature/401_dcgan
akihironitta Dec 17, 2020
843eac3
Apply suggestions from code review
akihironitta Dec 17, 2020
e6420a6
Add docs
chris-clem Dec 18, 2020
3515d50
Merge branch 'master' into feature/401_dcgan
chris-clem Dec 24, 2020
6090b5c
Use LSUN instead of CIFAR10
Jan 3, 2021
dc3b502
Update TensorboardGenerativeModelImageSampler
Jan 3, 2021
b61574b
Update docs with lsun
Jan 3, 2021
e3bd495
Update test
Jan 3, 2021
a9197ae
Revert TensorboardGenerativeModelImageSampler changes
Jan 3, 2021
e5f1456
Merge branch 'master' into feature/401_dcgan
chris-clem Jan 4, 2021
bd692a3
Remove ModelCheckpoint callback and nrow=5 arg
chris-clem Jan 4, 2021
cecb452
Apply suggestions from code review
akihironitta Jan 7, 2021
e460f09
Merge branch 'master' into feature/401_dcgan
chris-clem Jan 7, 2021
0a238c1
Fix test_dcgan
Jan 8, 2021
9058db1
Merge branch 'master' into feature/401_dcgan
chris-clem Jan 13, 2021
2c0d4d0
Merge branch 'master' into feature/401_dcgan
chris-clem Jan 17, 2021
01fea8a
Apply yapf
Jan 17, 2021
e2f6f46
Apply suggestions from code review
Borda Jan 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
82 changes: 82 additions & 0 deletions pl_bolts/models/gans/dcgan/components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py
import torch
from torch import nn

akihironitta marked this conversation as resolved.
Show resolved Hide resolved

class DCGANGenerator(nn.Module):
def __init__(self, latent_dim: int, feature_maps: int, image_channels: int):
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.gen = nn.Sequential(
self._make_gen_block(latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0),
self._make_gen_block(feature_maps * 8, feature_maps * 4),
self._make_gen_block(feature_maps * 4, feature_maps * 2),
self._make_gen_block(feature_maps * 2, feature_maps, 4),
self._make_gen_block(feature_maps, image_channels, last_block=True),
)

@staticmethod
def _make_gen_block(
in_channels: int,
out_channels: int,
kernel_size: int = 4,
stride: int = 2,
padding: int = 1,
bias: bool = False,
last_block: bool = False,
) -> nn.Sequential:
if not last_block:
gen_block = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
else:
gen_block = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.Tanh(),
)

return gen_block

def forward(self, noise: torch.Tensor) -> torch.Tensor:
return self.gen(noise)


class DCGANDiscriminator(nn.Module):
def __init__(self, feature_maps: int, image_channels: int):
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.disc = nn.Sequential(
self._make_disc_block(image_channels, feature_maps, batch_norm=False),
self._make_disc_block(feature_maps, feature_maps * 2),
self._make_disc_block(feature_maps * 2, feature_maps * 4),
self._make_disc_block(feature_maps * 4, feature_maps * 8),
self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True),
)

@staticmethod
def _make_disc_block(
in_channels: int,
out_channels: int,
kernel_size: int = 4,
stride: int = 2,
padding: int = 1,
bias: bool = False,
batch_norm: bool = True,
last_block: bool = False,
) -> nn.Sequential:
if not last_block:
disc_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
nn.LeakyReLU(0.2, inplace=True),
)
else:
disc_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.Sigmoid(),
)

return disc_block

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.disc(x)
181 changes: 181 additions & 0 deletions pl_bolts/models/gans/dcgan/dcgan_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from argparse import ArgumentParser
akihironitta marked this conversation as resolved.
Show resolved Hide resolved

import pytorch_lightning as pl
import torch
from torch import nn

from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
from pl_bolts.utils.warnings import warn_missing_pkg

try:
from torchvision import transforms as transform_lib
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
else:
_TORCHVISION_AVAILABLE = True
from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator


class DCGAN(pl.LightningModule):
chris-clem marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
beta1: float = 0.5,
beta2: float = 0.999,
feature_maps_gen: int = 64,
feature_maps_disc: int = 64,
image_channels: int = 1,
latent_dim: int = 100,
learning_rate: float = 0.0002,
**kwargs
):
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.save_hyperparameters()

self.generator = self._get_generator()
self.discriminator = self._get_discriminator()

self.criterion = nn.BCEWithLogitsLoss()

def _get_generator(self):
generator = DCGANGenerator(self.hparams.latent_dim, self.hparams.feature_maps_gen, self.hparams.image_channels)
generator.apply(self._weights_init)
return generator

def _get_discriminator(self):
discriminator = DCGANDiscriminator(self.hparams.feature_maps_disc, self.hparams.image_channels)
discriminator.apply(self._weights_init)
return discriminator

@staticmethod
def _weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
torch.nn.init.normal_(m.weight, 1.0, 0.02)
torch.nn.init.zeros_(m.bias)

def configure_optimizers(self):
lr = self.hparams.learning_rate
betas = (self.hparams.beta1, self.hparams.beta2)
opt_disc = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betas)
opt_gen = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=betas)
return [opt_disc, opt_gen], []

def forward(self, noise: torch.Tensor) -> torch.Tensor:
noise = noise.view(*noise.shape, 1, 1)
return self.generator(noise)

def training_step(self, batch, batch_idx, optimizer_idx):
real, _ = batch

# Train discriminator
result = None
if optimizer_idx == 0:
result = self._disc_step(real)

# Train generator
if optimizer_idx == 1:
result = self._gen_step(real)

return result

def _disc_step(self, real: torch.Tensor) -> torch.Tensor:
disc_loss = self._get_disc_loss(real)
self.log("loss/disc", disc_loss, on_epoch=True, prog_bar=True)
return disc_loss

def _gen_step(self, real: torch.Tensor) -> torch.Tensor:
gen_loss = self._get_gen_loss(real)
self.log("loss/gen", gen_loss, on_epoch=True, prog_bar=True)
return gen_loss

def _get_disc_loss(self, real: torch.Tensor) -> torch.Tensor:
# Train with real
real_pred = self.discriminator(real)
real_gt = torch.ones_like(real_pred)
real_loss = self.criterion(real_pred, real_gt)

# Train with fake
fake_pred = self._get_fake_pred(real)
fake_gt = torch.zeros_like(fake_pred)
fake_loss = self.criterion(fake_pred, fake_gt)

disc_loss = real_loss + fake_loss

return disc_loss

def _get_gen_loss(self, real: torch.Tensor) -> torch.Tensor:
# Train with fake
fake_pred = self._get_fake_pred(real)
fake_gt = torch.ones_like(fake_pred)
gen_loss = self.criterion(fake_pred, fake_gt)

return gen_loss

def _get_fake_pred(self, real: torch.Tensor) -> torch.Tensor:
batch_size = self._get_batch_size(real)
noise = self._get_noise(batch_size, self.hparams.latent_dim)
fake = self(noise)
fake_pred = self.discriminator(fake)

return fake_pred

@staticmethod
def _get_batch_size(real: torch.Tensor) -> int:
batch_size = len(real)
return batch_size

def _get_noise(self, n_samples: int, latent_dim: int) -> torch.Tensor:
noise = torch.randn(n_samples, latent_dim, device=self.device)
return noise

@staticmethod
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--beta1", default=0.5, type=float)
parser.add_argument("--beta2", default=0.999, type=float)
parser.add_argument("--feature_maps_gen", default=64, type=int)
parser.add_argument("--feature_maps_disc", default=64, type=int)
parser.add_argument("--image_channels", default=1, type=int)
parser.add_argument("--latent_dim", default=100, type=int)
parser.add_argument("--learning_rate", default=0.0002, type=float)
return parser


def cli_main(args=None):
from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule

pl.seed_everything(1234)

parser = ArgumentParser()
parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10")
parser.add_argument("--image_size", default=64, type=int)
script_args, _ = parser.parse_known_args(args)

if script_args.dataset == "mnist":
dm_cls = MNISTDataModule
elif script_args.dataset == "cifar10":
dm_cls = CIFAR10DataModule

parser = dm_cls.add_argparse_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
parser = DCGAN.add_model_specific_args(parser)
args = parser.parse_args(args)

transforms = transform_lib.Compose([transform_lib.Resize(args.image_size), transform_lib.ToTensor()])
dm = dm_cls.from_argparse_args(args)
dm.train_transforms = transforms
dm.val_transforms = transforms
dm.test_transforms = transforms

model = DCGAN(**vars(args))
callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)]
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks)
trainer.fit(model, dm)
return dm, model, trainer


if __name__ == "__main__":
dm, model, trainer = cli_main()
17 changes: 17 additions & 0 deletions tests/models/test_gans.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pytest
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from torchvision import transforms as transform_lib

from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule
from pl_bolts.models.gans import GAN
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN
chris-clem marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
Expand All @@ -17,3 +19,18 @@ def test_gan(tmpdir, datadir, dm_cls):
trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, dm)
trainer.test(datamodule=dm)


@pytest.mark.parametrize(
"dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")]
)
def test_dcgan(tmpdir, datadir, dm_cls):
seed_everything()

transforms = transform_lib.Compose([transform_lib.Resize(64), transform_lib.ToTensor()])
dm = dm_cls(data_dir=datadir, train_transforms=transforms, val_transforms=transforms, test_transforms=transforms)

model = DCGAN(image_channels=dm.dims[0])
trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, dm)
trainer.test(datamodule=dm)
11 changes: 11 additions & 0 deletions tests/models/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ def test_cli_run_basic_gan(cli_args):
cli_main()


@pytest.mark.parametrize('cli_args', [
f'--dataset mnist --data_dir {DATASETS_PATH} --fast_dev_run --image_channels 1',
f'--dataset cifar10 --data_dir {DATASETS_PATH} --fast_dev_run --image_channels 3',
])
def test_cli_run_dcgan_gan(cli_args):
from pl_bolts.models.gans.dcgan.dcgan_module import cli_main

with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
cli_main()


# TODO: this test is hanging (runs for more then 10min) so we need to use GPU or optimize it...
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.parametrize('cli_args', [
Expand Down