Skip to content

Commit

Permalink
Merge pull request #4 from hanzhanggit/dev
Browse files Browse the repository at this point in the history
add sagan
  • Loading branch information
DoctorTeeth authored Oct 9, 2018
2 parents a3a72cf + 7a8590e commit ad9612e
Show file tree
Hide file tree
Showing 19 changed files with 528 additions and 1,188 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.idea/
__pycache__/
.DS_Store

51 changes: 46 additions & 5 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,47 @@
# Spectral Normalization and projection discriminator for Generative Adversarial Networks
# Self-Attention GAN
Tensorflow implementation for reproducing main results in the paper [Self-Attention Generative Adversarial Networks](https://arxiv.org/abs/1805.08318) by Han Zhang, Ian Goodfellow, Dimitris Metaxas, Augustus Odena.

* Implementation of these papers:
* Spectral Normalization. https://openreview.net/pdf?id=B1QRgziT-
* Projection Discriminator. https://openreview.net/pdf?id=ByS1VpgRZ
* Reference Chainer code: https://github.com/pfnet-research/sngan_projection
<img src="imgs/img1.png"/>


### Dependencies
python 3.6

TensorFlow 1.5


**Data**

Download Imagenet dataset and preprocess the images into tfrecord files as instructed in [improved gan](https://github.com/openai/improved-gan/blob/master/imagenet/convert_imagenet_to_records.py). Put the tfrecord files into ./data


**Training**

The current batch size is 64x4=256. Larger batch size seems to give better performance. But it might need to find new hyperparameters for G&D learning rate. Note: It usually takes several weeks to train one million steps.

CUDA_VISIBLE_DEVICES=0,1,2,3 python train_imagenet.py --generator_type test --discriminator_type test --data_dir ./data

**Evaluation**

CUDA_VISIBLE_DEVICES=4 python eval_imagenet.py --generator_type test --data_dir ./data

### Citing Self-attention GAN
If you find Self-attention GAN is useful in your research, please consider citing:

```
@article{Han18,
author = {Han Zhang and
Ian J. Goodfellow and
Dimitris N. Metaxas and
Augustus Odena},
title = {Self-Attention Generative Adversarial Networks},
year = {2018},
journal = {arXiv:1805.08318},
}
```

**References**

- Spectral Normalization for Generative Adversarial Networks [Paper](https://arxiv.org/abs/1802.05957)
- cGANs with Projection Discriminator [Paper](https://arxiv.org/abs/1802.05637)
- Non-local Neural Networks [Paper](https://arxiv.org/abs/1711.07971)
118 changes: 116 additions & 2 deletions discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""The discriminator of SNGAN."""

import tensorflow as tf
import ops
import non_local


def dsample(x):
Expand Down Expand Up @@ -92,7 +93,7 @@ def optimized_block(x, out_channels, name,
return x + x_0


def discriminator(image, labels, df_dim, number_classes, update_collection=None,
def discriminator_old(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu, scope='Discriminator'):
"""Builds the discriminator graph.
Expand Down Expand Up @@ -126,4 +127,117 @@ def discriminator(image, labels, df_dim, number_classes, update_collection=None,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
return output


def discriminator(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu):
"""Builds the discriminator graph.
Args:
image: The current batch of images to classify as fake or real.
labels: The corresponding labels for the images.
df_dim: The df dimension.
number_classes: The number of classes in the labels.
update_collection: The update collections used in the
spectral_normed_weight.
act: The activation function used in the discriminator.
scope: Optional scope for `variable_op_scope`.
Returns:
A `Tensor` representing the logits of the discriminator.
"""
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
h0 = optimized_block(image, df_dim, 'd_optimized_block1',
update_collection, act=act) # 64 * 64
h1 = block(h0, df_dim * 2, 'd_block2',
update_collection, act=act) # 32 * 32
h2 = block(h1, df_dim * 4, 'd_block3',
update_collection, act=act) # 16 * 16
h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
h5_act = act(h5)
h6 = tf.reduce_sum(h5_act, [1, 2])
output = ops.snlinear(h6, 1, update_collection=update_collection,
name='d_sn_linear')
h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
print('Discriminator Structure')
return output

def discriminator_test(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu):
"""Builds the discriminator graph.
Args:
image: The current batch of images to classify as fake or real.
labels: The corresponding labels for the images.
df_dim: The df dimension.
number_classes: The number of classes in the labels.
update_collection: The update collections used in the
spectral_normed_weight.
act: The activation function used in the discriminator.
scope: Optional scope for `variable_op_scope`.
Returns:
A `Tensor` representing the logits of the discriminator.
"""
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
h0 = optimized_block(image, df_dim, 'd_optimized_block1',
update_collection, act=act) # 64 * 64
h1 = block(h0, df_dim * 2, 'd_block2',
update_collection, act=act) # 32 * 32
h1 = non_local.sn_non_local_block_sim(h1, update_collection, name='d_non_local') # 32 * 32
h2 = block(h1, df_dim * 4, 'd_block3',
update_collection, act=act) # 16 * 16
h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
h5_act = act(h5)
h6 = tf.reduce_sum(h5_act, [1, 2])
output = ops.snlinear(h6, 1, update_collection=update_collection,
name='d_sn_linear')
h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
print('Discriminator Test Structure')
return output

def discriminator_test_64(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu):
"""Builds the discriminator graph.
Args:
image: The current batch of images to classify as fake or real.
labels: The corresponding labels for the images.
df_dim: The df dimension.
number_classes: The number of classes in the labels.
update_collection: The update collections used in the
spectral_normed_weight.
act: The activation function used in the discriminator.
scope: Optional scope for `variable_op_scope`.
Returns:
A `Tensor` representing the logits of the discriminator.
"""
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
h0 = optimized_block(image, df_dim, 'd_optimized_block1',
update_collection, act=act) # 64 * 64
h0 = non_local.sn_non_local_block_sim(h0, update_collection, name='d_non_local') # 64 * 64
h1 = block(h0, df_dim * 2, 'd_block2',
update_collection, act=act) # 32 * 32
h2 = block(h1, df_dim * 4, 'd_block3',
update_collection, act=act) # 16 * 16
h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
h5_act = act(h5)
h6 = tf.reduce_sum(h5_act, [1, 2])
output = ops.snlinear(h6, 1, update_collection=update_collection,
name='d_sn_linear')
h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
return output
79 changes: 38 additions & 41 deletions eval_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Generic train."""
from __future__ import absolute_import
from __future__ import division
Expand All @@ -33,7 +34,8 @@


flags.DEFINE_string(
'data_dir', '/home/zhanghan/Data',
# 'data_dir', '/gpu/hz138/Data/imagenet', #'/home/hz138/Data/imagenet',
'data_dir', '/bigdata1/hz138/Data/imagenet',
'Directory with Imagenet input data as sharded recordio files of pre-'
'processed images.')
flags.DEFINE_integer('z_dim', 128, 'The dimension of z')
Expand All @@ -48,18 +50,17 @@
'image samples. [sample]')
flags.DEFINE_string('eval_dir', 'checkpoint/eval', 'Directory name to save the '
'eval summaries . [eval]')
flags.DEFINE_integer('batch_size', 32, 'Batch size of samples to feed into '
flags.DEFINE_integer('batch_size', 64, 'Batch size of samples to feed into '
'Inception models for evaluation. [16]')
flags.DEFINE_integer('shuffle_buffer_size', 5000, 'Number of records to load '
'before shuffling and yielding for consumption. [5000]')
flags.DEFINE_integer('dcgan_generator_batch_size', 100, 'Size of batch to feed '
'into generator -- we may stack multiple of these later.')
flags.DEFINE_integer('eval_sample_size', 1024,
flags.DEFINE_integer('eval_sample_size', 50000,
'Number of samples to sample from '
'generator and real data. [1024]')
flags.DEFINE_boolean('is_train', False, 'Use DCGAN only for evaluation.')
# TODO(olganw) Find the best way to clean up these flags for eval and train.
# These values need to be the same as they are in the training job.

flags.DEFINE_integer('task', 0, 'The task id of the current worker. [0]')
flags.DEFINE_integer('ps_tasks', 0, 'The number of ps tasks. [0]')
flags.DEFINE_integer('num_workers', 1, 'The number of worker tasks. [1]')
Expand All @@ -72,15 +73,17 @@
'and Frechet Inception Distance. [300]')

flags.DEFINE_integer('num_classes', 1000, 'The number of classes in the dataset')
flags.DEFINE_string('generator_type', 'baseline', 'test or baseline')
flags.DEFINE_string('generator_type', 'test', 'test or baseline')

FLAGS = flags.FLAGS


def main(_):
model_dir = '%s_%s' % ('imagenet', FLAGS.batch_size)
FLAGS.eval_dir = FLAGS.checkpoint_dir + '/eval'
checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, model_dir)
log_dir = os.path.join(FLAGS.eval_dir, model_dir)
print('log_dir', log_dir)
graph_def = None # pylint: disable=protected-access

# Batch size to feed batches of images through Inception and the generator
Expand All @@ -105,31 +108,29 @@ def main(_):
label_offset=-1,
shuffle_buffer_size=FLAGS.shuffle_buffer_size)

# Uniform distribution
# TODO(goodfellow) Use true distribution of ImageNet classses
num_classes = FLAGS.num_classes
gen_class_logits = tf.zeros((local_batch_size, num_classes))
gen_class_ints = tf.multinomial(gen_class_logits, 1)
gen_sparse_class = tf.squeeze(gen_class_ints)


with tf.variable_scope('model'):

# Generate the first batch of generated images and extract activations;
# this bootstraps the while_loop with a pools and logits tensor.


test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim)
generator = generator_fn(
test_zs[0],
gen_sparse_class,
FLAGS.gf_dim,
FLAGS.num_classes)
test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim)
generator = generator_fn(
test_zs[0],
gen_sparse_class,
FLAGS.gf_dim,
FLAGS.num_classes,
is_training=False)



pools, logits = utils.run_custom_inception(
generator, output_tensor=['pool_3:0', 'logits:0'], graph_def=graph_def)
pools, logits = utils.run_custom_inception(
generator, output_tensor=['pool_3:0', 'logits:0'], graph_def=graph_def)

# Set up while_loop to compute activations of generated images from generator.
def while_cond(g_pools, g_logits, i): # pylint: disable=unused-argument
Expand All @@ -144,24 +145,23 @@ def while_body(g_pools, g_logits, i):

test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim)
# Uniform distribution
# TODO(goodfellow) Use true distribution of ImageNet classses
gen_class_logits = tf.zeros((local_batch_size, num_classes))
gen_class_ints = tf.multinomial(gen_class_logits, 1)
gen_sparse_class = tf.squeeze(gen_class_ints)

with tf.variable_scope('model'):
generator = generator_fn(
test_zs[0],
gen_sparse_class,
FLAGS.gf_dim,
FLAGS.num_classes)
generator = generator_fn(
test_zs[0],
gen_sparse_class,
FLAGS.gf_dim,
FLAGS.num_classes,
is_training=False)

pools, logits = utils.run_custom_inception(
generator,
output_tensor=['pool_3:0', 'logits:0'],
graph_def=graph_def)
g_pools = tf.concat([g_pools, pools], 0)
g_logits = tf.concat([g_logits, logits], 0)
pools, logits = utils.run_custom_inception(
generator,
output_tensor=['pool_3:0', 'logits:0'],
graph_def=graph_def)
g_pools = tf.concat([g_pools, pools], 0)
g_logits = tf.concat([g_logits, logits], 0)

return (g_pools, g_logits, tf.add(i, 1))

Expand All @@ -183,25 +183,22 @@ def while_body(g_pools, g_logits, i):
new_generator_pools_list.set_shape([FLAGS.eval_sample_size, 2048])
new_generator_logits_list.set_shape([FLAGS.eval_sample_size, 1008])

# TODO(sbhupatiraju) Why is FID negative?

# Get a small batch of samples from generator to dispaly in TensorBoard
vis_batch_size = 16
eval_vis_zs = utils.make_z_normal(
1, vis_batch_size, FLAGS.z_dim)
# Uniform distribution
# TODO(goodfellow) Use true distribution of ImageNet classses

gen_class_logits_vis = tf.zeros((vis_batch_size, num_classes))
gen_class_ints_vis = tf.multinomial(gen_class_logits_vis, 1)
gen_sparse_class_vis = tf.squeeze(gen_class_ints_vis)

with tf.variable_scope('model'):
eval_vis_images = generator_fn(
eval_vis_zs[0],
gen_sparse_class_vis,
FLAGS.gf_dim,
FLAGS.num_classes
)
eval_vis_images = generator_fn(
eval_vis_zs[0],
gen_sparse_class_vis,
FLAGS.gf_dim,
FLAGS.num_classes,
is_training=False
)
eval_vis_images = tf.cast((eval_vis_images + 1.) * 127.5, tf.uint8)

with tf.variable_scope('eval_vis'):
Expand Down
Loading

0 comments on commit ad9612e

Please sign in to comment.