-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add DCGAN module * Undo black on conf.py * Add tests for DCGAN * Fix flake8 and codefactor * Add types and small refactoring * Make image sampler callback work * Upgrade DQN to use .log (#404) * Upgrade DQN to use .log * remove unused * pep8 * fixed other dqn * fix loss test case for batch size variation (#402) * Decouple DataModules from Models - CPCV2 (#386) * Decouple dms from CPCV2 * Update tests * Add docstrings, fix import, and update changelog * Update transforms * bugfix: batch_size parameter for DataModules remaining (#344) * bugfix: batch_size for DataModules remaining * Update sklearn datamodule tests * Fix default_transforms. Keep internal for every data module * fix typo on binary_mnist_datamodule thanks @akihironitta Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: Akihiro Nitta <[email protected]> * Fix a typo/copy paste error (#415) * Just a Typo (#413) missing a ' at the end of dataset='stl10 * Remove unused arguments (#418) * tests: Use cached datasets in LitMNIST and the doctests (#414) * Use cached datasets * Use cached datasets in doctests * clear replay buffer after trajectory (#425) * stale: update label * bugfix: Add missing imports to pl_bolts/__init__.py (#430) * Add missing imports * Add missing imports * Apply isort * Fix CIFAR num_samples (#432) * Add static type checker mypy to the tests and pre-commit hooks (#433) * Add mypy check to GitHub Actions * Run mypy on pl_bolts only * Add mypy check to pre-commit * Add an empty line at the end of files * Update mypy config * Update mypy config * Update mypy config * show Co-authored-by: Jirka Borovec <[email protected]> * missing logo * Add type annotations to pl_bolts/__init__.py (#435) * Run mypy on pl_bolts only * Update mypy config * Add type hints to pl_bolts/__init__.py * mypy Co-authored-by: Jirka Borovec <[email protected]> * skip hanging (#437) * Option to normalize latent interpolation images (#438) * add option to normalize latent interpolation images * linspace * update Co-authored-by: ananyahjha93 <[email protected]> * 0.2.6rc1 * Warnings fix (#449) * Revert "Merge pull request #1 from ganprad/warnings_fix" This reverts commit 7c5aaf0. * Fixes warning related np.integer in SklearnDataModule Fixes this warning: ```DeprecationWarning: Converting `np.integer` or `np.signedinteger` to a dtype is deprecated. The current result is `np.dtype(np.int_)` which is not strictly correct. Note that the result depends on the system. To ensure stable results use may want to use `np.int64` or `np.int32```` * Refactor datamodules/datasets (#338) * Remove try: ... except: ... * Fix experience_source * Fix imagenet * Fix kitti * Fix sklearn * Fix vocdetection * Fix typo * Remove duplicate * Fix by flake8 * Add optional packages availability vars * binary_mnist * Use pl_bolts._SKLEARN_AVAILABLE * Apply isort * cifar10 * mnist * cityscapes * fashion mnist * ssl_imagenet * stl10 * cifar10 * dummy * fix city * fix stl10 * fix mnist * ssl_amdim * remove unused DataLoader and fix docs * use from ... import ... * fix pragma: no cover * Fix forward reference in annotations * binmnist * Same order as imports * Move vars from __init__ to utils/__init__ * Remove vars from __init__ * Update vars * Apply isort * update min requirements - PL 1.1.1 (#448) * update min requirements * rc0 * imports * isort * flake8 * 1.1.1 * flake8 * docs * Add missing optional packages to `requirements/*.txt` (#450) * Import matplotlib at the top * Add missing optional packages * Update wandb * Add mypy to requirements * update Isort (#457) * Adding flags to datamodules (#388) * Adding flags to datamodules * Finishing up changes * Fixing syntax error * More syntax errors * More * Adding drop_last flag to sklearn test * Adding drop_last flag to sklearn test * Updating doc for reflect drop_last=False * Adding flags to datamodules * Finishing up changes * Fixing syntax error * More syntax errors * More * Adding drop_last flag to sklearn test * Adding drop_last flag to sklearn test * Updating doc for reflect drop_last=False * Cleaning up parameters and docstring * Fixing syntax error * Fixing documentation * Hardcoding shuffle=False for val and test * Add DCGAN module * Small fixes * Remove DataModules * Update docs * Update docs * Update torchvision import * Import gym as optional package to build docs successfully (#458) * Import gym as optional package * Fix import * Apply isort * bugfix: batch_size parameter for DataModules remaining (#344) * bugfix: batch_size for DataModules remaining * Update sklearn datamodule tests * Fix default_transforms. Keep internal for every data module * fix typo on binary_mnist_datamodule thanks @akihironitta Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: Akihiro Nitta <[email protected]> * Option to normalize latent interpolation images (#438) * add option to normalize latent interpolation images * linspace * update Co-authored-by: ananyahjha93 <[email protected]> * update min requirements - PL 1.1.1 (#448) * update min requirements * rc0 * imports * isort * flake8 * 1.1.1 * flake8 * docs * Apply suggestions from code review * Apply suggestions from code review * Add docs * Use LSUN instead of CIFAR10 * Update TensorboardGenerativeModelImageSampler * Update docs with lsun * Update test * Revert TensorboardGenerativeModelImageSampler changes * Remove ModelCheckpoint callback and nrow=5 arg * Apply suggestions from code review * Fix test_dcgan * Apply yapf * Apply suggestions from code review Co-authored-by: Teddy Koker <[email protected]> Co-authored-by: Sidhant Sundrani <[email protected]> Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: Héctor Laria <[email protected]> Co-authored-by: Bartol Karuza <[email protected]> Co-authored-by: Happy Sugar Life <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: ananyahjha93 <[email protected]> Co-authored-by: Pradeep Ganesan <[email protected]> Co-authored-by: Brian Ko <[email protected]> Co-authored-by: Christoph Clement <[email protected]>
- Loading branch information
1 parent
6e15643
commit 056f836
Showing
15 changed files
with
393 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401 | ||
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class DCGANGenerator(nn.Module): | ||
|
||
def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None: | ||
""" | ||
Args: | ||
latent_dim: Dimension of the latent space | ||
feature_maps: Number of feature maps to use | ||
image_channels: Number of channels of the images from the dataset | ||
""" | ||
super().__init__() | ||
self.gen = nn.Sequential( | ||
self._make_gen_block(latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0), | ||
self._make_gen_block(feature_maps * 8, feature_maps * 4), | ||
self._make_gen_block(feature_maps * 4, feature_maps * 2), | ||
self._make_gen_block(feature_maps * 2, feature_maps), | ||
self._make_gen_block(feature_maps, image_channels, last_block=True), | ||
) | ||
|
||
@staticmethod | ||
def _make_gen_block( | ||
in_channels: int, | ||
out_channels: int, | ||
kernel_size: int = 4, | ||
stride: int = 2, | ||
padding: int = 1, | ||
bias: bool = False, | ||
last_block: bool = False, | ||
) -> nn.Sequential: | ||
if not last_block: | ||
gen_block = nn.Sequential( | ||
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), | ||
nn.BatchNorm2d(out_channels), | ||
nn.ReLU(True), | ||
) | ||
else: | ||
gen_block = nn.Sequential( | ||
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), | ||
nn.Tanh(), | ||
) | ||
|
||
return gen_block | ||
|
||
def forward(self, noise: torch.Tensor) -> torch.Tensor: | ||
return self.gen(noise) | ||
|
||
|
||
class DCGANDiscriminator(nn.Module): | ||
|
||
def __init__(self, feature_maps: int, image_channels: int) -> None: | ||
""" | ||
Args: | ||
feature_maps: Number of feature maps to use | ||
image_channels: Number of channels of the images from the dataset | ||
""" | ||
super().__init__() | ||
self.disc = nn.Sequential( | ||
self._make_disc_block(image_channels, feature_maps, batch_norm=False), | ||
self._make_disc_block(feature_maps, feature_maps * 2), | ||
self._make_disc_block(feature_maps * 2, feature_maps * 4), | ||
self._make_disc_block(feature_maps * 4, feature_maps * 8), | ||
self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True), | ||
) | ||
|
||
@staticmethod | ||
def _make_disc_block( | ||
in_channels: int, | ||
out_channels: int, | ||
kernel_size: int = 4, | ||
stride: int = 2, | ||
padding: int = 1, | ||
bias: bool = False, | ||
batch_norm: bool = True, | ||
last_block: bool = False, | ||
) -> nn.Sequential: | ||
if not last_block: | ||
disc_block = nn.Sequential( | ||
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), | ||
nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
) | ||
else: | ||
disc_block = nn.Sequential( | ||
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), | ||
nn.Sigmoid(), | ||
) | ||
|
||
return disc_block | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return self.disc(x).view(-1, 1).squeeze(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
from argparse import ArgumentParser | ||
from typing import Any | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from torch import nn | ||
from torch.utils.data import DataLoader | ||
|
||
from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler | ||
from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator | ||
from pl_bolts.utils import _TORCHVISION_AVAILABLE | ||
from pl_bolts.utils.warnings import warn_missing_pkg | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
from torchvision import transforms as transform_lib | ||
from torchvision.datasets import LSUN, MNIST | ||
else: # pragma: no-cover | ||
warn_missing_pkg("torchvision") | ||
|
||
|
||
class DCGAN(pl.LightningModule): | ||
""" | ||
DCGAN implementation. | ||
Example:: | ||
from pl_bolts.models.gan import DCGAN | ||
m = DCGAN() | ||
Trainer(gpus=2).fit(m) | ||
Example CLI:: | ||
# mnist | ||
python dcgan_module.py --gpus 1 | ||
# cifar10 | ||
python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
beta1: float = 0.5, | ||
feature_maps_gen: int = 64, | ||
feature_maps_disc: int = 64, | ||
image_channels: int = 1, | ||
latent_dim: int = 100, | ||
learning_rate: float = 0.0002, | ||
**kwargs: Any, | ||
) -> None: | ||
""" | ||
Args: | ||
beta1: Beta1 value for Adam optimizer | ||
feature_maps_gen: Number of feature maps to use for the generator | ||
feature_maps_disc: Number of feature maps to use for the discriminator | ||
image_channels: Number of channels of the images from the dataset | ||
latent_dim: Dimension of the latent space | ||
learning_rate: Learning rate | ||
""" | ||
super().__init__() | ||
self.save_hyperparameters() | ||
|
||
self.generator = self._get_generator() | ||
self.discriminator = self._get_discriminator() | ||
|
||
self.criterion = nn.BCELoss() | ||
|
||
def _get_generator(self) -> nn.Module: | ||
generator = DCGANGenerator(self.hparams.latent_dim, self.hparams.feature_maps_gen, self.hparams.image_channels) | ||
generator.apply(self._weights_init) | ||
return generator | ||
|
||
def _get_discriminator(self) -> nn.Module: | ||
discriminator = DCGANDiscriminator(self.hparams.feature_maps_disc, self.hparams.image_channels) | ||
discriminator.apply(self._weights_init) | ||
return discriminator | ||
|
||
@staticmethod | ||
def _weights_init(m): | ||
classname = m.__class__.__name__ | ||
if classname.find("Conv") != -1: | ||
torch.nn.init.normal_(m.weight, 0.0, 0.02) | ||
elif classname.find("BatchNorm") != -1: | ||
torch.nn.init.normal_(m.weight, 1.0, 0.02) | ||
torch.nn.init.zeros_(m.bias) | ||
|
||
def configure_optimizers(self): | ||
lr = self.hparams.learning_rate | ||
betas = (self.hparams.beta1, 0.999) | ||
opt_disc = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betas) | ||
opt_gen = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=betas) | ||
return [opt_disc, opt_gen], [] | ||
|
||
def forward(self, noise: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Generates an image given input noise | ||
Example:: | ||
noise = torch.rand(batch_size, latent_dim) | ||
gan = GAN.load_from_checkpoint(PATH) | ||
img = gan(noise) | ||
""" | ||
noise = noise.view(*noise.shape, 1, 1) | ||
return self.generator(noise) | ||
|
||
def training_step(self, batch, batch_idx, optimizer_idx): | ||
real, _ = batch | ||
|
||
# Train discriminator | ||
result = None | ||
if optimizer_idx == 0: | ||
result = self._disc_step(real) | ||
|
||
# Train generator | ||
if optimizer_idx == 1: | ||
result = self._gen_step(real) | ||
|
||
return result | ||
|
||
def _disc_step(self, real: torch.Tensor) -> torch.Tensor: | ||
disc_loss = self._get_disc_loss(real) | ||
self.log("loss/disc", disc_loss, on_epoch=True) | ||
return disc_loss | ||
|
||
def _gen_step(self, real: torch.Tensor) -> torch.Tensor: | ||
gen_loss = self._get_gen_loss(real) | ||
self.log("loss/gen", gen_loss, on_epoch=True) | ||
return gen_loss | ||
|
||
def _get_disc_loss(self, real: torch.Tensor) -> torch.Tensor: | ||
# Train with real | ||
real_pred = self.discriminator(real) | ||
real_gt = torch.ones_like(real_pred) | ||
real_loss = self.criterion(real_pred, real_gt) | ||
|
||
# Train with fake | ||
fake_pred = self._get_fake_pred(real) | ||
fake_gt = torch.zeros_like(fake_pred) | ||
fake_loss = self.criterion(fake_pred, fake_gt) | ||
|
||
disc_loss = real_loss + fake_loss | ||
|
||
return disc_loss | ||
|
||
def _get_gen_loss(self, real: torch.Tensor) -> torch.Tensor: | ||
# Train with fake | ||
fake_pred = self._get_fake_pred(real) | ||
fake_gt = torch.ones_like(fake_pred) | ||
gen_loss = self.criterion(fake_pred, fake_gt) | ||
|
||
return gen_loss | ||
|
||
def _get_fake_pred(self, real: torch.Tensor) -> torch.Tensor: | ||
batch_size = len(real) | ||
noise = self._get_noise(batch_size, self.hparams.latent_dim) | ||
fake = self(noise) | ||
fake_pred = self.discriminator(fake) | ||
|
||
return fake_pred | ||
|
||
def _get_noise(self, n_samples: int, latent_dim: int) -> torch.Tensor: | ||
return torch.randn(n_samples, latent_dim, device=self.device) | ||
|
||
@staticmethod | ||
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: | ||
parser = ArgumentParser(parents=[parent_parser], add_help=False) | ||
parser.add_argument("--beta1", default=0.5, type=float) | ||
parser.add_argument("--feature_maps_gen", default=64, type=int) | ||
parser.add_argument("--feature_maps_disc", default=64, type=int) | ||
parser.add_argument("--latent_dim", default=100, type=int) | ||
parser.add_argument("--learning_rate", default=0.0002, type=float) | ||
return parser | ||
|
||
|
||
def cli_main(args=None): | ||
pl.seed_everything(1234) | ||
|
||
parser = ArgumentParser() | ||
parser.add_argument("--batch_size", default=64, type=int) | ||
parser.add_argument("--dataset", default="mnist", type=str, choices=["lsun", "mnist"]) | ||
parser.add_argument("--data_dir", default="./", type=str) | ||
parser.add_argument("--image_size", default=64, type=int) | ||
parser.add_argument("--num_workers", default=8, type=int) | ||
|
||
script_args, _ = parser.parse_known_args(args) | ||
|
||
if script_args.dataset == "lsun": | ||
transforms = transform_lib.Compose([ | ||
transform_lib.Resize(script_args.image_size), | ||
transform_lib.CenterCrop(script_args.image_size), | ||
transform_lib.ToTensor(), | ||
transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
]) | ||
dataset = LSUN(root=script_args.data_dir, classes=["bedroom_train"], transform=transforms) | ||
image_channels = 3 | ||
elif script_args.dataset == "mnist": | ||
transforms = transform_lib.Compose([ | ||
transform_lib.Resize(script_args.image_size), | ||
transform_lib.ToTensor(), | ||
transform_lib.Normalize((0.5, ), (0.5, )), | ||
]) | ||
dataset = MNIST(root=script_args.data_dir, download=True, transform=transforms) | ||
image_channels = 1 | ||
|
||
dataloader = DataLoader( | ||
dataset, batch_size=script_args.batch_size, shuffle=True, num_workers=script_args.num_workers | ||
) | ||
|
||
parser = DCGAN.add_model_specific_args(parser) | ||
parser = pl.Trainer.add_argparse_args(parser) | ||
args = parser.parse_args(args) | ||
|
||
model = DCGAN(**vars(args), image_channels=image_channels) | ||
callbacks = [ | ||
TensorboardGenerativeModelImageSampler(num_samples=5), | ||
LatentDimInterpolator(interpolate_epoch_interval=5), | ||
] | ||
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks) | ||
trainer.fit(model, dataloader) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli_main() |
Oops, something went wrong.