Skip to content

Commit

Permalink
add sagan
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Zhang authored and Han Zhang committed Oct 7, 2018
1 parent a3a72cf commit 7702dc5
Show file tree
Hide file tree
Showing 21 changed files with 560 additions and 1,355 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.idea/
__pycache__/
.DS_Store

51 changes: 46 additions & 5 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,47 @@
# Spectral Normalization and projection discriminator for Generative Adversarial Networks
# Self-Attention GAN
Tensorflow implementation for reproducing main results in the paper [Self-Attention Generative Adversarial Networks](https://arxiv.org/abs/1805.08318) by Han Zhang, Ian Goodfellow, Dimitris Metaxas, Augustus Odena.

* Implementation of these papers:
* Spectral Normalization. https://openreview.net/pdf?id=B1QRgziT-
* Projection Discriminator. https://openreview.net/pdf?id=ByS1VpgRZ
* Reference Chainer code: https://github.com/pfnet-research/sngan_projection
<img src="imgs/img1.png"/>


### Dependencies
python 3.6

TensorFlow 1.5


**Data**

Download Imagenet dataset and preprocess the images into tfrecord files as instructed in [improved gan](https://github.com/openai/improved-gan/blob/master/imagenet/convert_imagenet_to_records.py). Put the tfrecord files into ./data


**Training**

The current batch size is 64x4=256. Larger batch size seems to give better performance. But it might need to find new hyperparameters for G&D learning rate. Note: It usually takes several weeks to train one million steps.

CUDA_VISIBLE_DEVICES=0,1,2,3 python train_imagenet.py --generator_type test --discriminator_type test --data_dir ./data

**Evaluation**

CUDA_VISIBLE_DEVICES=4 python eval_imagenet.py --generator_type test --data_dir ./data

### Citing Self-attention GAN
If you find Self-attention GAN is useful in your research, please consider citing:

```
@article{Han18,
author = {Han Zhang and
Ian J. Goodfellow and
Dimitris N. Metaxas and
Augustus Odena},
title = {Self-Attention Generative Adversarial Networks},
year = {2018},
journal = {arXiv:1805.08318},
}
```

**References**

- Spectral Normalization for Generative Adversarial Networks [Paper](https://arxiv.org/abs/1802.05957)
- cGANs with Projection Discriminator [Paper](https://arxiv.org/abs/1802.05637)
- Non-local Neural Networks [Paper](https://arxiv.org/abs/1711.07971)
15 changes: 0 additions & 15 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,3 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import tensorflow as tf
import os
IMAGE_SIZE=128
Expand Down
131 changes: 115 additions & 16 deletions discriminator.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""The discriminator of SNGAN."""
import tensorflow as tf
import ops
import non_local


def dsample(x):
Expand Down Expand Up @@ -92,7 +78,7 @@ def optimized_block(x, out_channels, name,
return x + x_0


def discriminator(image, labels, df_dim, number_classes, update_collection=None,
def discriminator_old(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu, scope='Discriminator'):
"""Builds the discriminator graph.
Expand Down Expand Up @@ -126,4 +112,117 @@ def discriminator(image, labels, df_dim, number_classes, update_collection=None,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
return output


def discriminator(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu):
"""Builds the discriminator graph.
Args:
image: The current batch of images to classify as fake or real.
labels: The corresponding labels for the images.
df_dim: The df dimension.
number_classes: The number of classes in the labels.
update_collection: The update collections used in the
spectral_normed_weight.
act: The activation function used in the discriminator.
scope: Optional scope for `variable_op_scope`.
Returns:
A `Tensor` representing the logits of the discriminator.
"""
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
h0 = optimized_block(image, df_dim, 'd_optimized_block1',
update_collection, act=act) # 64 * 64
h1 = block(h0, df_dim * 2, 'd_block2',
update_collection, act=act) # 32 * 32
h2 = block(h1, df_dim * 4, 'd_block3',
update_collection, act=act) # 16 * 16
h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
h5_act = act(h5)
h6 = tf.reduce_sum(h5_act, [1, 2])
output = ops.snlinear(h6, 1, update_collection=update_collection,
name='d_sn_linear')
h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
print('Discriminator Structure')
return output

def discriminator_test(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu):
"""Builds the discriminator graph.
Args:
image: The current batch of images to classify as fake or real.
labels: The corresponding labels for the images.
df_dim: The df dimension.
number_classes: The number of classes in the labels.
update_collection: The update collections used in the
spectral_normed_weight.
act: The activation function used in the discriminator.
scope: Optional scope for `variable_op_scope`.
Returns:
A `Tensor` representing the logits of the discriminator.
"""
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
h0 = optimized_block(image, df_dim, 'd_optimized_block1',
update_collection, act=act) # 64 * 64
h1 = block(h0, df_dim * 2, 'd_block2',
update_collection, act=act) # 32 * 32
h1 = non_local.sn_non_local_block_sim(h1, update_collection, name='d_non_local') # 32 * 32
h2 = block(h1, df_dim * 4, 'd_block3',
update_collection, act=act) # 16 * 16
h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
h5_act = act(h5)
h6 = tf.reduce_sum(h5_act, [1, 2])
output = ops.snlinear(h6, 1, update_collection=update_collection,
name='d_sn_linear')
h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
print('Discriminator Test Structure')
return output

def discriminator_test_64(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu):
"""Builds the discriminator graph.
Args:
image: The current batch of images to classify as fake or real.
labels: The corresponding labels for the images.
df_dim: The df dimension.
number_classes: The number of classes in the labels.
update_collection: The update collections used in the
spectral_normed_weight.
act: The activation function used in the discriminator.
scope: Optional scope for `variable_op_scope`.
Returns:
A `Tensor` representing the logits of the discriminator.
"""
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
h0 = optimized_block(image, df_dim, 'd_optimized_block1',
update_collection, act=act) # 64 * 64
h0 = non_local.sn_non_local_block_sim(h0, update_collection, name='d_non_local') # 64 * 64
h1 = block(h0, df_dim * 2, 'd_block2',
update_collection, act=act) # 32 * 32
h2 = block(h1, df_dim * 4, 'd_block3',
update_collection, act=act) # 16 * 16
h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
h5_act = act(h5)
h6 = tf.reduce_sum(h5_act, [1, 2])
output = ops.snlinear(h6, 1, update_collection=update_collection,
name='d_sn_linear')
h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
return output
Loading

0 comments on commit 7702dc5

Please sign in to comment.