Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DCGAN module #403

Merged
merged 62 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
3a483ca
Add DCGAN module
chris-clem Nov 24, 2020
2bb60eb
Undo black on conf.py
chris-clem Nov 24, 2020
dac758b
Add tests for DCGAN
chris-clem Nov 25, 2020
1582816
Fix flake8 and codefactor
chris-clem Nov 25, 2020
0df4de5
Add types and small refactoring
chris-clem Nov 27, 2020
6b1e11a
Make image sampler callback work
chris-clem Nov 27, 2020
ec90342
Upgrade DQN to use .log (#404)
teddykoker Nov 26, 2020
6d3aaaa
fix loss test case for batch size variation (#402)
sidhantls Nov 26, 2020
ac56e5a
Decouple DataModules from Models - CPCV2 (#386)
akihironitta Nov 26, 2020
4fca092
Add docstrings, fix import, and update changelog
chris-clem Nov 27, 2020
4bc4d63
Update transforms
chris-clem Dec 15, 2020
0a13628
bugfix: batch_size parameter for DataModules remaining (#344)
hecoding Dec 1, 2020
07315ae
Fix a typo/copy paste error (#415)
bartolkaruza Dec 5, 2020
931dad8
Just a Typo (#413)
JonathanSum Dec 5, 2020
f8335f9
Remove unused arguments (#418)
akihironitta Dec 5, 2020
5b299ac
tests: Use cached datasets in LitMNIST and the doctests (#414)
akihironitta Dec 5, 2020
427b596
clear replay buffer after trajectory (#425)
sidhantls Dec 7, 2020
794aa20
stale: update label
Borda Dec 7, 2020
8aba29c
bugfix: Add missing imports to pl_bolts/__init__.py (#430)
akihironitta Dec 7, 2020
0f41495
Fix CIFAR num_samples (#432)
teddykoker Dec 8, 2020
4795612
Add static type checker mypy to the tests and pre-commit hooks (#433)
akihironitta Dec 8, 2020
e3f5e38
missing logo
Borda Dec 8, 2020
dbbcad6
Add type annotations to pl_bolts/__init__.py (#435)
akihironitta Dec 8, 2020
e8040a9
skip hanging (#437)
Borda Dec 8, 2020
bb37992
Option to normalize latent interpolation images (#438)
teddykoker Dec 8, 2020
f899c9f
0.2.6rc1
Borda Dec 11, 2020
17e7ae2
Warnings fix (#449)
ganprad Dec 14, 2020
57e2991
Refactor datamodules/datasets (#338)
akihironitta Dec 14, 2020
78c4d00
update min requirements - PL 1.1.1 (#448)
Borda Dec 16, 2020
77ed92e
Add missing optional packages to `requirements/*.txt` (#450)
akihironitta Dec 16, 2020
d81a70f
update Isort (#457)
Borda Dec 16, 2020
a52cd5c
Adding flags to datamodules (#388)
briankosw Dec 16, 2020
9b38f56
Add DCGAN module
chris-clem Nov 24, 2020
e41d9d4
Small fixes
chris-clem Dec 16, 2020
0f322fe
Remove DataModules
chris-clem Dec 16, 2020
3b52710
Update docs
chris-clem Dec 16, 2020
c1dcf70
Update docs
chris-clem Dec 16, 2020
d6d1dca
Update torchvision import
chris-clem Dec 16, 2020
add8bdf
Merge branch 'master' into feature/401_dcgan
chris-clem Dec 16, 2020
f058a3f
Import gym as optional package to build docs successfully (#458)
akihironitta Dec 17, 2020
141a7d3
bugfix: batch_size parameter for DataModules remaining (#344)
hecoding Dec 1, 2020
ce4db83
Option to normalize latent interpolation images (#438)
teddykoker Dec 8, 2020
ab89aa0
update min requirements - PL 1.1.1 (#448)
Borda Dec 16, 2020
54a246c
Apply suggestions from code review
akihironitta Dec 17, 2020
982b22e
Merge branch 'master' into feature/401_dcgan
akihironitta Dec 17, 2020
843eac3
Apply suggestions from code review
akihironitta Dec 17, 2020
e6420a6
Add docs
chris-clem Dec 18, 2020
3515d50
Merge branch 'master' into feature/401_dcgan
chris-clem Dec 24, 2020
6090b5c
Use LSUN instead of CIFAR10
Jan 3, 2021
dc3b502
Update TensorboardGenerativeModelImageSampler
Jan 3, 2021
b61574b
Update docs with lsun
Jan 3, 2021
e3bd495
Update test
Jan 3, 2021
a9197ae
Revert TensorboardGenerativeModelImageSampler changes
Jan 3, 2021
e5f1456
Merge branch 'master' into feature/401_dcgan
chris-clem Jan 4, 2021
bd692a3
Remove ModelCheckpoint callback and nrow=5 arg
chris-clem Jan 4, 2021
cecb452
Apply suggestions from code review
akihironitta Jan 7, 2021
e460f09
Merge branch 'master' into feature/401_dcgan
chris-clem Jan 7, 2021
0a238c1
Fix test_dcgan
Jan 8, 2021
9058db1
Merge branch 'master' into feature/401_dcgan
chris-clem Jan 13, 2021
2c0d4d0
Merge branch 'master' into feature/401_dcgan
chris-clem Jan 17, 2021
01fea8a
Apply yapf
Jan 17, 2021
e2f6f46
Apply suggestions from code review
Borda Jan 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#348](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/348),
[#323](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/323))
- Added data monitor callbacks `ModuleDataMonitor` and `TrainingDataMonitor` ([#285](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/285))
- Added DCGAN module ([#403](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/403))
- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`,
and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400))
- Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347))
Expand Down
Binary file added docs/source/_images/gans/dcgan_lsun_dloss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_lsun_gloss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_mnist_dloss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_mnist_gloss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 46 additions & 1 deletion docs/source/gans.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,49 @@ Loss curves:


.. autoclass:: pl_bolts.models.gans.GAN
:noindex:
:noindex:

DCGAN
---------
DCGAN implementation from the paper `Unsupervised Representation Learning with Deep Convolutional Generative
Adversarial Networks <https://arxiv.org/pdf/1511.06434.pdf>`_. The implementation is based on the version from
PyTorch's `examples <https://github.com/pytorch/examples/blob/master/dcgan/main.py>`_.

Implemented by:

- `Christoph Clement <https://github.com/chris-clem>`_

Example MNIST outputs:

.. image:: _images/gans/dcgan_mnist_outputs.png
:width: 400
:alt: DCGAN generated MNIST samples

Example LSUN bedroom outputs:

.. image:: _images/gans/dcgan_lsun_outputs.png
:width: 400
:alt: DCGAN generated LSUN bedroom samples

MNIST Loss curves:

.. image:: _images/gans/dcgan_mnist_dloss.png
:width: 200
:alt: DCGAN MNIST disc loss

.. image:: _images/gans/dcgan_mnist_gloss.png
:width: 200
:alt: DCGAN MNIST gen loss

LSUN Loss curves:

.. image:: _images/gans/dcgan_lsun_dloss.png
:width: 200
:alt: DCGAN LSUN disc loss

.. image:: _images/gans/dcgan_lsun_gloss.png
:width: 200
:alt: DCGAN LSUN gen loss

.. autoclass:: pl_bolts.models.gans.DCGAN
:noindex:
2 changes: 1 addition & 1 deletion pl_bolts/callbacks/vision/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
try:
import torchvision
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
warn_missing_pkg("torchvision") # pragma: no-cover


class TensorboardGenerativeModelImageSampler(Callback):
Expand Down
5 changes: 0 additions & 5 deletions pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ def __init__(
self.pin_memory = pin_memory
self.drop_last = drop_last

self.default_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324])
])

# split into train, val, test
kitti_dataset = KittiDataset(self.data_dir, transform=self._default_transforms())

Expand Down
1 change: 1 addition & 0 deletions pl_bolts/models/gans/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401
Empty file.
95 changes: 95 additions & 0 deletions pl_bolts/models/gans/dcgan/components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py
import torch
from torch import nn

akihironitta marked this conversation as resolved.
Show resolved Hide resolved

class DCGANGenerator(nn.Module):

def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None:
"""
Args:
latent_dim: Dimension of the latent space
feature_maps: Number of feature maps to use
image_channels: Number of channels of the images from the dataset
"""
super().__init__()
self.gen = nn.Sequential(
self._make_gen_block(latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0),
self._make_gen_block(feature_maps * 8, feature_maps * 4),
self._make_gen_block(feature_maps * 4, feature_maps * 2),
self._make_gen_block(feature_maps * 2, feature_maps),
self._make_gen_block(feature_maps, image_channels, last_block=True),
)

@staticmethod
def _make_gen_block(
in_channels: int,
out_channels: int,
kernel_size: int = 4,
stride: int = 2,
padding: int = 1,
bias: bool = False,
last_block: bool = False,
) -> nn.Sequential:
if not last_block:
gen_block = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
else:
gen_block = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.Tanh(),
)

return gen_block

def forward(self, noise: torch.Tensor) -> torch.Tensor:
return self.gen(noise)


class DCGANDiscriminator(nn.Module):

def __init__(self, feature_maps: int, image_channels: int) -> None:
"""
Args:
feature_maps: Number of feature maps to use
image_channels: Number of channels of the images from the dataset
"""
super().__init__()
self.disc = nn.Sequential(
self._make_disc_block(image_channels, feature_maps, batch_norm=False),
self._make_disc_block(feature_maps, feature_maps * 2),
self._make_disc_block(feature_maps * 2, feature_maps * 4),
self._make_disc_block(feature_maps * 4, feature_maps * 8),
self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True),
)

@staticmethod
def _make_disc_block(
in_channels: int,
out_channels: int,
kernel_size: int = 4,
stride: int = 2,
padding: int = 1,
bias: bool = False,
batch_norm: bool = True,
last_block: bool = False,
) -> nn.Sequential:
if not last_block:
disc_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
nn.LeakyReLU(0.2, inplace=True),
)
else:
disc_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
nn.Sigmoid(),
)

return disc_block

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.disc(x).view(-1, 1).squeeze(1)
Loading