Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sagan #4

Merged
merged 2 commits into from
Oct 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.idea/
__pycache__/
.DS_Store

51 changes: 46 additions & 5 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,47 @@
# Spectral Normalization and projection discriminator for Generative Adversarial Networks
# Self-Attention GAN
Tensorflow implementation for reproducing main results in the paper [Self-Attention Generative Adversarial Networks](https://arxiv.org/abs/1805.08318) by Han Zhang, Ian Goodfellow, Dimitris Metaxas, Augustus Odena.

* Implementation of these papers:
* Spectral Normalization. https://openreview.net/pdf?id=B1QRgziT-
* Projection Discriminator. https://openreview.net/pdf?id=ByS1VpgRZ
* Reference Chainer code: https://github.com/pfnet-research/sngan_projection
<img src="imgs/img1.png"/>


### Dependencies
python 3.6

TensorFlow 1.5


**Data**

Download Imagenet dataset and preprocess the images into tfrecord files as instructed in [improved gan](https://github.com/openai/improved-gan/blob/master/imagenet/convert_imagenet_to_records.py). Put the tfrecord files into ./data


**Training**

The current batch size is 64x4=256. Larger batch size seems to give better performance. But it might need to find new hyperparameters for G&D learning rate. Note: It usually takes several weeks to train one million steps.

CUDA_VISIBLE_DEVICES=0,1,2,3 python train_imagenet.py --generator_type test --discriminator_type test --data_dir ./data

**Evaluation**

CUDA_VISIBLE_DEVICES=4 python eval_imagenet.py --generator_type test --data_dir ./data

### Citing Self-attention GAN
If you find Self-attention GAN is useful in your research, please consider citing:

```
@article{Han18,
author = {Han Zhang and
Ian J. Goodfellow and
Dimitris N. Metaxas and
Augustus Odena},
title = {Self-Attention Generative Adversarial Networks},
year = {2018},
journal = {arXiv:1805.08318},
}
```

**References**

- Spectral Normalization for Generative Adversarial Networks [Paper](https://arxiv.org/abs/1802.05957)
- cGANs with Projection Discriminator [Paper](https://arxiv.org/abs/1802.05637)
- Non-local Neural Networks [Paper](https://arxiv.org/abs/1711.07971)
118 changes: 116 additions & 2 deletions discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""The discriminator of SNGAN."""

import tensorflow as tf
import ops
import non_local


def dsample(x):
Expand Down Expand Up @@ -92,7 +93,7 @@ def optimized_block(x, out_channels, name,
return x + x_0


def discriminator(image, labels, df_dim, number_classes, update_collection=None,
def discriminator_old(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu, scope='Discriminator'):
"""Builds the discriminator graph.

Expand Down Expand Up @@ -126,4 +127,117 @@ def discriminator(image, labels, df_dim, number_classes, update_collection=None,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
return output


def discriminator(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu):
"""Builds the discriminator graph.

Args:
image: The current batch of images to classify as fake or real.
labels: The corresponding labels for the images.
df_dim: The df dimension.
number_classes: The number of classes in the labels.
update_collection: The update collections used in the
spectral_normed_weight.
act: The activation function used in the discriminator.
scope: Optional scope for `variable_op_scope`.
Returns:
A `Tensor` representing the logits of the discriminator.
"""
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
h0 = optimized_block(image, df_dim, 'd_optimized_block1',
update_collection, act=act) # 64 * 64
h1 = block(h0, df_dim * 2, 'd_block2',
update_collection, act=act) # 32 * 32
h2 = block(h1, df_dim * 4, 'd_block3',
update_collection, act=act) # 16 * 16
h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
h5_act = act(h5)
h6 = tf.reduce_sum(h5_act, [1, 2])
output = ops.snlinear(h6, 1, update_collection=update_collection,
name='d_sn_linear')
h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
print('Discriminator Structure')
return output

def discriminator_test(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu):
"""Builds the discriminator graph.

Args:
image: The current batch of images to classify as fake or real.
labels: The corresponding labels for the images.
df_dim: The df dimension.
number_classes: The number of classes in the labels.
update_collection: The update collections used in the
spectral_normed_weight.
act: The activation function used in the discriminator.
scope: Optional scope for `variable_op_scope`.
Returns:
A `Tensor` representing the logits of the discriminator.
"""
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
h0 = optimized_block(image, df_dim, 'd_optimized_block1',
update_collection, act=act) # 64 * 64
h1 = block(h0, df_dim * 2, 'd_block2',
update_collection, act=act) # 32 * 32
h1 = non_local.sn_non_local_block_sim(h1, update_collection, name='d_non_local') # 32 * 32
h2 = block(h1, df_dim * 4, 'd_block3',
update_collection, act=act) # 16 * 16
h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
h5_act = act(h5)
h6 = tf.reduce_sum(h5_act, [1, 2])
output = ops.snlinear(h6, 1, update_collection=update_collection,
name='d_sn_linear')
h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
print('Discriminator Test Structure')
return output

def discriminator_test_64(image, labels, df_dim, number_classes, update_collection=None,
act=tf.nn.relu):
"""Builds the discriminator graph.

Args:
image: The current batch of images to classify as fake or real.
labels: The corresponding labels for the images.
df_dim: The df dimension.
number_classes: The number of classes in the labels.
update_collection: The update collections used in the
spectral_normed_weight.
act: The activation function used in the discriminator.
scope: Optional scope for `variable_op_scope`.
Returns:
A `Tensor` representing the logits of the discriminator.
"""
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
h0 = optimized_block(image, df_dim, 'd_optimized_block1',
update_collection, act=act) # 64 * 64
h0 = non_local.sn_non_local_block_sim(h0, update_collection, name='d_non_local') # 64 * 64
h1 = block(h0, df_dim * 2, 'd_block2',
update_collection, act=act) # 32 * 32
h2 = block(h1, df_dim * 4, 'd_block3',
update_collection, act=act) # 16 * 16
h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
h5_act = act(h5)
h6 = tf.reduce_sum(h5_act, [1, 2])
output = ops.snlinear(h6, 1, update_collection=update_collection,
name='d_sn_linear')
h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
update_collection=update_collection,
name='d_embedding')
output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
return output
79 changes: 38 additions & 41 deletions eval_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Generic train."""
from __future__ import absolute_import
from __future__ import division
Expand All @@ -33,7 +34,8 @@


flags.DEFINE_string(
'data_dir', '/home/zhanghan/Data',
# 'data_dir', '/gpu/hz138/Data/imagenet', #'/home/hz138/Data/imagenet',
'data_dir', '/bigdata1/hz138/Data/imagenet',
'Directory with Imagenet input data as sharded recordio files of pre-'
'processed images.')
flags.DEFINE_integer('z_dim', 128, 'The dimension of z')
Expand All @@ -48,18 +50,17 @@
'image samples. [sample]')
flags.DEFINE_string('eval_dir', 'checkpoint/eval', 'Directory name to save the '
'eval summaries . [eval]')
flags.DEFINE_integer('batch_size', 32, 'Batch size of samples to feed into '
flags.DEFINE_integer('batch_size', 64, 'Batch size of samples to feed into '
'Inception models for evaluation. [16]')
flags.DEFINE_integer('shuffle_buffer_size', 5000, 'Number of records to load '
'before shuffling and yielding for consumption. [5000]')
flags.DEFINE_integer('dcgan_generator_batch_size', 100, 'Size of batch to feed '
'into generator -- we may stack multiple of these later.')
flags.DEFINE_integer('eval_sample_size', 1024,
flags.DEFINE_integer('eval_sample_size', 50000,
'Number of samples to sample from '
'generator and real data. [1024]')
flags.DEFINE_boolean('is_train', False, 'Use DCGAN only for evaluation.')
# TODO(olganw) Find the best way to clean up these flags for eval and train.
# These values need to be the same as they are in the training job.

flags.DEFINE_integer('task', 0, 'The task id of the current worker. [0]')
flags.DEFINE_integer('ps_tasks', 0, 'The number of ps tasks. [0]')
flags.DEFINE_integer('num_workers', 1, 'The number of worker tasks. [1]')
Expand All @@ -72,15 +73,17 @@
'and Frechet Inception Distance. [300]')

flags.DEFINE_integer('num_classes', 1000, 'The number of classes in the dataset')
flags.DEFINE_string('generator_type', 'baseline', 'test or baseline')
flags.DEFINE_string('generator_type', 'test', 'test or baseline')

FLAGS = flags.FLAGS


def main(_):
model_dir = '%s_%s' % ('imagenet', FLAGS.batch_size)
FLAGS.eval_dir = FLAGS.checkpoint_dir + '/eval'
checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, model_dir)
log_dir = os.path.join(FLAGS.eval_dir, model_dir)
print('log_dir', log_dir)
graph_def = None # pylint: disable=protected-access

# Batch size to feed batches of images through Inception and the generator
Expand All @@ -105,31 +108,29 @@ def main(_):
label_offset=-1,
shuffle_buffer_size=FLAGS.shuffle_buffer_size)

# Uniform distribution
# TODO(goodfellow) Use true distribution of ImageNet classses
num_classes = FLAGS.num_classes
gen_class_logits = tf.zeros((local_batch_size, num_classes))
gen_class_ints = tf.multinomial(gen_class_logits, 1)
gen_sparse_class = tf.squeeze(gen_class_ints)


with tf.variable_scope('model'):

# Generate the first batch of generated images and extract activations;
# this bootstraps the while_loop with a pools and logits tensor.


test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim)
generator = generator_fn(
test_zs[0],
gen_sparse_class,
FLAGS.gf_dim,
FLAGS.num_classes)
test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim)
generator = generator_fn(
test_zs[0],
gen_sparse_class,
FLAGS.gf_dim,
FLAGS.num_classes,
is_training=False)



pools, logits = utils.run_custom_inception(
generator, output_tensor=['pool_3:0', 'logits:0'], graph_def=graph_def)
pools, logits = utils.run_custom_inception(
generator, output_tensor=['pool_3:0', 'logits:0'], graph_def=graph_def)

# Set up while_loop to compute activations of generated images from generator.
def while_cond(g_pools, g_logits, i): # pylint: disable=unused-argument
Expand All @@ -144,24 +145,23 @@ def while_body(g_pools, g_logits, i):

test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim)
# Uniform distribution
# TODO(goodfellow) Use true distribution of ImageNet classses
gen_class_logits = tf.zeros((local_batch_size, num_classes))
gen_class_ints = tf.multinomial(gen_class_logits, 1)
gen_sparse_class = tf.squeeze(gen_class_ints)

with tf.variable_scope('model'):
generator = generator_fn(
test_zs[0],
gen_sparse_class,
FLAGS.gf_dim,
FLAGS.num_classes)
generator = generator_fn(
test_zs[0],
gen_sparse_class,
FLAGS.gf_dim,
FLAGS.num_classes,
is_training=False)

pools, logits = utils.run_custom_inception(
generator,
output_tensor=['pool_3:0', 'logits:0'],
graph_def=graph_def)
g_pools = tf.concat([g_pools, pools], 0)
g_logits = tf.concat([g_logits, logits], 0)
pools, logits = utils.run_custom_inception(
generator,
output_tensor=['pool_3:0', 'logits:0'],
graph_def=graph_def)
g_pools = tf.concat([g_pools, pools], 0)
g_logits = tf.concat([g_logits, logits], 0)

return (g_pools, g_logits, tf.add(i, 1))

Expand All @@ -183,25 +183,22 @@ def while_body(g_pools, g_logits, i):
new_generator_pools_list.set_shape([FLAGS.eval_sample_size, 2048])
new_generator_logits_list.set_shape([FLAGS.eval_sample_size, 1008])

# TODO(sbhupatiraju) Why is FID negative?

# Get a small batch of samples from generator to dispaly in TensorBoard
vis_batch_size = 16
eval_vis_zs = utils.make_z_normal(
1, vis_batch_size, FLAGS.z_dim)
# Uniform distribution
# TODO(goodfellow) Use true distribution of ImageNet classses

gen_class_logits_vis = tf.zeros((vis_batch_size, num_classes))
gen_class_ints_vis = tf.multinomial(gen_class_logits_vis, 1)
gen_sparse_class_vis = tf.squeeze(gen_class_ints_vis)

with tf.variable_scope('model'):
eval_vis_images = generator_fn(
eval_vis_zs[0],
gen_sparse_class_vis,
FLAGS.gf_dim,
FLAGS.num_classes
)
eval_vis_images = generator_fn(
eval_vis_zs[0],
gen_sparse_class_vis,
FLAGS.gf_dim,
FLAGS.num_classes,
is_training=False
)
eval_vis_images = tf.cast((eval_vis_images + 1.) * 127.5, tf.uint8)

with tf.variable_scope('eval_vis'):
Expand Down
Loading