Skip to content

Commit

Permalink
Add DCGAN module (#403)
Browse files Browse the repository at this point in the history
* Add DCGAN module

* Undo black on conf.py

* Add tests for DCGAN

* Fix flake8 and codefactor

* Add types and small refactoring

* Make image sampler callback work

* Upgrade DQN to use .log (#404)

* Upgrade DQN to use .log

* remove unused

* pep8

* fixed other dqn

* fix loss test case for batch size variation (#402)

* Decouple DataModules from Models - CPCV2 (#386)

* Decouple dms from CPCV2

* Update tests

* Add docstrings, fix import, and update changelog

* Update transforms

* bugfix: batch_size parameter for DataModules remaining (#344)

* bugfix: batch_size for DataModules remaining

* Update sklearn datamodule tests

* Fix default_transforms. Keep internal for every data module

* fix typo on binary_mnist_datamodule

thanks @akihironitta

Co-authored-by: Akihiro Nitta <[email protected]>

Co-authored-by: Akihiro Nitta <[email protected]>

* Fix a typo/copy paste error (#415)

* Just a Typo (#413)

missing a ' at the end of dataset='stl10

* Remove unused arguments (#418)

* tests: Use cached datasets in LitMNIST and the doctests (#414)

* Use cached datasets

* Use cached datasets in doctests

* clear replay buffer after trajectory (#425)

* stale: update label

* bugfix: Add missing imports to pl_bolts/__init__.py (#430)

* Add missing imports

* Add missing imports

* Apply isort

* Fix CIFAR num_samples (#432)

* Add static type checker mypy to the tests and pre-commit hooks (#433)

* Add mypy check to GitHub Actions

* Run mypy on pl_bolts only

* Add mypy check to pre-commit

* Add an empty line at the end of files

* Update mypy config

* Update mypy config

* Update mypy config

* show

Co-authored-by: Jirka Borovec <[email protected]>

* missing logo

* Add type annotations to pl_bolts/__init__.py (#435)

* Run mypy on pl_bolts only

* Update mypy config

* Add type hints to pl_bolts/__init__.py

* mypy

Co-authored-by: Jirka Borovec <[email protected]>

* skip hanging (#437)

* Option to normalize latent interpolation images (#438)

* add option to normalize latent interpolation images

* linspace

* update

Co-authored-by: ananyahjha93 <[email protected]>

* 0.2.6rc1

* Warnings fix (#449)

* Revert "Merge pull request #1 from ganprad/warnings_fix"

This reverts commit 7c5aaf0.

* Fixes warning related np.integer in SklearnDataModule

Fixes this warning:
```DeprecationWarning: Converting `np.integer` or `np.signedinteger` to a dtype is deprecated. The current result is `np.dtype(np.int_)` which is not strictly correct. Note that the result depends on the system. To ensure stable results use may want to use `np.int64` or `np.int32````

* Refactor datamodules/datasets (#338)

* Remove try: ... except: ...

* Fix experience_source

* Fix imagenet

* Fix kitti

* Fix sklearn

* Fix vocdetection

* Fix typo

* Remove duplicate

* Fix by flake8

* Add optional packages availability vars

* binary_mnist

* Use pl_bolts._SKLEARN_AVAILABLE

* Apply isort

* cifar10

* mnist

* cityscapes

* fashion mnist

* ssl_imagenet

* stl10

* cifar10

* dummy

* fix city

* fix stl10

* fix mnist

* ssl_amdim

* remove unused DataLoader and fix docs

* use from ... import ...

* fix pragma: no cover

* Fix forward reference in annotations

* binmnist

* Same order as imports

* Move vars from __init__ to utils/__init__

* Remove vars from __init__

* Update vars

* Apply isort

* update min requirements - PL 1.1.1 (#448)

* update min requirements

* rc0

* imports

* isort

* flake8

* 1.1.1

* flake8

* docs

* Add missing optional packages to `requirements/*.txt` (#450)

* Import matplotlib at the top

* Add missing optional packages

* Update wandb

* Add mypy to requirements

* update Isort (#457)

* Adding flags to datamodules (#388)

* Adding flags to datamodules

* Finishing up changes

* Fixing syntax error

* More syntax errors

* More

* Adding drop_last flag to sklearn test

* Adding drop_last flag to sklearn test

* Updating doc for reflect drop_last=False

* Adding flags to datamodules

* Finishing up changes

* Fixing syntax error

* More syntax errors

* More

* Adding drop_last flag to sklearn test

* Adding drop_last flag to sklearn test

* Updating doc for reflect drop_last=False

* Cleaning up parameters and docstring

* Fixing syntax error

* Fixing documentation

* Hardcoding shuffle=False for val and test

* Add DCGAN module

* Small fixes

* Remove DataModules

* Update docs

* Update docs

* Update torchvision import

* Import gym as optional package to build docs successfully (#458)

* Import gym as optional package

* Fix import

* Apply isort

* bugfix: batch_size parameter for DataModules remaining (#344)

* bugfix: batch_size for DataModules remaining

* Update sklearn datamodule tests

* Fix default_transforms. Keep internal for every data module

* fix typo on binary_mnist_datamodule

thanks @akihironitta

Co-authored-by: Akihiro Nitta <[email protected]>

Co-authored-by: Akihiro Nitta <[email protected]>

* Option to normalize latent interpolation images (#438)

* add option to normalize latent interpolation images

* linspace

* update

Co-authored-by: ananyahjha93 <[email protected]>

* update min requirements - PL 1.1.1 (#448)

* update min requirements

* rc0

* imports

* isort

* flake8

* 1.1.1

* flake8

* docs

* Apply suggestions from code review

* Apply suggestions from code review

* Add docs

* Use LSUN instead of CIFAR10

* Update TensorboardGenerativeModelImageSampler

* Update docs with lsun

* Update test

* Revert TensorboardGenerativeModelImageSampler changes

* Remove ModelCheckpoint callback and nrow=5 arg

* Apply suggestions from code review

* Fix test_dcgan

* Apply yapf

* Apply suggestions from code review

Co-authored-by: Teddy Koker <[email protected]>
Co-authored-by: Sidhant Sundrani <[email protected]>
Co-authored-by: Akihiro Nitta <[email protected]>
Co-authored-by: Héctor Laria <[email protected]>
Co-authored-by: Bartol Karuza <[email protected]>
Co-authored-by: Happy Sugar Life <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: ananyahjha93 <[email protected]>
Co-authored-by: Pradeep Ganesan <[email protected]>
Co-authored-by: Brian Ko <[email protected]>
Co-authored-by: Christoph Clement <[email protected]>
  • Loading branch information
13 people authored Jan 18, 2021
1 parent 6e15643 commit 056f836
Show file tree
Hide file tree
Showing 15 changed files with 393 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#348](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/348),
[#323](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/323))
- Added data monitor callbacks `ModuleDataMonitor` and `TrainingDataMonitor` ([#285](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/285))
- Added DCGAN module ([#403](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/403))
- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`,
and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400))
- Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347))
Expand Down
Binary file added docs/source/_images/gans/dcgan_lsun_dloss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_lsun_gloss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_lsun_outputs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_mnist_dloss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_mnist_gloss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_mnist_outputs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 46 additions & 1 deletion docs/source/gans.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,49 @@ Loss curves:
.. autoclass:: pl_bolts.models.gans.GAN
:noindex:
:noindex:

DCGAN
---------
DCGAN implementation from the paper `Unsupervised Representation Learning with Deep Convolutional Generative
Adversarial Networks <https://arxiv.org/pdf/1511.06434.pdf>`_. The implementation is based on the version from
PyTorch's `examples <https://github.com/pytorch/examples/blob/master/dcgan/main.py>`_.

Implemented by:

- `Christoph Clement <https://github.com/chris-clem>`_

Example MNIST outputs:

.. image:: _images/gans/dcgan_mnist_outputs.png
:width: 400
:alt: DCGAN generated MNIST samples

Example LSUN bedroom outputs:

.. image:: _images/gans/dcgan_lsun_outputs.png
:width: 400
:alt: DCGAN generated LSUN bedroom samples

MNIST Loss curves:

.. image:: _images/gans/dcgan_mnist_dloss.png
:width: 200
:alt: DCGAN MNIST disc loss

.. image:: _images/gans/dcgan_mnist_gloss.png
:width: 200
:alt: DCGAN MNIST gen loss

LSUN Loss curves:

.. image:: _images/gans/dcgan_lsun_dloss.png
:width: 200
:alt: DCGAN LSUN disc loss

.. image:: _images/gans/dcgan_lsun_gloss.png
:width: 200
:alt: DCGAN LSUN gen loss

.. autoclass:: pl_bolts.models.gans.DCGAN
:noindex:
5 changes: 0 additions & 5 deletions pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ def __init__(
self.pin_memory = pin_memory
self.drop_last = drop_last

self.default_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324])
])

# split into train, val, test
kitti_dataset = KittiDataset(self.data_dir, transform=self._default_transforms())

Expand Down
1 change: 1 addition & 0 deletions pl_bolts/models/gans/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401
Empty file.
95 changes: 95 additions & 0 deletions pl_bolts/models/gans/dcgan/components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py
import torch
from torch import nn


class DCGANGenerator(nn.Module):

def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None:
"""
Args:
latent_dim: Dimension of the latent space
feature_maps: Number of feature maps to use
image_channels: Number of channels of the images from the dataset
"""
super().__init__()
self.gen = nn.Sequential(
self._make_gen_block(latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0),
self._make_gen_block(feature_maps * 8, feature_maps * 4),
self._make_gen_block(feature_maps * 4, feature_maps * 2),
self._make_gen_block(feature_maps * 2, feature_maps),
self._make_gen_block(feature_maps, image_channels, last_block=True),
)

@staticmethod
def _make_gen_block(
in_channels: int,
out_channels: int,
kernel_size: int = 4,
stride: int = 2,
padding: int = 1,
bias: bool = False,
last_block: bool = False,
) -> nn.Sequential:
if not last_block:
gen_block = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
else:
gen_block = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.Tanh(),
)

return gen_block

def forward(self, noise: torch.Tensor) -> torch.Tensor:
return self.gen(noise)


class DCGANDiscriminator(nn.Module):

def __init__(self, feature_maps: int, image_channels: int) -> None:
"""
Args:
feature_maps: Number of feature maps to use
image_channels: Number of channels of the images from the dataset
"""
super().__init__()
self.disc = nn.Sequential(
self._make_disc_block(image_channels, feature_maps, batch_norm=False),
self._make_disc_block(feature_maps, feature_maps * 2),
self._make_disc_block(feature_maps * 2, feature_maps * 4),
self._make_disc_block(feature_maps * 4, feature_maps * 8),
self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True),
)

@staticmethod
def _make_disc_block(
in_channels: int,
out_channels: int,
kernel_size: int = 4,
stride: int = 2,
padding: int = 1,
bias: bool = False,
batch_norm: bool = True,
last_block: bool = False,
) -> nn.Sequential:
if not last_block:
disc_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
nn.LeakyReLU(0.2, inplace=True),
)
else:
disc_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.Sigmoid(),
)

return disc_block

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.disc(x).view(-1, 1).squeeze(1)
224 changes: 224 additions & 0 deletions pl_bolts/models/gans/dcgan/dcgan_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from argparse import ArgumentParser
from typing import Any

import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader

from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import LSUN, MNIST
else: # pragma: no-cover
warn_missing_pkg("torchvision")


class DCGAN(pl.LightningModule):
"""
DCGAN implementation.
Example::
from pl_bolts.models.gan import DCGAN
m = DCGAN()
Trainer(gpus=2).fit(m)
Example CLI::
# mnist
python dcgan_module.py --gpus 1
# cifar10
python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3
"""

def __init__(
self,
beta1: float = 0.5,
feature_maps_gen: int = 64,
feature_maps_disc: int = 64,
image_channels: int = 1,
latent_dim: int = 100,
learning_rate: float = 0.0002,
**kwargs: Any,
) -> None:
"""
Args:
beta1: Beta1 value for Adam optimizer
feature_maps_gen: Number of feature maps to use for the generator
feature_maps_disc: Number of feature maps to use for the discriminator
image_channels: Number of channels of the images from the dataset
latent_dim: Dimension of the latent space
learning_rate: Learning rate
"""
super().__init__()
self.save_hyperparameters()

self.generator = self._get_generator()
self.discriminator = self._get_discriminator()

self.criterion = nn.BCELoss()

def _get_generator(self) -> nn.Module:
generator = DCGANGenerator(self.hparams.latent_dim, self.hparams.feature_maps_gen, self.hparams.image_channels)
generator.apply(self._weights_init)
return generator

def _get_discriminator(self) -> nn.Module:
discriminator = DCGANDiscriminator(self.hparams.feature_maps_disc, self.hparams.image_channels)
discriminator.apply(self._weights_init)
return discriminator

@staticmethod
def _weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
torch.nn.init.normal_(m.weight, 1.0, 0.02)
torch.nn.init.zeros_(m.bias)

def configure_optimizers(self):
lr = self.hparams.learning_rate
betas = (self.hparams.beta1, 0.999)
opt_disc = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betas)
opt_gen = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=betas)
return [opt_disc, opt_gen], []

def forward(self, noise: torch.Tensor) -> torch.Tensor:
"""
Generates an image given input noise
Example::
noise = torch.rand(batch_size, latent_dim)
gan = GAN.load_from_checkpoint(PATH)
img = gan(noise)
"""
noise = noise.view(*noise.shape, 1, 1)
return self.generator(noise)

def training_step(self, batch, batch_idx, optimizer_idx):
real, _ = batch

# Train discriminator
result = None
if optimizer_idx == 0:
result = self._disc_step(real)

# Train generator
if optimizer_idx == 1:
result = self._gen_step(real)

return result

def _disc_step(self, real: torch.Tensor) -> torch.Tensor:
disc_loss = self._get_disc_loss(real)
self.log("loss/disc", disc_loss, on_epoch=True)
return disc_loss

def _gen_step(self, real: torch.Tensor) -> torch.Tensor:
gen_loss = self._get_gen_loss(real)
self.log("loss/gen", gen_loss, on_epoch=True)
return gen_loss

def _get_disc_loss(self, real: torch.Tensor) -> torch.Tensor:
# Train with real
real_pred = self.discriminator(real)
real_gt = torch.ones_like(real_pred)
real_loss = self.criterion(real_pred, real_gt)

# Train with fake
fake_pred = self._get_fake_pred(real)
fake_gt = torch.zeros_like(fake_pred)
fake_loss = self.criterion(fake_pred, fake_gt)

disc_loss = real_loss + fake_loss

return disc_loss

def _get_gen_loss(self, real: torch.Tensor) -> torch.Tensor:
# Train with fake
fake_pred = self._get_fake_pred(real)
fake_gt = torch.ones_like(fake_pred)
gen_loss = self.criterion(fake_pred, fake_gt)

return gen_loss

def _get_fake_pred(self, real: torch.Tensor) -> torch.Tensor:
batch_size = len(real)
noise = self._get_noise(batch_size, self.hparams.latent_dim)
fake = self(noise)
fake_pred = self.discriminator(fake)

return fake_pred

def _get_noise(self, n_samples: int, latent_dim: int) -> torch.Tensor:
return torch.randn(n_samples, latent_dim, device=self.device)

@staticmethod
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--beta1", default=0.5, type=float)
parser.add_argument("--feature_maps_gen", default=64, type=int)
parser.add_argument("--feature_maps_disc", default=64, type=int)
parser.add_argument("--latent_dim", default=100, type=int)
parser.add_argument("--learning_rate", default=0.0002, type=float)
return parser


def cli_main(args=None):
pl.seed_everything(1234)

parser = ArgumentParser()
parser.add_argument("--batch_size", default=64, type=int)
parser.add_argument("--dataset", default="mnist", type=str, choices=["lsun", "mnist"])
parser.add_argument("--data_dir", default="./", type=str)
parser.add_argument("--image_size", default=64, type=int)
parser.add_argument("--num_workers", default=8, type=int)

script_args, _ = parser.parse_known_args(args)

if script_args.dataset == "lsun":
transforms = transform_lib.Compose([
transform_lib.Resize(script_args.image_size),
transform_lib.CenterCrop(script_args.image_size),
transform_lib.ToTensor(),
transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset = LSUN(root=script_args.data_dir, classes=["bedroom_train"], transform=transforms)
image_channels = 3
elif script_args.dataset == "mnist":
transforms = transform_lib.Compose([
transform_lib.Resize(script_args.image_size),
transform_lib.ToTensor(),
transform_lib.Normalize((0.5, ), (0.5, )),
])
dataset = MNIST(root=script_args.data_dir, download=True, transform=transforms)
image_channels = 1

dataloader = DataLoader(
dataset, batch_size=script_args.batch_size, shuffle=True, num_workers=script_args.num_workers
)

parser = DCGAN.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args(args)

model = DCGAN(**vars(args), image_channels=image_channels)
callbacks = [
TensorboardGenerativeModelImageSampler(num_samples=5),
LatentDimInterpolator(interpolate_epoch_interval=5),
]
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks)
trainer.fit(model, dataloader)


if __name__ == "__main__":
cli_main()
Loading

0 comments on commit 056f836

Please sign in to comment.