This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-580] Add SN-GAN example #12419
Merged
Merged
Changes from 18 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
27019bb
update sn-gan example
stu1130 3d1601e
fix naming
stu1130 8113551
add more comments
stu1130 e31ca88
fix naming and refine comments
stu1130 b7e44c6
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
stu1130 0d1b80f
make power iteration as one hyperparameter
stu1130 4db5909
deal with divided by zero problem
stu1130 3b944d2
replace 0.00000001 with EPSILON
stu1130 b4ca8a3
refactor the example
stu1130 09e366a
add README
stu1130 4085edb
address the feedback
stu1130 a467c68
refine the composing
stu1130 c3db826
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
stu1130 dc982c3
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
stu1130 cc77f47
fix the typo, delete the redundant piece of code and update the resul…
stu1130 f90190e
update folder name to align with others
stu1130 bd7a8bc
update image name
stu1130 d7b54fc
add the variable back
stu1130 4757682
remove the redundant piece of code and fix typo
stu1130 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Spectral Normalization GAN | ||
|
||
This example implements [Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957) based on [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset. | ||
|
||
## Usage | ||
|
||
Example runs and the results: | ||
|
||
```python | ||
python train.py --use-gpu --data-path=data | ||
``` | ||
|
||
* Note that the program would download the CIFAR10 for you | ||
|
||
`python train.py --help` gives the following arguments: | ||
|
||
```bash | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
--data-path DATA_PATH | ||
path of data. | ||
--batch-size BATCH_SIZE | ||
training batch size. default is 64. | ||
--epochs EPOCHS number of training epochs. default is 100. | ||
--lr LR learning rate. default is 0.0001. | ||
--lr-beta LR_BETA learning rate for the beta in margin based loss. | ||
default is 0.5s. | ||
--use-gpu use gpu for training. | ||
--clip_gr CLIP_GR Clip the gradient by projecting onto the box. default | ||
is 10.0. | ||
--z-dim Z_DIM dimension of the latent z vector. default is 100. | ||
``` | ||
|
||
## Result | ||
|
||
![SN-GAN](sn_gan_output.png) | ||
|
||
## Learned Spectral Normalization | ||
|
||
![alt text](https://github.com/taki0112/Spectral_Normalization-Tensorflow/blob/master/assests/sn.png) | ||
|
||
## Reference | ||
|
||
[Simple Tensorflow Implementation](https://github.com/taki0112/Spectral_Normalization-Tensorflow) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# This example is inspired by https://github.com/jason71995/Keras-GAN-Library, | ||
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb | ||
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py | ||
|
||
import numpy as np | ||
|
||
import mxnet as mx | ||
from mxnet import gluon | ||
from mxnet.gluon.data.vision import CIFAR10 | ||
|
||
IMAGE_SIZE = 64 | ||
|
||
def transformer(data, label): | ||
""" data preparation """ | ||
data = mx.image.imresize(data, IMAGE_SIZE, IMAGE_SIZE) | ||
data = mx.nd.transpose(data, (2, 0, 1)) | ||
data = data.astype(np.float32) / 128.0 - 1 | ||
return data, label | ||
|
||
|
||
def get_training_data(batch_size): | ||
""" helper function to get dataloader""" | ||
return gluon.data.DataLoader( | ||
CIFAR10(train=True, transform=transformer), | ||
batch_size=batch_size, shuffle=True, last_batch='discard') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# This example is inspired by https://github.com/jason71995/Keras-GAN-Library, | ||
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb | ||
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py | ||
|
||
import mxnet as mx | ||
from mxnet import nd | ||
from mxnet import gluon | ||
from mxnet.gluon import Block | ||
|
||
|
||
EPSILON = 1e-08 | ||
POWER_ITERATION = 1 | ||
|
||
class SNConv2D(Block): | ||
""" Customized Conv2D to feed the conv with the weight that we apply spectral normalization """ | ||
|
||
def __init__(self, num_filter, kernel_size, | ||
strides, padding, in_channels, | ||
ctx=mx.cpu(), iterations=1): | ||
|
||
super(SNConv2D, self).__init__() | ||
|
||
self.num_filter = num_filter | ||
self.kernel_size = kernel_size | ||
self.strides = strides | ||
self.padding = padding | ||
self.in_channels = in_channels | ||
self.iterations = iterations | ||
self.ctx = ctx | ||
|
||
with self.name_scope(): | ||
# init the weight | ||
self.weight = self.params.get('weight', shape=( | ||
num_filter, in_channels, kernel_size, kernel_size)) | ||
self.u = self.params.get( | ||
'u', init=mx.init.Normal(), shape=(1, num_filter)) | ||
|
||
def _spectral_norm(self): | ||
""" spectral normalization """ | ||
w = self.params.get('weight').data(self.ctx) | ||
# the w preserve the original weight value to be used in line 75 | ||
w_mat = w | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. w is needed to be used for calculation later in the line 75 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. assignment does not create a copy, reshape creates a copy. you can use: |
||
w_mat = nd.reshape(w_mat, [w_mat.shape[0], -1]) | ||
|
||
_u = self.u.data(self.ctx) | ||
_v = None | ||
|
||
for _ in range(POWER_ITERATION): | ||
_v = nd.L2Normalization(nd.dot(_u, w_mat)) | ||
_u = nd.L2Normalization(nd.dot(_v, w_mat.T)) | ||
|
||
sigma = nd.sum(nd.dot(_u, w_mat) * _v) | ||
if sigma == 0.: | ||
sigma = EPSILON | ||
|
||
self.params.setattr('u', _u) | ||
|
||
return w / sigma | ||
|
||
def forward(self, x): | ||
# x shape is batch_size x in_channels x height x width | ||
return nd.Convolution( | ||
data=x, | ||
weight=self._spectral_norm(), | ||
kernel=(self.kernel_size, self.kernel_size), | ||
pad=(self.padding, self.padding), | ||
stride=(self.strides, self.strides), | ||
num_filter=self.num_filter, | ||
no_bias=True | ||
) | ||
|
||
|
||
def get_generator(): | ||
""" construct and return generator """ | ||
g_net = gluon.nn.Sequential() | ||
with g_net.name_scope(): | ||
|
||
g_net.add(gluon.nn.Conv2DTranspose( | ||
channels=512, kernel_size=4, strides=1, padding=0, use_bias=False)) | ||
g_net.add(gluon.nn.BatchNorm()) | ||
g_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
g_net.add(gluon.nn.Conv2DTranspose( | ||
channels=256, kernel_size=4, strides=2, padding=1, use_bias=False)) | ||
g_net.add(gluon.nn.BatchNorm()) | ||
g_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
g_net.add(gluon.nn.Conv2DTranspose( | ||
channels=128, kernel_size=4, strides=2, padding=1, use_bias=False)) | ||
g_net.add(gluon.nn.BatchNorm()) | ||
g_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
g_net.add(gluon.nn.Conv2DTranspose( | ||
channels=64, kernel_size=4, strides=2, padding=1, use_bias=False)) | ||
g_net.add(gluon.nn.BatchNorm()) | ||
g_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
g_net.add(gluon.nn.Conv2DTranspose(channels=3, kernel_size=4, strides=2, padding=1, use_bias=False)) | ||
g_net.add(gluon.nn.Activation('tanh')) | ||
|
||
return g_net | ||
|
||
|
||
def get_descriptor(ctx): | ||
""" construct and return descriptor """ | ||
d_net = gluon.nn.Sequential() | ||
with d_net.name_scope(): | ||
|
||
d_net.add(SNConv2D(num_filter=64, kernel_size=4, strides=2, padding=1, in_channels=3, ctx=ctx)) | ||
d_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
d_net.add(SNConv2D(num_filter=128, kernel_size=4, strides=2, padding=1, in_channels=64, ctx=ctx)) | ||
d_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
d_net.add(SNConv2D(num_filter=256, kernel_size=4, strides=2, padding=1, in_channels=128, ctx=ctx)) | ||
d_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
d_net.add(SNConv2D(num_filter=512, kernel_size=4, strides=2, padding=1, in_channels=256, ctx=ctx)) | ||
d_net.add(gluon.nn.LeakyReLU(0.2)) | ||
|
||
d_net.add(SNConv2D(num_filter=1, kernel_size=4, strides=1, padding=0, in_channels=512, ctx=ctx)) | ||
|
||
return d_net |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# This example is inspired by https://github.com/jason71995/Keras-GAN-Library, | ||
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb | ||
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py | ||
|
||
|
||
import os | ||
import random | ||
import logging | ||
import argparse | ||
|
||
from data import get_training_data | ||
from model import get_generator, get_descriptor | ||
from utils import save_image | ||
|
||
import mxnet as mx | ||
from mxnet import nd, autograd | ||
from mxnet import gluon | ||
|
||
# CLI | ||
parser = argparse.ArgumentParser( | ||
description='train a model for Spectral Normalization GAN.') | ||
parser.add_argument('--data-path', type=str, default='./data', | ||
help='path of data.') | ||
parser.add_argument('--batch-size', type=int, default=64, | ||
help='training batch size. default is 64.') | ||
parser.add_argument('--epochs', type=int, default=100, | ||
help='number of training epochs. default is 100.') | ||
parser.add_argument('--lr', type=float, default=0.0001, | ||
help='learning rate. default is 0.0001.') | ||
parser.add_argument('--lr-beta', type=float, default=0.5, | ||
help='learning rate for the beta in margin based loss. default is 0.5.') | ||
parser.add_argument('--use-gpu', action='store_true', | ||
help='use gpu for training.') | ||
parser.add_argument('--clip_gr', type=float, default=10.0, | ||
help='Clip the gradient by projecting onto the box. default is 10.0.') | ||
parser.add_argument('--z-dim', type=int, default=10, | ||
help='dimension of the latent z vector. default is 100.') | ||
opt = parser.parse_args() | ||
|
||
BATCH_SIZE = opt.batch_size | ||
Z_DIM = opt.z_dim | ||
NUM_EPOCHS = opt.epochs | ||
LEARNING_RATE = opt.lr | ||
BETA = opt.lr_beta | ||
OUTPUT_DIR = opt.data_path | ||
CTX = mx.gpu() if opt.use_gpu else mx.cpu() | ||
CLIP_GRADIENT = opt.clip_gr | ||
IMAGE_SIZE = 64 | ||
|
||
|
||
def facc(label, pred): | ||
""" evaluate accuracy """ | ||
pred = pred.ravel() | ||
label = label.ravel() | ||
return ((pred > 0.5) == label).mean() | ||
|
||
|
||
# setting | ||
mx.random.seed(random.randint(1, 10000)) | ||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
# create output dir | ||
try: | ||
os.makedirs(opt.data_path) | ||
except OSError: | ||
pass | ||
|
||
# get training data | ||
train_data = get_training_data(opt.batch_size) | ||
|
||
# get model | ||
g_net = get_generator() | ||
d_net = get_descriptor(CTX) | ||
|
||
# define loss function | ||
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() | ||
|
||
# initialization | ||
g_net.collect_params().initialize(mx.init.Xavier(), ctx=CTX) | ||
d_net.collect_params().initialize(mx.init.Xavier(), ctx=CTX) | ||
g_trainer = gluon.Trainer( | ||
g_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT}) | ||
d_trainer = gluon.Trainer( | ||
d_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT}) | ||
g_net.collect_params().zero_grad() | ||
d_net.collect_params().zero_grad() | ||
# define evaluation metric | ||
metric = mx.metric.CustomMetric(facc) | ||
# initialize labels | ||
real_label = nd.ones(BATCH_SIZE, CTX) | ||
fake_label = nd.zeros(BATCH_SIZE, CTX) | ||
|
||
for epoch in range(NUM_EPOCHS): | ||
for i, (d, _) in enumerate(train_data): | ||
# update D | ||
data = d.as_in_context(CTX) | ||
noise = nd.normal(loc=0, scale=1, shape=( | ||
BATCH_SIZE, Z_DIM, 1, 1), ctx=CTX) | ||
with autograd.record(): | ||
# train with real image | ||
output = d_net(data).reshape((-1, 1)) | ||
errD_real = loss(output, real_label) | ||
metric.update([real_label, ], [output, ]) | ||
|
||
# train with fake image | ||
fake_image = g_net(noise) | ||
output = d_net(fake_image.detach()).reshape((-1, 1)) | ||
errD_fake = loss(output, fake_label) | ||
errD = errD_real + errD_fake | ||
errD.backward() | ||
metric.update([fake_label, ], [output, ]) | ||
|
||
d_trainer.step(BATCH_SIZE) | ||
# update G | ||
with autograd.record(): | ||
fake_image = g_net(noise) | ||
output = d_net(fake_image).reshape(-1, 1) | ||
errG = loss(output, real_label) | ||
errG.backward() | ||
|
||
g_trainer.step(BATCH_SIZE) | ||
|
||
# print log infomation every 100 batches | ||
if i % 100 == 0: | ||
name, acc = metric.get() | ||
logging.info('discriminator loss = %f, generator loss = %f, \ | ||
binary training acc = %f at iter %d epoch %d', | ||
nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc, i, epoch) | ||
if i == 0: | ||
save_image(fake_image, epoch, IMAGE_SIZE, BATCH_SIZE, OUTPUT_DIR) | ||
|
||
metric.reset() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the
s
is also here still