diff --git a/CHANGELOG.md b/CHANGELOG.md index 64adeb2add..55b752d796 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#348](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/348), [#323](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/323)) - Added data monitor callbacks `ModuleDataMonitor` and `TrainingDataMonitor` ([#285](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/285)) +- Added DCGAN module ([#403](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/403)) - Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`, and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400)) - Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347)) diff --git a/docs/source/_images/gans/dcgan_lsun_dloss.png b/docs/source/_images/gans/dcgan_lsun_dloss.png new file mode 100644 index 0000000000..7288e1ec8d Binary files /dev/null and b/docs/source/_images/gans/dcgan_lsun_dloss.png differ diff --git a/docs/source/_images/gans/dcgan_lsun_gloss.png b/docs/source/_images/gans/dcgan_lsun_gloss.png new file mode 100644 index 0000000000..4763535c5d Binary files /dev/null and b/docs/source/_images/gans/dcgan_lsun_gloss.png differ diff --git a/docs/source/_images/gans/dcgan_lsun_outputs.png b/docs/source/_images/gans/dcgan_lsun_outputs.png new file mode 100644 index 0000000000..fe10166190 Binary files /dev/null and b/docs/source/_images/gans/dcgan_lsun_outputs.png differ diff --git a/docs/source/_images/gans/dcgan_mnist_dloss.png b/docs/source/_images/gans/dcgan_mnist_dloss.png new file mode 100644 index 0000000000..826735ea5d Binary files /dev/null and b/docs/source/_images/gans/dcgan_mnist_dloss.png differ diff --git a/docs/source/_images/gans/dcgan_mnist_gloss.png b/docs/source/_images/gans/dcgan_mnist_gloss.png new file mode 100644 index 0000000000..d616d700bd Binary files /dev/null and b/docs/source/_images/gans/dcgan_mnist_gloss.png differ diff --git a/docs/source/_images/gans/dcgan_mnist_outputs.png b/docs/source/_images/gans/dcgan_mnist_outputs.png new file mode 100644 index 0000000000..8fbe130bed Binary files /dev/null and b/docs/source/_images/gans/dcgan_mnist_outputs.png differ diff --git a/docs/source/gans.rst b/docs/source/gans.rst index c76915c8bc..d25d671669 100644 --- a/docs/source/gans.rst +++ b/docs/source/gans.rst @@ -40,4 +40,49 @@ Loss curves: .. autoclass:: pl_bolts.models.gans.GAN - :noindex: \ No newline at end of file + :noindex: + +DCGAN +--------- +DCGAN implementation from the paper `Unsupervised Representation Learning with Deep Convolutional Generative +Adversarial Networks `_. The implementation is based on the version from +PyTorch's `examples `_. + +Implemented by: + + - `Christoph Clement `_ + +Example MNIST outputs: + + .. image:: _images/gans/dcgan_mnist_outputs.png + :width: 400 + :alt: DCGAN generated MNIST samples + +Example LSUN bedroom outputs: + + .. image:: _images/gans/dcgan_lsun_outputs.png + :width: 400 + :alt: DCGAN generated LSUN bedroom samples + +MNIST Loss curves: + + .. image:: _images/gans/dcgan_mnist_dloss.png + :width: 200 + :alt: DCGAN MNIST disc loss + + .. image:: _images/gans/dcgan_mnist_gloss.png + :width: 200 + :alt: DCGAN MNIST gen loss + +LSUN Loss curves: + + .. image:: _images/gans/dcgan_lsun_dloss.png + :width: 200 + :alt: DCGAN LSUN disc loss + + .. image:: _images/gans/dcgan_lsun_gloss.png + :width: 200 + :alt: DCGAN LSUN gen loss + +.. autoclass:: pl_bolts.models.gans.DCGAN + :noindex: diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 4afe195f69..12012a5477 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -81,11 +81,6 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - self.default_transforms = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]) - ]) - # split into train, val, test kitti_dataset = KittiDataset(self.data_dir, transform=self._default_transforms()) diff --git a/pl_bolts/models/gans/__init__.py b/pl_bolts/models/gans/__init__.py index c7f935ef6d..c28eb32124 100644 --- a/pl_bolts/models/gans/__init__.py +++ b/pl_bolts/models/gans/__init__.py @@ -1 +1,2 @@ from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401 +from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401 diff --git a/pl_bolts/models/gans/dcgan/__init__.py b/pl_bolts/models/gans/dcgan/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pl_bolts/models/gans/dcgan/components.py b/pl_bolts/models/gans/dcgan/components.py new file mode 100644 index 0000000000..c2432f3580 --- /dev/null +++ b/pl_bolts/models/gans/dcgan/components.py @@ -0,0 +1,95 @@ +# Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py +import torch +from torch import nn + + +class DCGANGenerator(nn.Module): + + def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None: + """ + Args: + latent_dim: Dimension of the latent space + feature_maps: Number of feature maps to use + image_channels: Number of channels of the images from the dataset + """ + super().__init__() + self.gen = nn.Sequential( + self._make_gen_block(latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0), + self._make_gen_block(feature_maps * 8, feature_maps * 4), + self._make_gen_block(feature_maps * 4, feature_maps * 2), + self._make_gen_block(feature_maps * 2, feature_maps), + self._make_gen_block(feature_maps, image_channels, last_block=True), + ) + + @staticmethod + def _make_gen_block( + in_channels: int, + out_channels: int, + kernel_size: int = 4, + stride: int = 2, + padding: int = 1, + bias: bool = False, + last_block: bool = False, + ) -> nn.Sequential: + if not last_block: + gen_block = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), + nn.BatchNorm2d(out_channels), + nn.ReLU(True), + ) + else: + gen_block = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), + nn.Tanh(), + ) + + return gen_block + + def forward(self, noise: torch.Tensor) -> torch.Tensor: + return self.gen(noise) + + +class DCGANDiscriminator(nn.Module): + + def __init__(self, feature_maps: int, image_channels: int) -> None: + """ + Args: + feature_maps: Number of feature maps to use + image_channels: Number of channels of the images from the dataset + """ + super().__init__() + self.disc = nn.Sequential( + self._make_disc_block(image_channels, feature_maps, batch_norm=False), + self._make_disc_block(feature_maps, feature_maps * 2), + self._make_disc_block(feature_maps * 2, feature_maps * 4), + self._make_disc_block(feature_maps * 4, feature_maps * 8), + self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True), + ) + + @staticmethod + def _make_disc_block( + in_channels: int, + out_channels: int, + kernel_size: int = 4, + stride: int = 2, + padding: int = 1, + bias: bool = False, + batch_norm: bool = True, + last_block: bool = False, + ) -> nn.Sequential: + if not last_block: + disc_block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), + nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(), + nn.LeakyReLU(0.2, inplace=True), + ) + else: + disc_block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), + nn.Sigmoid(), + ) + + return disc_block + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.disc(x).view(-1, 1).squeeze(1) diff --git a/pl_bolts/models/gans/dcgan/dcgan_module.py b/pl_bolts/models/gans/dcgan/dcgan_module.py new file mode 100644 index 0000000000..b99f7f4f99 --- /dev/null +++ b/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -0,0 +1,224 @@ +from argparse import ArgumentParser +from typing import Any + +import pytorch_lightning as pl +import torch +from torch import nn +from torch.utils.data import DataLoader + +from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler +from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as transform_lib + from torchvision.datasets import LSUN, MNIST +else: # pragma: no-cover + warn_missing_pkg("torchvision") + + +class DCGAN(pl.LightningModule): + """ + DCGAN implementation. + + Example:: + + from pl_bolts.models.gan import DCGAN + + m = DCGAN() + Trainer(gpus=2).fit(m) + + Example CLI:: + + # mnist + python dcgan_module.py --gpus 1 + + # cifar10 + python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3 + """ + + def __init__( + self, + beta1: float = 0.5, + feature_maps_gen: int = 64, + feature_maps_disc: int = 64, + image_channels: int = 1, + latent_dim: int = 100, + learning_rate: float = 0.0002, + **kwargs: Any, + ) -> None: + """ + Args: + beta1: Beta1 value for Adam optimizer + feature_maps_gen: Number of feature maps to use for the generator + feature_maps_disc: Number of feature maps to use for the discriminator + image_channels: Number of channels of the images from the dataset + latent_dim: Dimension of the latent space + learning_rate: Learning rate + """ + super().__init__() + self.save_hyperparameters() + + self.generator = self._get_generator() + self.discriminator = self._get_discriminator() + + self.criterion = nn.BCELoss() + + def _get_generator(self) -> nn.Module: + generator = DCGANGenerator(self.hparams.latent_dim, self.hparams.feature_maps_gen, self.hparams.image_channels) + generator.apply(self._weights_init) + return generator + + def _get_discriminator(self) -> nn.Module: + discriminator = DCGANDiscriminator(self.hparams.feature_maps_disc, self.hparams.image_channels) + discriminator.apply(self._weights_init) + return discriminator + + @staticmethod + def _weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + torch.nn.init.normal_(m.weight, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + torch.nn.init.normal_(m.weight, 1.0, 0.02) + torch.nn.init.zeros_(m.bias) + + def configure_optimizers(self): + lr = self.hparams.learning_rate + betas = (self.hparams.beta1, 0.999) + opt_disc = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betas) + opt_gen = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=betas) + return [opt_disc, opt_gen], [] + + def forward(self, noise: torch.Tensor) -> torch.Tensor: + """ + Generates an image given input noise + + Example:: + + noise = torch.rand(batch_size, latent_dim) + gan = GAN.load_from_checkpoint(PATH) + img = gan(noise) + """ + noise = noise.view(*noise.shape, 1, 1) + return self.generator(noise) + + def training_step(self, batch, batch_idx, optimizer_idx): + real, _ = batch + + # Train discriminator + result = None + if optimizer_idx == 0: + result = self._disc_step(real) + + # Train generator + if optimizer_idx == 1: + result = self._gen_step(real) + + return result + + def _disc_step(self, real: torch.Tensor) -> torch.Tensor: + disc_loss = self._get_disc_loss(real) + self.log("loss/disc", disc_loss, on_epoch=True) + return disc_loss + + def _gen_step(self, real: torch.Tensor) -> torch.Tensor: + gen_loss = self._get_gen_loss(real) + self.log("loss/gen", gen_loss, on_epoch=True) + return gen_loss + + def _get_disc_loss(self, real: torch.Tensor) -> torch.Tensor: + # Train with real + real_pred = self.discriminator(real) + real_gt = torch.ones_like(real_pred) + real_loss = self.criterion(real_pred, real_gt) + + # Train with fake + fake_pred = self._get_fake_pred(real) + fake_gt = torch.zeros_like(fake_pred) + fake_loss = self.criterion(fake_pred, fake_gt) + + disc_loss = real_loss + fake_loss + + return disc_loss + + def _get_gen_loss(self, real: torch.Tensor) -> torch.Tensor: + # Train with fake + fake_pred = self._get_fake_pred(real) + fake_gt = torch.ones_like(fake_pred) + gen_loss = self.criterion(fake_pred, fake_gt) + + return gen_loss + + def _get_fake_pred(self, real: torch.Tensor) -> torch.Tensor: + batch_size = len(real) + noise = self._get_noise(batch_size, self.hparams.latent_dim) + fake = self(noise) + fake_pred = self.discriminator(fake) + + return fake_pred + + def _get_noise(self, n_samples: int, latent_dim: int) -> torch.Tensor: + return torch.randn(n_samples, latent_dim, device=self.device) + + @staticmethod + def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument("--beta1", default=0.5, type=float) + parser.add_argument("--feature_maps_gen", default=64, type=int) + parser.add_argument("--feature_maps_disc", default=64, type=int) + parser.add_argument("--latent_dim", default=100, type=int) + parser.add_argument("--learning_rate", default=0.0002, type=float) + return parser + + +def cli_main(args=None): + pl.seed_everything(1234) + + parser = ArgumentParser() + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--dataset", default="mnist", type=str, choices=["lsun", "mnist"]) + parser.add_argument("--data_dir", default="./", type=str) + parser.add_argument("--image_size", default=64, type=int) + parser.add_argument("--num_workers", default=8, type=int) + + script_args, _ = parser.parse_known_args(args) + + if script_args.dataset == "lsun": + transforms = transform_lib.Compose([ + transform_lib.Resize(script_args.image_size), + transform_lib.CenterCrop(script_args.image_size), + transform_lib.ToTensor(), + transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]) + dataset = LSUN(root=script_args.data_dir, classes=["bedroom_train"], transform=transforms) + image_channels = 3 + elif script_args.dataset == "mnist": + transforms = transform_lib.Compose([ + transform_lib.Resize(script_args.image_size), + transform_lib.ToTensor(), + transform_lib.Normalize((0.5, ), (0.5, )), + ]) + dataset = MNIST(root=script_args.data_dir, download=True, transform=transforms) + image_channels = 1 + + dataloader = DataLoader( + dataset, batch_size=script_args.batch_size, shuffle=True, num_workers=script_args.num_workers + ) + + parser = DCGAN.add_model_specific_args(parser) + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args(args) + + model = DCGAN(**vars(args), image_channels=image_channels) + callbacks = [ + TensorboardGenerativeModelImageSampler(num_samples=5), + LatentDimInterpolator(interpolate_epoch_interval=5), + ] + trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks) + trainer.fit(model, dataloader) + + +if __name__ == "__main__": + cli_main() diff --git a/tests/models/test_gans.py b/tests/models/test_gans.py index 70f0c9c00a..fe6feea4fa 100644 --- a/tests/models/test_gans.py +++ b/tests/models/test_gans.py @@ -1,9 +1,10 @@ import pytest import pytorch_lightning as pl from pytorch_lightning import seed_everything +from torchvision import transforms as transform_lib from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule -from pl_bolts.models.gans import GAN +from pl_bolts.models.gans import DCGAN, GAN @pytest.mark.parametrize( @@ -20,3 +21,19 @@ def test_gan(tmpdir, datadir, dm_cls): trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=dm) trainer.test(datamodule=dm, ckpt_path=None) + + +@pytest.mark.parametrize( + "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), + pytest.param(CIFAR10DataModule, id="cifar10")] +) +def test_dcgan(tmpdir, datadir, dm_cls): + seed_everything() + + transforms = transform_lib.Compose([transform_lib.Resize(64), transform_lib.ToTensor()]) + dm = dm_cls(data_dir=datadir, train_transforms=transforms, val_transforms=transforms, test_transforms=transforms) + + model = DCGAN(image_channels=dm.dims[0]) + trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer.fit(model, dm) + trainer.test(datamodule=dm, ckpt_path=None) diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index 5da6c976c8..e946b17943 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -25,6 +25,14 @@ def test_cli_run_basic_gan(cli_args, dataset_name): cli_main() +@pytest.mark.parametrize('cli_args', [f'--dataset mnist --data_dir {DATASETS_PATH} --fast_dev_run 1']) +def test_cli_run_dcgan(cli_args): + from pl_bolts.models.gans.dcgan.dcgan_module import cli_main + + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): + cli_main() + + # TODO: this test is hanging (runs for more then 10min) so we need to use GPU or optimize it... @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.parametrize(