-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-580] Add SN-GAN example #12419
Changes from 10 commits
27019bb
3d1601e
8113551
e31ca88
b7e44c6
0d1b80f
4db5909
3b944d2
b4ca8a3
09e366a
4085edb
a467c68
c3db826
dc982c3
cc77f47
f90190e
bd7a8bc
d7b54fc
4757682
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Spectral Normalization GAN | ||
|
||
This example implements [Spectral Normalization for Generative Adversarial Networks](https://openreview.net/pdf?id=B1QRgziT-) based on [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please refer to arxiv - https://arxiv.org/abs/1802.05957 |
||
|
||
## Usage | ||
|
||
Example runs and the results: | ||
|
||
```python | ||
python train.py --use-gpu --data-path=data/CIFAR10 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Write a note that user needs to download CIFAR10 dataset |
||
``` | ||
|
||
`python train.py --help` gives the following arguments: | ||
|
||
```bash | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
--data-path DATA_PATH | ||
path of data. | ||
--batch-size BATCH_SIZE | ||
training batch size. default is 64. | ||
--epochs EPOCHS number of training epochs. default is 100. | ||
--lr LR learning rate. default is 0.0001. | ||
--lr-beta LR_BETA learning rate for the beta in margin based loss. | ||
default is 0.5s. | ||
--use-gpu use gpu for training. | ||
--clip_gr CLIP_GR Clip the gradient by projecting onto the box. default | ||
is 10.0. | ||
--z-dim Z_DIM dimension of the latent z vector. default is 100. | ||
``` | ||
|
||
## Learned Spectral Normalization | ||
|
||
![alt text](https://github.com/taki0112/Spectral_Normalization-Tensorflow/blob/master/assests/sn.png) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a few samples of the generated images? It always makes it more appealing for people looking to try new models. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am running the model with Xavier initializer and will update the image if it's better |
||
## Reference | ||
|
||
[Simple Tensorflow Implementation](https://github.com/taki0112/Spectral_Normalization-Tensorflow) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# This example is inspired by https://github.com/jason71995/Keras-GAN-Library, | ||
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb | ||
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py | ||
|
||
import numpy as np | ||
|
||
import mxnet as mx | ||
from mxnet import gluon | ||
from mxnet.gluon.data.vision import CIFAR10 | ||
|
||
IMAGE_SIZE = 64 | ||
|
||
def transformer(data, label): | ||
""" data preparation """ | ||
data = mx.image.imresize(data, IMAGE_SIZE, IMAGE_SIZE) | ||
data = mx.nd.transpose(data, (2, 0, 1)) | ||
data = data.astype(np.float32) / 128.0 - 1 | ||
return data, label | ||
|
||
|
||
def get_training_data(batch_size): | ||
""" hepler function to get dataloader""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. helper |
||
return gluon.data.DataLoader( | ||
CIFAR10(train=True, transform=transformer), | ||
batch_size=batch_size, shuffle=True, last_batch='discard') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# This example is inspired by https://github.com/jason71995/Keras-GAN-Library, | ||
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb | ||
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py | ||
|
||
import mxnet as mx | ||
from mxnet import nd | ||
from mxnet import gluon | ||
from mxnet.gluon import Block | ||
|
||
|
||
EPSILON = 1e-08 | ||
POWER_ITERATION = 1 | ||
|
||
class SNConv2D(Block): | ||
""" Customized Conv2D to feed the conv with the weight that we apply spectral normalization """ | ||
|
||
def __init__(self, num_filter, kernel_size, | ||
strides, padding, in_channels, | ||
ctx, iterations=1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. set default ctx=mx.cpu() |
||
|
||
super(SNConv2D, self).__init__() | ||
|
||
self.num_filter = num_filter | ||
self.kernel_size = kernel_size | ||
self.strides = strides | ||
self.padding = padding | ||
self.in_channels = in_channels | ||
self.iterations = iterations | ||
self.ctx = ctx | ||
|
||
with self.name_scope(): | ||
# init the weight | ||
self.weight = self.params.get('weight', shape=( | ||
num_filter, in_channels, kernel_size, kernel_size)) | ||
self.u = self.params.get( | ||
'u', init=mx.init.Normal(), shape=(1, num_filter)) | ||
|
||
def spectral_norm(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest using |
||
""" spectral normalization """ | ||
w = self.params.get('weight').data(self.ctx) | ||
w_mat = w | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is necessary, you can simply use |
||
w_mat = nd.reshape(w_mat, [w_mat.shape[0], -1]) | ||
|
||
_u = self.u.data(self.ctx) | ||
_v = None | ||
|
||
for _ in range(POWER_ITERATION): | ||
_v = nd.L2Normalization(nd.dot(_u, w_mat)) | ||
_u = nd.L2Normalization(nd.dot(_v, w_mat.T)) | ||
|
||
sigma = nd.sum(nd.dot(_u, w_mat) * _v) | ||
if sigma == 0.: | ||
sigma = EPSILON | ||
|
||
self.params.setattr('u', _u) | ||
|
||
return w / sigma | ||
|
||
def forward(self, x): | ||
# x shape is batch_size x in_channels x height x width | ||
return nd.Convolution( | ||
data=x, | ||
weight=self.spectral_norm(), | ||
kernel=(self.kernel_size, self.kernel_size), | ||
pad=(self.padding, self.padding), | ||
stride=(self.strides, self.strides), | ||
num_filter=self.num_filter, | ||
no_bias=True | ||
) | ||
|
||
|
||
def get_generator(): | ||
""" construct and return generator """ | ||
g_net = gluon.nn.Sequential() | ||
with g_net.name_scope(): | ||
|
||
g_net.add(gluon.nn.Conv2DTranspose(512, 4, 1, 0, use_bias=False)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is always more clear for readers in these kind of examples to have named parameters for these layers. |
||
g_net.add(gluon.nn.BatchNorm()) | ||
g_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
g_net.add(gluon.nn.Conv2DTranspose(256, 4, 2, 1, use_bias=False)) | ||
g_net.add(gluon.nn.BatchNorm()) | ||
g_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
g_net.add(gluon.nn.Conv2DTranspose(128, 4, 2, 1, use_bias=False)) | ||
g_net.add(gluon.nn.BatchNorm()) | ||
g_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
g_net.add(gluon.nn.Conv2DTranspose(64, 4, 2, 1, use_bias=False)) | ||
g_net.add(gluon.nn.BatchNorm()) | ||
g_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
g_net.add(gluon.nn.Conv2DTranspose(3, 4, 2, 1, use_bias=False)) | ||
g_net.add(gluon.nn.Activation('tanh')) | ||
|
||
return g_net | ||
|
||
|
||
def get_descriptor(ctx): | ||
""" construct and return descriptor """ | ||
d_net = gluon.nn.Sequential() | ||
with d_net.name_scope(): | ||
|
||
d_net.add(SNConv2D(64, 4, 2, 1, 3, ctx)) | ||
d_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
d_net.add(SNConv2D(128, 4, 2, 1, 64, ctx)) | ||
d_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
d_net.add(SNConv2D(256, 4, 2, 1, 128, ctx)) | ||
d_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
d_net.add(SNConv2D(512, 4, 2, 1, 256, ctx)) | ||
d_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
d_net.add(SNConv2D(1, 4, 1, 0, 512, ctx)) | ||
|
||
return d_net |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# This example is inspired by https://github.com/jason71995/Keras-GAN-Library, | ||
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb | ||
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py | ||
|
||
|
||
import os | ||
import random | ||
import logging | ||
import argparse | ||
|
||
from data import get_training_data | ||
from model import get_generator, get_descriptor | ||
from utils import save_image | ||
|
||
import mxnet as mx | ||
from mxnet import nd, autograd | ||
from mxnet import gluon | ||
|
||
# CLI | ||
parser = argparse.ArgumentParser( | ||
description='train a model for Spectral Normalization GAN.') | ||
parser.add_argument('--data-path', type=str, default='./data', | ||
help='path of data.') | ||
parser.add_argument('--batch-size', type=int, default=64, | ||
help='training batch size. default is 64.') | ||
parser.add_argument('--epochs', type=int, default=100, | ||
help='number of training epochs. default is 100.') | ||
parser.add_argument('--lr', type=float, default=0.0001, | ||
help='learning rate. default is 0.0001.') | ||
parser.add_argument('--lr-beta', type=float, default=0.5, | ||
help='learning rate for the beta in margin based loss. default is 0.5s.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 0.5s ? what does |
||
parser.add_argument('--use-gpu', action='store_true', | ||
help='use gpu for training.') | ||
parser.add_argument('--clip_gr', type=float, default=10.0, | ||
help='Clip the gradient by projecting onto the box. default is 10.0.') | ||
parser.add_argument('--z-dim', type=int, default=10, | ||
help='dimension of the latent z vector. default is 100.') | ||
opt = parser.parse_args() | ||
|
||
BATCH_SIZE = opt.batch_size | ||
Z_DIM = opt.z_dim | ||
NUM_EPOCHS = opt.epochs | ||
LEARNING_RATE = opt.lr | ||
BETA = opt.lr_beta | ||
OUTPUT_DIR = opt.data_path | ||
CTX = mx.gpu() if opt.use_gpu else mx.cpu() | ||
CLIP_GRADIENT = opt.clip_gr | ||
IMAGE_SIZE = 64 | ||
|
||
|
||
def facc(label, pred): | ||
""" evaluate accuracy """ | ||
pred = pred.ravel() | ||
label = label.ravel() | ||
return ((pred > 0.5) == label).mean() | ||
|
||
|
||
# setting | ||
mx.random.seed(random.randint(1, 10000)) | ||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
# create output dir | ||
try: | ||
os.makedirs(opt.data_path) | ||
except OSError: | ||
pass | ||
|
||
# get training data | ||
train_data = get_training_data(opt.batch_size) | ||
|
||
# get model | ||
g_net = get_generator() | ||
d_net = get_descriptor(CTX) | ||
|
||
# define loss function | ||
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() | ||
|
||
# initialization | ||
g_net.collect_params().initialize(mx.init.Normal(0.02), ctx=CTX) | ||
d_net.collect_params().initialize(mx.init.Normal(0.02), ctx=CTX) | ||
g_trainer = gluon.Trainer( | ||
g_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT}) | ||
d_trainer = gluon.Trainer( | ||
d_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT}) | ||
g_net.collect_params().zero_grad() | ||
d_net.collect_params().zero_grad() | ||
# define evaluation metric | ||
metric = mx.metric.CustomMetric(facc) | ||
# initialize labels | ||
real_label = nd.ones(BATCH_SIZE, CTX) | ||
fake_label = nd.zeros(BATCH_SIZE, CTX) | ||
|
||
for epoch in range(NUM_EPOCHS): | ||
for i, (d, _) in enumerate(train_data): | ||
# update D | ||
data = d.as_in_context(CTX) | ||
noise = nd.normal(loc=0, scale=1, shape=( | ||
BATCH_SIZE, Z_DIM, 1, 1), ctx=CTX) | ||
with autograd.record(): | ||
# train with real image | ||
output = d_net(data).reshape((-1, 1)) | ||
errD_real = loss(output, real_label) | ||
metric.update([real_label, ], [output, ]) | ||
|
||
# train with fake image | ||
fake_image = g_net(noise) | ||
output = d_net(fake_image.detach()).reshape((-1, 1)) | ||
errD_fake = loss(output, fake_label) | ||
errD = errD_real + errD_fake | ||
errD.backward() | ||
metric.update([fake_label, ], [output, ]) | ||
|
||
d_trainer.step(BATCH_SIZE) | ||
# update G | ||
with autograd.record(): | ||
fake_image = g_net(noise) | ||
output = d_net(fake_image).reshape(-1, 1) | ||
errG = loss(output, real_label) | ||
errG.backward() | ||
|
||
g_trainer.step(BATCH_SIZE) | ||
|
||
# print log infomation every 100 batches | ||
if i % 100 == 0: | ||
name, acc = metric.get() | ||
logging.info('discriminator loss = %f, generator loss = %f, \ | ||
binary training acc = %f at iter %d epoch %d', | ||
nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc, i, epoch) | ||
if i == 0: | ||
save_image(fake_image, epoch, IMAGE_SIZE, BATCH_SIZE, OUTPUT_DIR) | ||
|
||
metric.reset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update https://github.com/apache/incubator-mxnet/tree/master/example#deep-learning-examples-in-the-mxnet-project-repository with your example, thanks!