Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-580] Add SN-GAN example #12419

Merged
merged 19 commits into from
Sep 12, 2018
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions example/gluon/sn-gan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Spectral Normalization GAN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


This example implements [Spectral Normalization for Generative Adversarial Networks](https://openreview.net/pdf?id=B1QRgziT-) based on [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to arxiv - https://arxiv.org/abs/1802.05957


## Usage

Example runs and the results:

```python
python train.py --use-gpu --data-path=data/CIFAR10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Write a note that user needs to download CIFAR10 dataset

```

`python train.py --help` gives the following arguments:

```bash
optional arguments:
-h, --help show this help message and exit
--data-path DATA_PATH
path of data.
--batch-size BATCH_SIZE
training batch size. default is 64.
--epochs EPOCHS number of training epochs. default is 100.
--lr LR learning rate. default is 0.0001.
--lr-beta LR_BETA learning rate for the beta in margin based loss.
default is 0.5s.
--use-gpu use gpu for training.
--clip_gr CLIP_GR Clip the gradient by projecting onto the box. default
is 10.0.
--z-dim Z_DIM dimension of the latent z vector. default is 100.
```

## Learned Spectral Normalization

![alt text](https://github.com/taki0112/Spectral_Normalization-Tensorflow/blob/master/assests/sn.png)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a few samples of the generated images? It always makes it more appealing for people looking to try new models. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am running the model with Xavier initializer and will update the image if it's better

## Reference

[Simple Tensorflow Implementation](https://github.com/taki0112/Spectral_Normalization-Tensorflow)
42 changes: 42 additions & 0 deletions example/gluon/sn-gan/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py

import numpy as np

import mxnet as mx
from mxnet import gluon
from mxnet.gluon.data.vision import CIFAR10

IMAGE_SIZE = 64

def transformer(data, label):
""" data preparation """
data = mx.image.imresize(data, IMAGE_SIZE, IMAGE_SIZE)
data = mx.nd.transpose(data, (2, 0, 1))
data = data.astype(np.float32) / 128.0 - 1
return data, label


def get_training_data(batch_size):
""" hepler function to get dataloader"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

helper

return gluon.data.DataLoader(
CIFAR10(train=True, transform=transformer),
batch_size=batch_size, shuffle=True, last_batch='discard')
135 changes: 135 additions & 0 deletions example/gluon/sn-gan/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py

import mxnet as mx
from mxnet import nd
from mxnet import gluon
from mxnet.gluon import Block


EPSILON = 1e-08
POWER_ITERATION = 1

class SNConv2D(Block):
""" Customized Conv2D to feed the conv with the weight that we apply spectral normalization """

def __init__(self, num_filter, kernel_size,
strides, padding, in_channels,
ctx, iterations=1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set default ctx=mx.cpu()


super(SNConv2D, self).__init__()

self.num_filter = num_filter
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
self.in_channels = in_channels
self.iterations = iterations
self.ctx = ctx

with self.name_scope():
# init the weight
self.weight = self.params.get('weight', shape=(
num_filter, in_channels, kernel_size, kernel_size))
self.u = self.params.get(
'u', init=mx.init.Normal(), shape=(1, num_filter))

def spectral_norm(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest using _spectral_norm(self) as it is a private function

""" spectral normalization """
w = self.params.get('weight').data(self.ctx)
w_mat = w
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary, you can simply use w in your nd.reshape(w_mat..

w_mat = nd.reshape(w_mat, [w_mat.shape[0], -1])

_u = self.u.data(self.ctx)
_v = None

for _ in range(POWER_ITERATION):
_v = nd.L2Normalization(nd.dot(_u, w_mat))
_u = nd.L2Normalization(nd.dot(_v, w_mat.T))

sigma = nd.sum(nd.dot(_u, w_mat) * _v)
if sigma == 0.:
sigma = EPSILON

self.params.setattr('u', _u)

return w / sigma

def forward(self, x):
# x shape is batch_size x in_channels x height x width
return nd.Convolution(
data=x,
weight=self.spectral_norm(),
kernel=(self.kernel_size, self.kernel_size),
pad=(self.padding, self.padding),
stride=(self.strides, self.strides),
num_filter=self.num_filter,
no_bias=True
)


def get_generator():
""" construct and return generator """
g_net = gluon.nn.Sequential()
with g_net.name_scope():

g_net.add(gluon.nn.Conv2DTranspose(512, 4, 1, 0, use_bias=False))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is always more clear for readers in these kind of examples to have named parameters for these layers.

g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(256, 4, 2, 1, use_bias=False))
g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(128, 4, 2, 1, use_bias=False))
g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(64, 4, 2, 1, use_bias=False))
g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(3, 4, 2, 1, use_bias=False))
g_net.add(gluon.nn.Activation('tanh'))

return g_net


def get_descriptor(ctx):
""" construct and return descriptor """
d_net = gluon.nn.Sequential()
with d_net.name_scope():

d_net.add(SNConv2D(64, 4, 2, 1, 3, ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(128, 4, 2, 1, 64, ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(256, 4, 2, 1, 128, ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(512, 4, 2, 1, 256, ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(1, 4, 1, 0, 512, ctx))

return d_net
149 changes: 149 additions & 0 deletions example/gluon/sn-gan/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py


import os
import random
import logging
import argparse

from data import get_training_data
from model import get_generator, get_descriptor
from utils import save_image

import mxnet as mx
from mxnet import nd, autograd
from mxnet import gluon

# CLI
parser = argparse.ArgumentParser(
description='train a model for Spectral Normalization GAN.')
parser.add_argument('--data-path', type=str, default='./data',
help='path of data.')
parser.add_argument('--batch-size', type=int, default=64,
help='training batch size. default is 64.')
parser.add_argument('--epochs', type=int, default=100,
help='number of training epochs. default is 100.')
parser.add_argument('--lr', type=float, default=0.0001,
help='learning rate. default is 0.0001.')
parser.add_argument('--lr-beta', type=float, default=0.5,
help='learning rate for the beta in margin based loss. default is 0.5s.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.5s ? what does s stand for here?

parser.add_argument('--use-gpu', action='store_true',
help='use gpu for training.')
parser.add_argument('--clip_gr', type=float, default=10.0,
help='Clip the gradient by projecting onto the box. default is 10.0.')
parser.add_argument('--z-dim', type=int, default=10,
help='dimension of the latent z vector. default is 100.')
opt = parser.parse_args()

BATCH_SIZE = opt.batch_size
Z_DIM = opt.z_dim
NUM_EPOCHS = opt.epochs
LEARNING_RATE = opt.lr
BETA = opt.lr_beta
OUTPUT_DIR = opt.data_path
CTX = mx.gpu() if opt.use_gpu else mx.cpu()
CLIP_GRADIENT = opt.clip_gr
IMAGE_SIZE = 64


def facc(label, pred):
""" evaluate accuracy """
pred = pred.ravel()
label = label.ravel()
return ((pred > 0.5) == label).mean()


# setting
mx.random.seed(random.randint(1, 10000))
logging.basicConfig(level=logging.DEBUG)

# create output dir
try:
os.makedirs(opt.data_path)
except OSError:
pass

# get training data
train_data = get_training_data(opt.batch_size)

# get model
g_net = get_generator()
d_net = get_descriptor(CTX)

# define loss function
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()

# initialization
g_net.collect_params().initialize(mx.init.Normal(0.02), ctx=CTX)
d_net.collect_params().initialize(mx.init.Normal(0.02), ctx=CTX)
g_trainer = gluon.Trainer(
g_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT})
d_trainer = gluon.Trainer(
d_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT})
g_net.collect_params().zero_grad()
d_net.collect_params().zero_grad()
# define evaluation metric
metric = mx.metric.CustomMetric(facc)
# initialize labels
real_label = nd.ones(BATCH_SIZE, CTX)
fake_label = nd.zeros(BATCH_SIZE, CTX)

for epoch in range(NUM_EPOCHS):
for i, (d, _) in enumerate(train_data):
# update D
data = d.as_in_context(CTX)
noise = nd.normal(loc=0, scale=1, shape=(
BATCH_SIZE, Z_DIM, 1, 1), ctx=CTX)
with autograd.record():
# train with real image
output = d_net(data).reshape((-1, 1))
errD_real = loss(output, real_label)
metric.update([real_label, ], [output, ])

# train with fake image
fake_image = g_net(noise)
output = d_net(fake_image.detach()).reshape((-1, 1))
errD_fake = loss(output, fake_label)
errD = errD_real + errD_fake
errD.backward()
metric.update([fake_label, ], [output, ])

d_trainer.step(BATCH_SIZE)
# update G
with autograd.record():
fake_image = g_net(noise)
output = d_net(fake_image).reshape(-1, 1)
errG = loss(output, real_label)
errG.backward()

g_trainer.step(BATCH_SIZE)

# print log infomation every 100 batches
if i % 100 == 0:
name, acc = metric.get()
logging.info('discriminator loss = %f, generator loss = %f, \
binary training acc = %f at iter %d epoch %d',
nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc, i, epoch)
if i == 0:
save_image(fake_image, epoch, IMAGE_SIZE, BATCH_SIZE, OUTPUT_DIR)

metric.reset()
Loading