diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..47990a0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.idea/ +__pycache__/ +.DS_Store + diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 470b4f4..a69dee8 --- a/README.md +++ b/README.md @@ -1,6 +1,47 @@ -# Spectral Normalization and projection discriminator for Generative Adversarial Networks +# Self-Attention GAN +Tensorflow implementation for reproducing main results in the paper [Self-Attention Generative Adversarial Networks](https://arxiv.org/abs/1805.08318) by Han Zhang, Ian Goodfellow, Dimitris Metaxas, Augustus Odena. -* Implementation of these papers: -* Spectral Normalization. https://openreview.net/pdf?id=B1QRgziT- -* Projection Discriminator. https://openreview.net/pdf?id=ByS1VpgRZ -* Reference Chainer code: https://github.com/pfnet-research/sngan_projection + + + +### Dependencies +python 3.6 + +TensorFlow 1.5 + + +**Data** + +Download Imagenet dataset and preprocess the images into tfrecord files as instructed in [improved gan](https://github.com/openai/improved-gan/blob/master/imagenet/convert_imagenet_to_records.py). Put the tfrecord files into ./data + + +**Training** + +The current batch size is 64x4=256. Larger batch size seems to give better performance. But it might need to find new hyperparameters for G&D learning rate. Note: It usually takes several weeks to train one million steps. + +CUDA_VISIBLE_DEVICES=0,1,2,3 python train_imagenet.py --generator_type test --discriminator_type test --data_dir ./data + +**Evaluation** + +CUDA_VISIBLE_DEVICES=4 python eval_imagenet.py --generator_type test --data_dir ./data + +### Citing Self-attention GAN +If you find Self-attention GAN is useful in your research, please consider citing: + +``` +@article{Han18, + author = {Han Zhang and + Ian J. Goodfellow and + Dimitris N. Metaxas and + Augustus Odena}, + title = {Self-Attention Generative Adversarial Networks}, + year = {2018}, + journal = {arXiv:1805.08318}, +} +``` + +**References** + +- Spectral Normalization for Generative Adversarial Networks [Paper](https://arxiv.org/abs/1802.05957) +- cGANs with Projection Discriminator [Paper](https://arxiv.org/abs/1802.05637) +- Non-local Neural Networks [Paper](https://arxiv.org/abs/1711.07971) diff --git a/discriminator.py b/discriminator.py index 58de422..55e218f 100644 --- a/discriminator.py +++ b/discriminator.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """The discriminator of SNGAN.""" + import tensorflow as tf import ops +import non_local def dsample(x): @@ -92,7 +93,7 @@ def optimized_block(x, out_channels, name, return x + x_0 -def discriminator(image, labels, df_dim, number_classes, update_collection=None, +def discriminator_old(image, labels, df_dim, number_classes, update_collection=None, act=tf.nn.relu, scope='Discriminator'): """Builds the discriminator graph. @@ -126,4 +127,117 @@ def discriminator(image, labels, df_dim, number_classes, update_collection=None, update_collection=update_collection, name='d_embedding') output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True) + return output + + +def discriminator(image, labels, df_dim, number_classes, update_collection=None, + act=tf.nn.relu): + """Builds the discriminator graph. + + Args: + image: The current batch of images to classify as fake or real. + labels: The corresponding labels for the images. + df_dim: The df dimension. + number_classes: The number of classes in the labels. + update_collection: The update collections used in the + spectral_normed_weight. + act: The activation function used in the discriminator. + scope: Optional scope for `variable_op_scope`. + Returns: + A `Tensor` representing the logits of the discriminator. + """ + with tf.variable_scope('model', reuse=tf.AUTO_REUSE): + h0 = optimized_block(image, df_dim, 'd_optimized_block1', + update_collection, act=act) # 64 * 64 + h1 = block(h0, df_dim * 2, 'd_block2', + update_collection, act=act) # 32 * 32 + h2 = block(h1, df_dim * 4, 'd_block3', + update_collection, act=act) # 16 * 16 + h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8 + h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4 + h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act) + h5_act = act(h5) + h6 = tf.reduce_sum(h5_act, [1, 2]) + output = ops.snlinear(h6, 1, update_collection=update_collection, + name='d_sn_linear') + h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16, + update_collection=update_collection, + name='d_embedding') + output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True) + print('Discriminator Structure') + return output + +def discriminator_test(image, labels, df_dim, number_classes, update_collection=None, + act=tf.nn.relu): + """Builds the discriminator graph. + + Args: + image: The current batch of images to classify as fake or real. + labels: The corresponding labels for the images. + df_dim: The df dimension. + number_classes: The number of classes in the labels. + update_collection: The update collections used in the + spectral_normed_weight. + act: The activation function used in the discriminator. + scope: Optional scope for `variable_op_scope`. + Returns: + A `Tensor` representing the logits of the discriminator. + """ + with tf.variable_scope('model', reuse=tf.AUTO_REUSE): + h0 = optimized_block(image, df_dim, 'd_optimized_block1', + update_collection, act=act) # 64 * 64 + h1 = block(h0, df_dim * 2, 'd_block2', + update_collection, act=act) # 32 * 32 + h1 = non_local.sn_non_local_block_sim(h1, update_collection, name='d_non_local') # 32 * 32 + h2 = block(h1, df_dim * 4, 'd_block3', + update_collection, act=act) # 16 * 16 + h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8 + h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4 + h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act) + h5_act = act(h5) + h6 = tf.reduce_sum(h5_act, [1, 2]) + output = ops.snlinear(h6, 1, update_collection=update_collection, + name='d_sn_linear') + h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16, + update_collection=update_collection, + name='d_embedding') + output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True) + print('Discriminator Test Structure') + return output + +def discriminator_test_64(image, labels, df_dim, number_classes, update_collection=None, + act=tf.nn.relu): + """Builds the discriminator graph. + + Args: + image: The current batch of images to classify as fake or real. + labels: The corresponding labels for the images. + df_dim: The df dimension. + number_classes: The number of classes in the labels. + update_collection: The update collections used in the + spectral_normed_weight. + act: The activation function used in the discriminator. + scope: Optional scope for `variable_op_scope`. + Returns: + A `Tensor` representing the logits of the discriminator. + """ + with tf.variable_scope('model', reuse=tf.AUTO_REUSE): + h0 = optimized_block(image, df_dim, 'd_optimized_block1', + update_collection, act=act) # 64 * 64 + h0 = non_local.sn_non_local_block_sim(h0, update_collection, name='d_non_local') # 64 * 64 + h1 = block(h0, df_dim * 2, 'd_block2', + update_collection, act=act) # 32 * 32 + h2 = block(h1, df_dim * 4, 'd_block3', + update_collection, act=act) # 16 * 16 + h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8 + h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4 + h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act) + h5_act = act(h5) + h6 = tf.reduce_sum(h5_act, [1, 2]) + output = ops.snlinear(h6, 1, update_collection=update_collection, + name='d_sn_linear') + h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16, + update_collection=update_collection, + name='d_embedding') + output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True) return output \ No newline at end of file diff --git a/eval_imagenet.py b/eval_imagenet.py index 5723181..e708f9c 100644 --- a/eval_imagenet.py +++ b/eval_imagenet.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + """Generic train.""" from __future__ import absolute_import from __future__ import division @@ -33,7 +34,8 @@ flags.DEFINE_string( - 'data_dir', '/home/zhanghan/Data', + # 'data_dir', '/gpu/hz138/Data/imagenet', #'/home/hz138/Data/imagenet', + 'data_dir', '/bigdata1/hz138/Data/imagenet', 'Directory with Imagenet input data as sharded recordio files of pre-' 'processed images.') flags.DEFINE_integer('z_dim', 128, 'The dimension of z') @@ -48,18 +50,17 @@ 'image samples. [sample]') flags.DEFINE_string('eval_dir', 'checkpoint/eval', 'Directory name to save the ' 'eval summaries . [eval]') -flags.DEFINE_integer('batch_size', 32, 'Batch size of samples to feed into ' +flags.DEFINE_integer('batch_size', 64, 'Batch size of samples to feed into ' 'Inception models for evaluation. [16]') flags.DEFINE_integer('shuffle_buffer_size', 5000, 'Number of records to load ' 'before shuffling and yielding for consumption. [5000]') flags.DEFINE_integer('dcgan_generator_batch_size', 100, 'Size of batch to feed ' 'into generator -- we may stack multiple of these later.') -flags.DEFINE_integer('eval_sample_size', 1024, +flags.DEFINE_integer('eval_sample_size', 50000, 'Number of samples to sample from ' 'generator and real data. [1024]') flags.DEFINE_boolean('is_train', False, 'Use DCGAN only for evaluation.') -# TODO(olganw) Find the best way to clean up these flags for eval and train. -# These values need to be the same as they are in the training job. + flags.DEFINE_integer('task', 0, 'The task id of the current worker. [0]') flags.DEFINE_integer('ps_tasks', 0, 'The number of ps tasks. [0]') flags.DEFINE_integer('num_workers', 1, 'The number of worker tasks. [1]') @@ -72,15 +73,17 @@ 'and Frechet Inception Distance. [300]') flags.DEFINE_integer('num_classes', 1000, 'The number of classes in the dataset') -flags.DEFINE_string('generator_type', 'baseline', 'test or baseline') +flags.DEFINE_string('generator_type', 'test', 'test or baseline') FLAGS = flags.FLAGS def main(_): model_dir = '%s_%s' % ('imagenet', FLAGS.batch_size) + FLAGS.eval_dir = FLAGS.checkpoint_dir + '/eval' checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, model_dir) log_dir = os.path.join(FLAGS.eval_dir, model_dir) + print('log_dir', log_dir) graph_def = None # pylint: disable=protected-access # Batch size to feed batches of images through Inception and the generator @@ -105,31 +108,29 @@ def main(_): label_offset=-1, shuffle_buffer_size=FLAGS.shuffle_buffer_size) - # Uniform distribution - # TODO(goodfellow) Use true distribution of ImageNet classses num_classes = FLAGS.num_classes gen_class_logits = tf.zeros((local_batch_size, num_classes)) gen_class_ints = tf.multinomial(gen_class_logits, 1) gen_sparse_class = tf.squeeze(gen_class_ints) - with tf.variable_scope('model'): # Generate the first batch of generated images and extract activations; # this bootstraps the while_loop with a pools and logits tensor. - test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim) - generator = generator_fn( - test_zs[0], - gen_sparse_class, - FLAGS.gf_dim, - FLAGS.num_classes) + test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim) + generator = generator_fn( + test_zs[0], + gen_sparse_class, + FLAGS.gf_dim, + FLAGS.num_classes, + is_training=False) - pools, logits = utils.run_custom_inception( - generator, output_tensor=['pool_3:0', 'logits:0'], graph_def=graph_def) + pools, logits = utils.run_custom_inception( + generator, output_tensor=['pool_3:0', 'logits:0'], graph_def=graph_def) # Set up while_loop to compute activations of generated images from generator. def while_cond(g_pools, g_logits, i): # pylint: disable=unused-argument @@ -144,24 +145,23 @@ def while_body(g_pools, g_logits, i): test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim) # Uniform distribution - # TODO(goodfellow) Use true distribution of ImageNet classses gen_class_logits = tf.zeros((local_batch_size, num_classes)) gen_class_ints = tf.multinomial(gen_class_logits, 1) gen_sparse_class = tf.squeeze(gen_class_ints) - with tf.variable_scope('model'): - generator = generator_fn( - test_zs[0], - gen_sparse_class, - FLAGS.gf_dim, - FLAGS.num_classes) + generator = generator_fn( + test_zs[0], + gen_sparse_class, + FLAGS.gf_dim, + FLAGS.num_classes, + is_training=False) - pools, logits = utils.run_custom_inception( - generator, - output_tensor=['pool_3:0', 'logits:0'], - graph_def=graph_def) - g_pools = tf.concat([g_pools, pools], 0) - g_logits = tf.concat([g_logits, logits], 0) + pools, logits = utils.run_custom_inception( + generator, + output_tensor=['pool_3:0', 'logits:0'], + graph_def=graph_def) + g_pools = tf.concat([g_pools, pools], 0) + g_logits = tf.concat([g_logits, logits], 0) return (g_pools, g_logits, tf.add(i, 1)) @@ -183,25 +183,22 @@ def while_body(g_pools, g_logits, i): new_generator_pools_list.set_shape([FLAGS.eval_sample_size, 2048]) new_generator_logits_list.set_shape([FLAGS.eval_sample_size, 1008]) - # TODO(sbhupatiraju) Why is FID negative? - # Get a small batch of samples from generator to dispaly in TensorBoard vis_batch_size = 16 eval_vis_zs = utils.make_z_normal( 1, vis_batch_size, FLAGS.z_dim) - # Uniform distribution - # TODO(goodfellow) Use true distribution of ImageNet classses + gen_class_logits_vis = tf.zeros((vis_batch_size, num_classes)) gen_class_ints_vis = tf.multinomial(gen_class_logits_vis, 1) gen_sparse_class_vis = tf.squeeze(gen_class_ints_vis) - with tf.variable_scope('model'): - eval_vis_images = generator_fn( - eval_vis_zs[0], - gen_sparse_class_vis, - FLAGS.gf_dim, - FLAGS.num_classes - ) + eval_vis_images = generator_fn( + eval_vis_zs[0], + gen_sparse_class_vis, + FLAGS.gf_dim, + FLAGS.num_classes, + is_training=False + ) eval_vis_images = tf.cast((eval_vis_images + 1.) * 127.5, tf.uint8) with tf.variable_scope('eval_vis'): diff --git a/generator.py b/generator.py index c49e7ce..c357540 100644 --- a/generator.py +++ b/generator.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + """The generator of SNGAN.""" + import tensorflow as tf import ops +import non_local def upscale(x, n): @@ -42,7 +45,7 @@ def usample(x): x = tf.image.resize_nearest_neighbor(x, [nh * 2, nw * 2]) return x -def block(x, labels, out_channels, num_classes, name): +def block_no_sn(x, labels, out_channels, num_classes, is_training, name): """Builds the residual blocks used in the generator. Compared with block, optimized_block always downsamples the spatial resolution @@ -61,10 +64,10 @@ def block(x, labels, out_channels, num_classes, name): bn0 = ops.ConditionalBatchNorm(num_classes, name='cbn_0') bn1 = ops.ConditionalBatchNorm(num_classes, name='cbn_1') x_0 = x - x = tf.nn.relu(bn0(x, labels)) + x = tf.nn.relu(bn0(x, labels, is_training)) x = usample(x) x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv1') - x = tf.nn.relu(bn1(x, labels)) + x = tf.nn.relu(bn1(x, labels, is_training)) x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv2') x_0 = usample(x_0) @@ -72,11 +75,28 @@ def block(x, labels, out_channels, num_classes, name): return x_0 + x +def block(x, labels, out_channels, num_classes, is_training, name): + with tf.variable_scope(name): + bn0 = ops.ConditionalBatchNorm(num_classes, name='cbn_0') + bn1 = ops.ConditionalBatchNorm(num_classes, name='cbn_1') + x_0 = x + x = tf.nn.relu(bn0(x, labels, is_training)) + x = usample(x) + x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='snconv1') + x = tf.nn.relu(bn1(x, labels, is_training)) + x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='snconv2') -def generator(zs, + x_0 = usample(x_0) + x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, name='snconv3') + + return x_0 + x + + +def generator_old(zs, target_class, gf_dim, num_classes, + is_training=True, scope='Generator'): """Builds the generator graph propagating from z to x. @@ -93,22 +113,151 @@ def generator(zs, with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # project `z` and reshape - act0 = ops.linear(zs, gf_dim * 16 * 4 * 4, 'g_h0') + act0 = ops.linear(zs, gf_dim * 16 * 4 * 4, scope='g_h0') + act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) + + act1 = block_no_sn(act0, target_class, gf_dim * 16, + num_classes, is_training, 'g_block1') # 8 * 8 + act2 = block_no_sn(act1, target_class, gf_dim * 8, + num_classes, is_training, 'g_block2') # 16 * 16 + act3 = block_no_sn(act2, target_class, gf_dim * 4, + num_classes, is_training, 'g_block3') # 32 * 32 + act4 = block_no_sn(act3, target_class, gf_dim * 2, + num_classes, is_training, 'g_block4') # 64 * 64 + act5 = block_no_sn(act4, target_class, gf_dim, + num_classes, is_training, 'g_block5') # 128 * 128 + bn = ops.batch_norm(name='g_bn') + + act5 = tf.nn.relu(bn(act5, is_training)) + act6 = ops.conv2d(act5, 3, 3, 3, 1, 1, name='g_conv_last') + out = tf.nn.tanh(act6) + print('GAN baseline with moving average') + return out + +def generator(zs, + target_class, + gf_dim, + num_classes, + is_training=True): + """Builds the generator graph propagating from z to x. + + Args: + zs: The list of noise tensors. + target_class: The conditional labels in the generation. + gf_dim: The gf dimension. + num_classes: Number of classes in the labels. + scope: Optional scope for `variable_op_scope`. + + Returns: + outputs: The output layer of the generator. + """ + + with tf.variable_scope('model', reuse=tf.AUTO_REUSE): + # project `z` and reshape + act0 = ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0') act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) act1 = block(act0, target_class, gf_dim * 16, - num_classes, 'g_block1') # 8 * 8 + num_classes, is_training, 'g_block1') # 8 * 8 act2 = block(act1, target_class, gf_dim * 8, - num_classes, 'g_block2') # 16 * 16 + num_classes, is_training, 'g_block2') # 16 * 16 act3 = block(act2, target_class, gf_dim * 4, - num_classes, 'g_block3') # 32 * 32 + num_classes, is_training, 'g_block3') # 32 * 32 act4 = block(act3, target_class, gf_dim * 2, - num_classes, 'g_block4') # 64 * 64 + num_classes, is_training, 'g_block4') # 64 * 64 act5 = block(act4, target_class, gf_dim, - num_classes, 'g_block5') # 128 * 128 - bn = ops.BatchNorm(name='g_bn') + num_classes, is_training, 'g_block5') # 128 * 128 + bn = ops.batch_norm(name='g_bn') - act5 = tf.nn.relu(bn(act5)) - act6 = ops.conv2d(act5, 3, 3, 3, 1, 1, name='g_conv_last') + act5 = tf.nn.relu(bn(act5, is_training)) + act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last') + out = tf.nn.tanh(act6) + print('Generator Structure') + return out + + +def generator_test(zs, + target_class, + gf_dim, + num_classes, + is_training=True): + """Builds the generator graph propagating from z to x. + + Args: + zs: The list of noise tensors. + target_class: The conditional labels in the generation. + gf_dim: The gf dimension. + num_classes: Number of classes in the labels. + scope: Optional scope for `variable_op_scope`. + + Returns: + outputs: The output layer of the generator. + """ + + with tf.variable_scope('model', reuse=tf.AUTO_REUSE): + # project `z` and reshape + act0 = ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0') + act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) + + act1 = block(act0, target_class, gf_dim * 16, + num_classes, is_training, 'g_block1') # 8 * 8 + act2 = block(act1, target_class, gf_dim * 8, + num_classes, is_training, 'g_block2') # 16 * 16 + act3 = block(act2, target_class, gf_dim * 4, + num_classes, is_training, 'g_block3') # 32 * 32 + act3 = non_local.sn_non_local_block_sim(act3, None, name='g_non_local') + act4 = block(act3, target_class, gf_dim * 2, + num_classes, is_training, 'g_block4') # 64 * 64 + act5 = block(act4, target_class, gf_dim, + num_classes, is_training, 'g_block5') # 128 * 128 + bn = ops.batch_norm(name='g_bn') + + act5 = tf.nn.relu(bn(act5, is_training)) + act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last') out = tf.nn.tanh(act6) - return out \ No newline at end of file + print('Generator TEST structure') + return out + +def generator_test_64(zs, + target_class, + gf_dim, + num_classes, + is_training=True): + """Builds the generator graph propagating from z to x. + + Args: + zs: The list of noise tensors. + target_class: The conditional labels in the generation. + gf_dim: The gf dimension. + num_classes: Number of classes in the labels. + scope: Optional scope for `variable_op_scope`. + + Returns: + outputs: The output layer of the generator. + """ + + with tf.variable_scope('model', reuse=tf.AUTO_REUSE): + # project `z` and reshape + act0 = ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0') + act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) + + act1 = block(act0, target_class, gf_dim * 16, + num_classes, is_training, 'g_block1') # 8 * 8 + act2 = block(act1, target_class, gf_dim * 8, + num_classes, is_training, 'g_block2') # 16 * 16 + act3 = block(act2, target_class, gf_dim * 4, + num_classes, is_training, 'g_block3') # 32 * 32 + + act4 = block(act3, target_class, gf_dim * 2, + num_classes, is_training, 'g_block4') # 64 * 64 + act4 = non_local.sn_non_local_block_sim(act4, None, name='g_non_local') + act5 = block(act4, target_class, gf_dim, + num_classes, is_training, 'g_block5') # 128 * 128 + bn = ops.batch_norm(name='g_bn') + + act5 = tf.nn.relu(bn(act5, is_training)) + act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last') + out = tf.nn.tanh(act6) + print('GAN test with moving average') + return out + diff --git a/imgs/img1.png b/imgs/img1.png new file mode 100644 index 0000000..245be83 Binary files /dev/null and b/imgs/img1.png differ diff --git a/model.py b/model.py index 9219f48..70e2b15 100644 --- a/model.py +++ b/model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + """The DCGAN Model.""" from __future__ import absolute_import from __future__ import division @@ -30,12 +31,13 @@ tfgan = tf.contrib.gan flags.DEFINE_string( - 'data_dir', '/home/zhanghan/Data', + # 'data_dir', '/gpu/hz138/Data/imagenet', #'/home/hz138/Data/imagenet', + 'data_dir', '/bigdata1/hz138/Data/imagenet', 'Directory with Imagenet input data as sharded recordio files of pre-' 'processed images.') flags.DEFINE_float('discriminator_learning_rate', 0.0004, 'Learning rate of for adam. [0.0004]') -flags.DEFINE_float('generator_learning_rate', 0.0004, +flags.DEFINE_float('generator_learning_rate', 0.0001, 'Learning rate of for adam. [0.0004]') flags.DEFINE_float('beta1', 0.0, 'Momentum term of adam. [0.5]') flags.DEFINE_integer('image_size', 128, 'The size of image to use ' @@ -50,8 +52,8 @@ flags.DEFINE_integer('number_classes', 1000, 'The number of classes in the dataset') flags.DEFINE_string('loss_type', 'hinge_loss', 'the loss type can be' ' hinge_loss or kl_loss') -flags.DEFINE_string('generator_type', 'baseline', 'test or baseline') -flags.DEFINE_string('discriminator_type', 'baseline', 'test or baseline') +flags.DEFINE_string('generator_type', 'test', 'test or baseline') +flags.DEFINE_string('discriminator_type', 'test', 'test or baseline') @@ -123,11 +125,12 @@ def build_model(self): # (non-local) replicas, the ReplicaDeviceSetter distributes the variables # across the different devices. current_step = tf.cast(self.global_step, tf.float32) - g_ratio = (1.0 + 2e-5 * tf.maximum((current_step - 100000.0), 0.0)) - g_ratio = tf.minimum(g_ratio, 4.0) + # g_ratio = (1.0 + 2e-5 * tf.maximum((current_step - 100000.0), 0.0)) + # g_ratio = tf.minimum(g_ratio, 4.0) self.d_learning_rate = FLAGS.discriminator_learning_rate + self.g_learning_rate = FLAGS.generator_learning_rate # self.g_learning_rate = FLAGS.generator_learning_rate / (1.0 + 2e-5 * tf.cast(self.global_step, tf.float32)) - self.g_learning_rate = FLAGS.generator_learning_rate / g_ratio + # self.g_learning_rate = FLAGS.generator_learning_rate / g_ratio with tf.device(tf.train.replica_device_setter(config.ps_tasks)): self.d_opt = tf.train.AdamOptimizer( self.d_learning_rate, beta1=FLAGS.beta1) @@ -139,47 +142,49 @@ def build_model(self): self.g_opt = tf.train.SyncReplicasOptimizer( opt=self.g_opt, replicas_to_aggregate=config.replicas_to_aggregate) - with tf.variable_scope('model') as model_scope: - if config.num_towers > 1: - all_d_grads = [] - all_g_grads = [] - for idx, device in enumerate(self.devices): - with tf.device('/%s' % device): - with tf.name_scope('device_%s' % idx): - with ops.variables_on_gpu0(): - self.build_model_single_gpu( - gpu_idx=idx, - batch_size=config.batch_size, - num_towers=config.num_towers) - d_grads = self.d_opt.compute_gradients(self.d_losses[-1], - var_list=self.d_vars) - g_grads = self.g_opt.compute_gradients(self.g_losses[-1], - var_list=self.g_vars) - all_d_grads.append(d_grads) - all_g_grads.append(g_grads) - model_scope.reuse_variables() - d_grads = ops.avg_grads(all_d_grads) - g_grads = ops.avg_grads(all_g_grads) - else: - with tf.device(tf.train.replica_device_setter(config.ps_tasks)): - # TODO(olganw): reusing virtual batchnorm doesn't work in the multi- - # replica case. - self.build_model_single_gpu(batch_size=config.batch_size, - num_towers=config.num_towers) - d_grads = self.d_opt.compute_gradients(self.d_losses[-1], - var_list=self.d_vars) - g_grads = self.g_opt.compute_gradients(self.g_losses[-1], - var_list=self.g_vars) + if config.num_towers > 1: + all_d_grads = [] + all_g_grads = [] + for idx, device in enumerate(self.devices): + with tf.device('/%s' % device): + with tf.name_scope('device_%s' % idx): + with ops.variables_on_gpu0(): + self.build_model_single_gpu( + gpu_idx=idx, + batch_size=config.batch_size, + num_towers=config.num_towers) + d_grads = self.d_opt.compute_gradients(self.d_losses[-1], + var_list=self.d_vars) + g_grads = self.g_opt.compute_gradients(self.g_losses[-1], + var_list=self.g_vars) + all_d_grads.append(d_grads) + all_g_grads.append(g_grads) + d_grads = ops.avg_grads(all_d_grads) + g_grads = ops.avg_grads(all_g_grads) + else: + with tf.device(tf.train.replica_device_setter(config.ps_tasks)): + # TODO(olganw): reusing virtual batchnorm doesn't work in the multi- + # replica case. + self.build_model_single_gpu(batch_size=config.batch_size, + num_towers=config.num_towers) + d_grads = self.d_opt.compute_gradients(self.d_losses[-1], + var_list=self.d_vars) + g_grads = self.g_opt.compute_gradients(self.g_losses[-1], + var_list=self.g_vars) with tf.device(tf.train.replica_device_setter(config.ps_tasks)): + update_moving_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + print('update_moving_ops', update_moving_ops) if config.sync_replicas: - d_step = tf.get_variable('d_step', initializer=0, trainable=False) - self.d_optim = self.d_opt.apply_gradients(d_grads, global_step=d_step) - g_step = tf.get_variable('g_step', initializer=0, trainable=False) - self.g_optim = self.g_opt.apply_gradients(g_grads, global_step=g_step) + with tf.control_dependencies(update_moving_ops): + d_step = tf.get_variable('d_step', initializer=0, trainable=False) + self.d_optim = self.d_opt.apply_gradients(d_grads, global_step=d_step) + g_step = tf.get_variable('g_step', initializer=0, trainable=False) + self.g_optim = self.g_opt.apply_gradients(g_grads, global_step=g_step) else: # Don't create any additional counters, and don't update the global step - self.d_optim = self.d_opt.apply_gradients(d_grads) - self.g_optim = self.g_opt.apply_gradients(g_grads) + with tf.control_dependencies(update_moving_ops): + self.d_optim = self.d_opt.apply_gradients(d_grads) + self.g_optim = self.g_opt.apply_gradients(g_grads) def build_model_single_gpu(self, gpu_idx=0, batch_size=1, num_towers=1): """Builds a model for a single GPU. @@ -339,8 +344,8 @@ def get_vars(self): """Get variables.""" t_vars = tf.trainable_variables() # TODO(olganw): scoping or collections for this instead of name hack - self.d_vars = [var for var in t_vars if var.name.startswith('model/Discriminator')] - self.g_vars = [var for var in t_vars if var.name.startswith('model/Generator')] + self.d_vars = [var for var in t_vars if var.name.startswith('model/d_')] + self.g_vars = [var for var in t_vars if var.name.startswith('model/g_')] self.sigma_ratio_vars = [var for var in t_vars if 'sigma_ratio' in var.name] for x in self.d_vars: assert x not in self.g_vars diff --git a/tmp/non_local.py b/non_local.py similarity index 62% rename from tmp/non_local.py rename to non_local.py index d0f7a66..e54315c 100644 --- a/tmp/non_local.py +++ b/non_local.py @@ -13,12 +13,25 @@ # limitations under the License. # ============================================================================== -import sn_ops import tensorflow as tf +import numpy as np +import ops +def conv1x1(input_, output_dim, + init=tf.contrib.layers.xavier_initializer(), name='conv1x1'): + k_h = 1 + k_w = 1 + d_h = 1 + d_w = 1 + with tf.variable_scope(name): + w = tf.get_variable( + 'w', [k_h, k_w, input_.get_shape()[-1], output_dim], + initializer=init) + conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') + return conv def sn_conv1x1(input_, output_dim, update_collection, - init=tf.contrib.layers.xavier_initializer(), name='sn_conv1x1'): + init=tf.contrib.layers.xavier_initializer(), name='sn_conv1x1'): with tf.variable_scope(name): k_h = 1 k_w = 1 @@ -27,44 +40,42 @@ def sn_conv1x1(input_, output_dim, update_collection, w = tf.get_variable( 'w', [k_h, k_w, input_.get_shape()[-1], output_dim], initializer=init) - w_bar = sn_ops.spectral_normed_weight( - w, num_iters=1, update_collection=update_collection) + w_bar = ops.spectral_normed_weight(w, num_iters=1, update_collection=update_collection) conv = tf.nn.conv2d(input_, w_bar, strides=[1, d_h, d_w, 1], padding='SAME') return conv - -def sn_non_local_block_sim(x, update_collection, name, - init=tf.contrib.layers.xavier_initializer()): +def sn_non_local_block_sim(x, update_collection, name, init=tf.contrib.layers.xavier_initializer()): with tf.variable_scope(name): batch_size, h, w, num_channels = x.get_shape().as_list() location_num = h * w downsampled_num = location_num // 4 - theta = sn_conv1x1( - x, num_channels // 8, update_collection, init, 'sn_conv_theta') + # theta path + theta = sn_conv1x1(x, num_channels // 8, update_collection, init, 'sn_conv_theta') theta = tf.reshape( theta, [batch_size, location_num, num_channels // 8]) - phi = sn_conv1x1( - x, num_channels // 8, update_collection, init, 'sn_conv_phi') + # phi path + phi = sn_conv1x1(x, num_channels // 8, update_collection, init, 'sn_conv_phi') phi = tf.layers.max_pooling2d(inputs=phi, pool_size=[2, 2], strides=2) phi = tf.reshape( phi, [batch_size, downsampled_num, num_channels // 8]) + attn = tf.matmul(theta, phi, transpose_b=True) attn = tf.nn.softmax(attn) + print(tf.reduce_sum(attn, axis=-1)) + # g path g = sn_conv1x1(x, num_channels // 2, update_collection, init, 'sn_conv_g') g = tf.layers.max_pooling2d(inputs=g, pool_size=[2, 2], strides=2) g = tf.reshape( - g, [batch_size, downsampled_num, num_channels // 2]) + g, [batch_size, downsampled_num, num_channels // 2]) attn_g = tf.matmul(attn, g) attn_g = tf.reshape(attn_g, [batch_size, h, w, num_channels // 2]) sigma = tf.get_variable( 'sigma_ratio', [], initializer=tf.constant_initializer(0.0)) - attn_g = sn_conv1x1( - attn_g, num_channels, update_collection, init, 'sn_conv_attn') - return x + sigma * attn_g - + attn_g = sn_conv1x1(attn_g, num_channels, update_collection, init, 'sn_conv_attn') + return x + sigma * attn_g \ No newline at end of file diff --git a/ops.py b/ops.py index b27fa83..30ece1e 100644 --- a/ops.py +++ b/ops.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + """The building block ops for Spectral Normalization GAN.""" from __future__ import absolute_import from __future__ import division @@ -261,7 +262,7 @@ def sn_embedding(x, number_classes, embedding_size, sn_iters=1, return tf.nn.embedding_lookup(embedding_map_bar, x) -class ConditionalBatchNorm(object): +class ConditionalBatchNorm_old(object): """Conditional BatchNorm. For each class, it has a specific gamma and beta as normalization variable. @@ -300,6 +301,79 @@ def __call__(self, inputs, labels): outputs.set_shape(inputs_shape) return outputs +class ConditionalBatchNorm(object): + """Conditional BatchNorm. + + For each class, it has a specific gamma and beta as normalization variable. + """ + + def __init__(self, num_categories, name='conditional_batch_norm', decay_rate=0.999, center=True, + scale=True): + with tf.variable_scope(name): + self.name = name + self.num_categories = num_categories + self.center = center + self.scale = scale + self.decay_rate = decay_rate + + def __call__(self, inputs, labels, is_training=True): + inputs = tf.convert_to_tensor(inputs) + inputs_shape = inputs.get_shape() + params_shape = inputs_shape[-1:] + axis = [0, 1, 2] + shape = tf.TensorShape([self.num_categories]).concatenate(params_shape) + moving_shape = tf.TensorShape([1, 1, 1]).concatenate(params_shape) + + with tf.variable_scope(self.name): + self.gamma = tf.get_variable( + 'gamma', shape, + initializer=tf.ones_initializer()) + self.beta = tf.get_variable( + 'beta', shape, + initializer=tf.zeros_initializer()) + self.moving_mean = tf.get_variable('mean', moving_shape, + initializer=tf.zeros_initializer(), + trainable=False) + self.moving_var = tf.get_variable('var', moving_shape, + initializer=tf.ones_initializer(), + trainable=False) + + beta = tf.gather(self.beta, labels) + beta = tf.expand_dims(tf.expand_dims(beta, 1), 1) + gamma = tf.gather(self.gamma, labels) + gamma = tf.expand_dims(tf.expand_dims(gamma, 1), 1) + decay = self.decay_rate + variance_epsilon = 1E-5 + if is_training: + mean, variance = tf.nn.moments(inputs, axis, keep_dims=True) + update_mean = tf.assign(self.moving_mean, self.moving_mean * decay + mean * (1 - decay)) + update_var = tf.assign(self.moving_var, self.moving_var * decay + variance * (1 - decay)) + tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean) + tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_var) + #with tf.control_dependencies([update_mean, update_var]): + outputs = tf.nn.batch_normalization( + inputs, mean, variance, beta, gamma, variance_epsilon) + else: + outputs = tf.nn.batch_normalization( + inputs, self.moving_mean, self.moving_var, beta, gamma, variance_epsilon) + outputs.set_shape(inputs_shape) + return outputs + +class batch_norm(object): + def __init__(self, epsilon=1e-5, momentum = 0.9999, name="batch_norm"): + with tf.variable_scope(name): + self.epsilon = epsilon + self.momentum = momentum + self.name = name + + def __call__(self, x, train=True): + return tf.contrib.layers.batch_norm(x, + decay=self.momentum, + # updates_collections=None, + epsilon=self.epsilon, + scale=True, + is_training=train, + scope=self.name) class BatchNorm(object): """The Batch Normalization layer.""" diff --git a/tmp/discriminator.py b/tmp/discriminator.py deleted file mode 100644 index a27919a..0000000 --- a/tmp/discriminator.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import tensorflow as tf -import non_local -import sn_ops as ops - -slim = tf.contrib.slim - - -def dsample(x): - xd = tf.nn.avg_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID') - return xd - - -def Block(x, out_channels, name, update_collection=tf.GraphKeys.UPDATE_OPS, - downsample=True, act=tf.nn.relu): - with tf.variable_scope(name): - input_channels = x.shape.as_list()[-1] - x_0 = x - x = act(x) - x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, - update_collection=update_collection, name='sn_conv1') - x = act(x) - x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, - update_collection=update_collection, name='sn_conv2') - if downsample: - x = dsample(x) - if downsample or input_channels != out_channels: - x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, - update_collection=update_collection, name='sn_conv3') - if downsample: - x_0 = dsample(x_0) - return x_0 + x - - -def BaselineBlock(x, bottleneck_ratio, name, update_collection=None, - act=tf.nn.relu): - with tf.variable_scope(name): - input_channels = x.shape.as_list()[-1] - out_channels = input_channels // bottleneck_ratio - x_0 = x - x = act(x) - x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, - update_collection=update_collection, name='sn_conv1') - x = act(x) - x = ops.snconv2d(x, input_channels, 3, 3, 1, 1, - update_collection=update_collection, name='sn_conv2') - sigma = tf.get_variable( - 'sigma_ratio', [], initializer=tf.constant_initializer(0.0)) - return x_0 + sigma * x - - -def OptimizedBlock(x, out_channels, name, - update_collection=tf.GraphKeys.UPDATE_OPS, act=tf.nn.relu): - with tf.variable_scope(name): - x_0 = x - x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, - update_collection=update_collection, name='sn_conv1') - x = act(x) - x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, - update_collection=update_collection, name='sn_conv2') - x = dsample(x) - x_0 = dsample(x_0) - x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, - update_collection=update_collection, name='sn_conv3') - return x + x_0 - - -def discriminator(image, labels, df_dim, number_classes, reuse_vars=False, - update_collection=tf.GraphKeys.UPDATE_OPS, act=tf.nn.relu): - if reuse_vars: - tf.get_variable_scope().reuse_variables() - h0 = OptimizedBlock(image, df_dim, 'd_optimized_block1', - update_collection, act=act) - h0 = BaselineBlock(h0, 8, 'd_baseline_block') - h1 = Block(h0, df_dim * 2, 'd_block2', update_collection, act=act) - h2 = Block(h1, df_dim * 4, 'd_block3', update_collection, act=act) - h3 = Block(h2, df_dim * 8, 'd_block4', update_collection, act=act) - h4 = Block(h3, df_dim * 16, 'd_block5', update_collection, act=act) - h5 = Block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act) - h5_act = act(h5) - h6 = tf.reduce_sum(h5_act, [1, 2]) - output = ops.snlinear(h6, 1, update_collection=update_collection, - name='d_sn_linear') - h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16, - update_collection=update_collection, - name='d_embedding') - output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True) - return output - - -def discriminator_test(image, labels, df_dim, number_classes, reuse_vars=False, - update_collection=tf.GraphKeys.UPDATE_OPS, - act=tf.nn.relu): - if reuse_vars: - tf.get_variable_scope().reuse_variables() - h0 = OptimizedBlock(image, df_dim, 'd_optimized_block1', update_collection, - act=act) - h1 = Block(h0, df_dim * 2, 'd_block2', update_collection, act=act) - h1 = non_local.sn_non_local_block_sim(h1, update_collection, - name='d_non_local') - h2 = Block(h1, df_dim * 4, 'd_block3', update_collection, act=act) - h3 = Block(h2, df_dim * 8, 'd_block4', update_collection, act=act) - h4 = Block(h3, df_dim * 16, 'd_block5', update_collection, act=act) - h5 = Block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act) - h5_act = act(h5) - h6 = tf.reduce_sum(h5_act, [1, 2]) - output = ops.snlinear(h6, 1, update_collection=update_collection, - name='d_sn_linear') - h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16, - update_collection=update_collection, - name='d_embedding') - output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True) - return output diff --git a/tmp/generator.py b/tmp/generator.py deleted file mode 100644 index 2e72ab3..0000000 --- a/tmp/generator.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import tensorflow as tf -import non_local -import sn_ops - - -def usample(x): - _, nh, nw, _ = x.get_shape().as_list() - x = tf.image.resize_nearest_neighbor(x, [nh * 2, nw * 2]) - return x - - -def Block(x, labels, out_channels, num_classes, name): - with tf.variable_scope(name): - bn0 = sn_ops.ConditionalBatchNorm(num_classes, name='cbn_0') - bn1 = sn_ops.ConditionalBatchNorm(num_classes, name='cbn_1') - x_0 = x - x = tf.nn.relu(bn0(x, labels)) - x = usample(x) - x = sn_ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='snconv1') - x = tf.nn.relu(bn1(x, labels)) - x = sn_ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='snconv2') - - x_0 = usample(x_0) - x_0 = sn_ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, name='snconv3') - - return x_0 + x - - -def BaselineBlock(x, bottleneck_ratio, name, update_collection=None, - act=tf.nn.relu): - with tf.variable_scope(name): - input_channels = x.shape.as_list()[-1] - out_channels = input_channels // bottleneck_ratio - x_0 = x - x = act(x) - x = sn_ops.snconv2d(x, out_channels, 3, 3, 1, 1, - update_collection=update_collection, name='sn_conv1') - x = act(x) - x = sn_ops.snconv2d(x, input_channels, 3, 3, 1, 1, - update_collection=update_collection, name='sn_conv2') - sigma = tf.get_variable( - 'sigma_ratio', [], initializer=tf.constant_initializer(0.0)) - return x_0 + sigma * x - - -def generator(zs, - target_class, - gf_dim, - num_classes, - reuse_vars=False): - - if reuse_vars: - tf.get_variable_scope().reuse_variables() - - act0 = sn_ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0') - act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) - - act1 = Block(act0, target_class, gf_dim * 16, num_classes, 'g_block1') - act2 = Block(act1, target_class, gf_dim * 8, num_classes, 'g_block2') - act3 = Block(act2, target_class, gf_dim * 4, num_classes, 'g_block3') - act4 = Block(act3, target_class, gf_dim * 2, num_classes, 'g_block4') - act4 = BaselineBlock(act4, 8, 'g_baseline_block') - act5 = Block(act4, target_class, gf_dim, num_classes, 'g_block5') - bn = sn_ops.BatchNorm(name='g_bn') - - act5 = tf.nn.relu(bn(act5)) - act6 = sn_ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last') - out = tf.nn.tanh(act6) - return out - - -def generator_test(zs, - target_class, - gf_dim, - num_classes, - reuse_vars=False): - - if reuse_vars: - tf.get_variable_scope().reuse_variables() - - act0 = sn_ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0') - act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) - act1 = Block(act0, target_class, gf_dim * 16, num_classes, 'g_block1') - act2 = Block(act1, target_class, gf_dim * 8, num_classes, 'g_block2') - act3 = Block(act2, target_class, gf_dim * 4, num_classes, 'g_block3') - act3 = non_local.sn_non_local_block_sim(act3, None, name='g_non_local') - act4 = Block(act3, target_class, gf_dim * 2, num_classes, 'g_block4') - act5 = Block(act4, target_class, gf_dim, num_classes, 'g_block5') - bn = sn_ops.BatchNorm(name='g_bn') - - act5 = tf.nn.relu(bn(act5)) - act6 = sn_ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last') - out = tf.nn.tanh(act6) - return out diff --git a/tmp/model.py b/tmp/model.py deleted file mode 100644 index 87996b2..0000000 --- a/tmp/model.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl import flags -import tensorflow as tf -import discriminator as disc -import generator as generator_module -import ops -import utils_in as utils - - -tfgan = tf.contrib.gan - -flags.DEFINE_string( - 'data_dir', '/tmp/data', - 'Directory with Imagenet input data') -flags.DEFINE_float('discriminator_learning_rate', 0.0004, - 'Learning rate of for adam. [0.0004]') -flags.DEFINE_float('generator_learning_rate', 0.0004, - 'Learning rate of for adam. [0.0004]') -flags.DEFINE_float('beta1', 0.0, 'Momentum term of adam. [0.5]') -flags.DEFINE_integer('image_size', 128, 'The size of image to use ' - '(will be center cropped) [128]') -flags.DEFINE_integer('image_width', 128, - 'The width of the images presented to the model') -flags.DEFINE_integer('data_parallelism', 64, 'The number of objects to read at' - ' one time when loading input data. [64]') -flags.DEFINE_integer('z_dim', 128, 'Dimensionality of latent code z. [8192]') -flags.DEFINE_integer('gf_dim', 64, 'Dimensionality of gf. [64]') -flags.DEFINE_integer('df_dim', 64, 'Dimensionality of df. [64]') -flags.DEFINE_integer('number_classes', 1000, - 'The number of classes in the dataset') -flags.DEFINE_string('generator_type', 'baseline', 'test or baseline') -flags.DEFINE_string('discriminator_type', 'baseline', 'test or baseline') - - -FLAGS = flags.FLAGS - - -def _get_d_real_loss(discriminator_on_data_logits): - loss = tf.nn.relu(1.0 - discriminator_on_data_logits) - return tf.reduce_mean(loss) - - -def _get_d_fake_loss(discriminator_on_generator_logits): - return tf.reduce_mean(tf.nn.relu(1 + discriminator_on_generator_logits)) - - -def _get_g_loss(discriminator_on_generator_logits): - return -tf.reduce_mean(discriminator_on_generator_logits) - - -class SNGAN(object): - - def __init__(self, zs, config=None, global_step=None, devices=None): - self.config = config - self.image_size = FLAGS.image_size - self.image_shape = [FLAGS.image_size, FLAGS.image_size, 3] - self.z_dim = FLAGS.z_dim - self.gf_dim = FLAGS.gf_dim - self.df_dim = FLAGS.df_dim - self.num_classes = FLAGS.number_classes - - self.data_parallelism = FLAGS.data_parallelism - self.zs = zs - - self.c_dim = 3 - self.dataset_name = 'imagenet' - self.devices = devices - self.global_step = global_step - - self.build_model() - - def build_model(self): - config = self.config - self.d_opt = tf.train.AdamOptimizer( - FLAGS.discriminator_learning_rate, beta1=FLAGS.beta1) - self.g_opt = tf.train.AdamOptimizer( - FLAGS.generator_learning_rate, beta1=FLAGS.beta1) - - with tf.variable_scope('model') as model_scope: - if config.num_towers > 1: - all_d_grads = [] - all_g_grads = [] - for idx, device in enumerate(self.devices): - with tf.device('/%s' % device): - with tf.name_scope('device_%s' % idx): - with ops.variables_on_gpu0(): - self.build_model_single_gpu( - gpu_idx=idx, - batch_size=config.batch_size, - num_towers=config.num_towers) - d_grads = self.d_opt.compute_gradients(self.d_losses[-1], - var_list=self.d_vars) - g_grads = self.g_opt.compute_gradients(self.g_losses[-1], - var_list=self.g_vars) - all_d_grads.append(d_grads) - all_g_grads.append(g_grads) - model_scope.reuse_variables() - d_grads = ops.avg_grads(all_d_grads) - g_grads = ops.avg_grads(all_g_grads) - else: - self.build_model_single_gpu(batch_size=config.batch_size, - num_towers=config.num_towers) - d_grads = self.d_opt.compute_gradients(self.d_losses[-1], - var_list=self.d_vars) - g_grads = self.g_opt.compute_gradients(self.g_losses[-1], - var_list=self.g_vars) - - d_step = tf.get_variable('d_step', initializer=0, trainable=False) - self.d_optim = self.d_opt.apply_gradients(d_grads, global_step=d_step) - g_step = tf.get_variable('g_step', initializer=0, trainable=False) - self.g_optim = self.g_opt.apply_gradients(g_grads, global_step=g_step) - - def build_model_single_gpu(self, gpu_idx=0, batch_size=1, num_towers=1): - config = self.config - show_num = min(config.batch_size, 64) - - reuse_vars = gpu_idx > 0 - if gpu_idx == 0: - self.increment_global_step = self.global_step.assign_add(1) - self.batches = utils.get_imagenet_batches( - FLAGS.data_dir, batch_size, num_towers, label_offset=0, - cycle_length=config.data_parallelism, - shuffle_buffer_size=config.shuffle_buffer_size) - sample_images, _ = self.batches[0] - vis_images = tf.cast((sample_images + 1.) * 127.5, tf.uint8) - tf.summary.image('input_image_grid', - tfgan.eval.image_grid( - vis_images[:show_num], - grid_shape=utils.squarest_grid_size( - show_num), - image_shape=(128, 128))) - - images, sparse_labels = self.batches[gpu_idx] - sparse_labels = tf.squeeze(sparse_labels) - - gen_class_logits = tf.zeros((batch_size, self.num_classes)) - gen_class_ints = tf.multinomial(gen_class_logits, 1) - gen_sparse_class = tf.squeeze(gen_class_ints) - gen_class_ints = tf.squeeze(gen_class_ints) - gen_class_vector = tf.one_hot(gen_class_ints, self.num_classes) - - if FLAGS.generator_type == 'baseline': - generator_fn = generator_module.generator - elif FLAGS.generator_type == 'test': - generator_fn = generator_module.generator_test - - generator = generator_fn( - self.zs[gpu_idx], - gen_sparse_class, - self.gf_dim, - self.num_classes, - reuse_vars=reuse_vars, - ) - - if gpu_idx == 0: - generator_means = tf.reduce_mean(generator, 0, keep_dims=True) - generator_vars = tf.reduce_mean( - tf.squared_difference(generator, generator_means), 0, keep_dims=True) - generator = tf.Print( - generator, - [tf.reduce_mean(generator_means), tf.reduce_mean(generator_vars)], - 'generator mean and average var', first_n=1) - image_means = tf.reduce_mean(images, 0, keep_dims=True) - image_vars = tf.reduce_mean( - tf.squared_difference(images, image_means), 0, keep_dims=True) - images = tf.Print( - images, [tf.reduce_mean(image_means), tf.reduce_mean(image_vars)], - 'image mean and average var', first_n=1) - sparse_labels = tf.Print( - sparse_labels, [sparse_labels, sparse_labels.shape], - 'sparse_labels', first_n=2) - gen_sparse_class = tf.Print( - gen_sparse_class, [gen_sparse_class, gen_sparse_class.shape], - 'gen_sparse_labels', first_n=2) - - self.generators = [] - - self.generators.append(generator) - - if FLAGS.discriminator_type == 'baseline': - discriminator_fn = disc.discriminator - elif FLAGS.discriminator_type == 'test': - discriminator_fn = disc.discriminator_test - else: - raise NotImplementedError - discriminator_on_data_logits = discriminator_fn( - images, sparse_labels, self.df_dim, self.num_classes, - reuse_vars=reuse_vars, update_collection=None) - discriminator_on_generator_logits = discriminator_fn( - generator, gen_sparse_class, self.df_dim, self.num_classes, - reuse_vars=True, update_collection='NO_OPS') - - vis_generator = tf.cast((generator + 1.) * 127.5, tf.uint8) - tf.summary.image('generator', vis_generator) - - tf.summary.image('generator_grid', - tfgan.eval.image_grid( - vis_generator[:show_num], - grid_shape=utils.squarest_grid_size(show_num), - image_shape=(128, 128))) - - d_loss_real = _get_d_real_loss( - discriminator_on_data_logits) - d_loss_fake = _get_d_fake_loss(discriminator_on_generator_logits) - g_loss_gan = _get_g_loss(discriminator_on_generator_logits) - - d_loss = d_loss_real + d_loss_fake - g_loss = g_loss_gan - - logit_discriminator_on_data = tf.reduce_mean(discriminator_on_data_logits) - logit_discriminator_on_generator = tf.reduce_mean( - discriminator_on_generator_logits) - - tf.summary.scalar('d_loss', d_loss) - tf.summary.scalar('d_loss_real', d_loss_real) - tf.summary.scalar('d_loss_fake', d_loss_fake) - tf.summary.scalar('g_loss', g_loss) - tf.summary.scalar('logit_real', logit_discriminator_on_data) - tf.summary.scalar('logit_fake', logit_discriminator_on_generator) - - if gpu_idx == 0: - self.d_loss_reals = [] - self.d_loss_fakes = [] - self.d_losses = [] - self.g_losses = [] - self.d_loss_reals.append(d_loss_real) - self.d_loss_fakes.append(d_loss_fake) - self.d_losses.append(d_loss) - self.g_losses.append(g_loss) - - if gpu_idx == 0: - self.get_vars() - for var in self.sigma_ratio_vars: - tf.summary.scalar(var.name, var) - - def get_vars(self): - t_vars = tf.trainable_variables() - self.d_vars = [var for var in t_vars if var.name.startswith('model/d_')] - self.g_vars = [var for var in t_vars if var.name.startswith('model/g_')] - self.sigma_ratio_vars = [var for var in t_vars if 'sigma_ratio' in var.name] - self.all_vars = t_vars diff --git a/tmp/ops.py b/tmp/ops.py deleted file mode 100644 index 3fd8897..0000000 --- a/tmp/ops.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from contextlib import contextmanager -import tensorflow as tf - - -@contextmanager -def variables_on_gpu0(): - old_fn = tf.get_variable - def new_fn(*args, **kwargs): - with tf.device('/gpu:0'): - return old_fn(*args, **kwargs) - tf.get_variable = new_fn - yield - tf.get_variable = old_fn - - -def avg_grads(tower_grads): - average_grads = [] - for grad_and_vars in zip(*tower_grads): - grads = [] - for g, _ in grad_and_vars: - expanded_g = tf.expand_dims(g, 0) - - grads.append(expanded_g) - - grad = tf.concat(grads, 0) - grad = tf.reduce_mean(grad, 0) - v = grad_and_vars[0][1] - grad_and_var = (grad, v) - average_grads.append(grad_and_var) - return average_grads diff --git a/tmp/sn_ops.py b/tmp/sn_ops.py deleted file mode 100644 index be94ab7..0000000 --- a/tmp/sn_ops.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import tensorflow as tf - -NO_OPS = 'NO_OPS' - - -def _l2normalize(v, eps=1e-12): - return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) - - -def spectral_normed_weight(weights, u=None, num_iters=1, update_collection=None, - with_sigma=False): - w_shape = weights.shape.as_list() - w_mat = tf.reshape(weights, [-1, w_shape[-1]]) - if u is None: - u = tf.get_variable('u', [1, w_shape[-1]], - initializer=tf.truncated_normal_initializer(), - trainable=False) - _u = u - for _ in range(num_iters): - _v = _l2normalize(tf.matmul(_u, w_mat, transpose_b=True)) - _u = _l2normalize(tf.matmul(_v, w_mat)) - - sigma = tf.squeeze(tf.matmul(tf.matmul(_v, w_mat), _u, transpose_b=True)) - w_mat = w_mat / sigma - if update_collection is None: - with tf.control_dependencies([u.assign(_u)]): - w_bar = tf.reshape(w_mat, w_shape) - else: - w_bar = tf.reshape(w_mat, w_shape) - if update_collection != NO_OPS: - tf.add_to_collection(update_collection, u.assign(_u)) - if with_sigma: - return w_bar, sigma - else: - return w_bar - - -def snconv2d(input_, output_dim, - k_h=3, k_w=3, d_h=2, d_w=2, stddev=0.02, - sn_iters=1, update_collection=None, name='snconv2d'): - with tf.variable_scope(name): - w = tf.get_variable( - 'w', [k_h, k_w, input_.get_shape()[-1], output_dim], - initializer=tf.contrib.layers.xavier_initializer()) - w_bar = spectral_normed_weight(w, num_iters=sn_iters, - update_collection=update_collection) - conv = tf.nn.conv2d(input_, w_bar, strides=[1, d_h, d_w, 1], padding='SAME') - biases = tf.get_variable('biases', [output_dim], - initializer=tf.zeros_initializer()) - conv = tf.nn.bias_add(conv, biases) - return conv - - -def snlinear(x, output_size, mean=0., stddev=0.02, bias_start=0.0, - sn_iters=1, update_collection=None, name='snlinear'): - shape = x.get_shape().as_list() - - with tf.variable_scope(name): - matrix = tf.get_variable( - 'Matrix', [shape[1], output_size], tf.float32, - tf.contrib.layers.xavier_initializer()) - matrix_bar = spectral_normed_weight(matrix, num_iters=sn_iters, - update_collection=update_collection) - bias = tf.get_variable( - 'bias', [output_size], initializer=tf.constant_initializer(bias_start)) - out = tf.matmul(x, matrix_bar) + bias - return out - - -def sn_embedding(x, number_classes, embedding_size, sn_iters=1, - update_collection=None, - name='snembedding'): - with tf.variable_scope(name): - embedding_map = tf.get_variable( - name='embedding_map', - shape=[number_classes, embedding_size], - initializer=tf.contrib.layers.xavier_initializer()) - embedding_map_bar_transpose = spectral_normed_weight( - tf.transpose(embedding_map), num_iters=sn_iters, - update_collection=update_collection) - embedding_map_bar = tf.transpose(embedding_map_bar_transpose) - return tf.nn.embedding_lookup(embedding_map_bar, x) - - -class ConditionalBatchNorm(object): - - def __init__(self, num_categories, name='conditional_batch_norm', center=True, - scale=True): - with tf.variable_scope(name): - self.name = name - self.num_categories = num_categories - self.center = center - self.scale = scale - - def __call__(self, inputs, labels): - inputs = tf.convert_to_tensor(inputs) - inputs_shape = inputs.get_shape() - params_shape = inputs_shape[-1:] - axis = [0, 1, 2] - shape = tf.TensorShape([self.num_categories]).concatenate(params_shape) - - with tf.variable_scope(self.name): - self.gamma = tf.get_variable( - 'gamma', shape, - initializer=tf.ones_initializer()) - self.beta = tf.get_variable( - 'beta', shape, - initializer=tf.zeros_initializer()) - beta = tf.gather(self.beta, labels) - beta = tf.expand_dims(tf.expand_dims(beta, 1), 1) - gamma = tf.gather(self.gamma, labels) - gamma = tf.expand_dims(tf.expand_dims(gamma, 1), 1) - mean, variance = tf.nn.moments(inputs, axis, keep_dims=True) - variance_epsilon = 1E-5 - outputs = tf.nn.batch_normalization( - inputs, mean, variance, beta, gamma, variance_epsilon) - outputs.set_shape(inputs_shape) - return outputs - - -class BatchNorm(object): - - def __init__(self, name='batch_norm', center=True, - scale=True): - with tf.variable_scope(name): - self.name = name - self.center = center - self.scale = scale - - def __call__(self, inputs): - inputs = tf.convert_to_tensor(inputs) - inputs_shape = inputs.get_shape().as_list() - params_shape = inputs_shape[-1] - axis = [0, 1, 2] - shape = tf.TensorShape([params_shape]) - with tf.variable_scope(self.name): - self.gamma = tf.get_variable( - 'gamma', shape, - initializer=tf.ones_initializer()) - self.beta = tf.get_variable( - 'beta', shape, - initializer=tf.zeros_initializer()) - beta = self.beta - gamma = self.gamma - - mean, variance = tf.nn.moments(inputs, axis, keep_dims=True) - variance_epsilon = 1E-5 - outputs = tf.nn.batch_normalization( - inputs, mean, variance, beta, gamma, variance_epsilon) - outputs.set_shape(inputs_shape) - return outputs diff --git a/tmp/train_imagenet.py b/tmp/train_imagenet.py deleted file mode 100644 index 379fa14..0000000 --- a/tmp/train_imagenet.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -from absl import flags -import tensorflow as tf -import model -import utils_in as utils - -tfgan = tf.contrib.gan - - -flags.DEFINE_string('master', 'local', - 'BNS name of the TensorFlow master to use. [local]') -flags.DEFINE_string('checkpoint_dir', 'checkpoint', - 'Directory name to save the checkpoints. [checkpoint]') -flags.DEFINE_integer('batch_size', 32, 'Number of images in input batch. [64]') -flags.DEFINE_integer('shuffle_buffer_size', 100000, 'Number of records to load ' - 'before shuffling and yielding for consumption. [100000]') -flags.DEFINE_integer('save_summaries_steps', 1, 'Number of seconds between ' - 'saving summary statistics. [1]') -flags.DEFINE_integer('save_checkpoint_secs', 1200, 'Number of seconds between ' - 'saving checkpoints of model. [1200]') -flags.DEFINE_integer('task', 0, 'The task id of the current worker. [0]') -flags.DEFINE_integer('ps_tasks', 0, 'The number of ps tasks. [0]') -flags.DEFINE_integer('num_workers', 1, 'The number of worker tasks. [1]') -flags.DEFINE_integer('replicas_to_aggregate', 1, 'The number of replicas ' - 'to aggregate for synchronous optimization [1]') -flags.DEFINE_boolean('sync_replicas', True, 'Whether to sync replicas. [True]') -flags.DEFINE_integer('num_towers', 1, 'The number of GPUs to use per task. [1]') - -FLAGS = flags.FLAGS - - -def main(_): - tf.gfile.MakeDirs(FLAGS.checkpoint_dir) - model_dir = '%s_%s' % ('imagenet', FLAGS.batch_size) - logdir = os.path.join(FLAGS.checkpoint_dir, model_dir) - tf.gfile.MakeDirs(logdir) - graph = tf.Graph() - with graph.as_default(): - global_step = tf.train.create_global_step() - devices = ['/gpu:{}'.format(tower) for tower in range(FLAGS.num_towers)] - - noise_tensor = utils.make_z_normal( - FLAGS.num_towers, FLAGS.batch_size, FLAGS.z_dim) - - model_object = model.SNGAN( - noise_tensor=noise_tensor, - config=FLAGS, - global_step=global_step, - devices=devices) - - train_ops = tfgan.GANTrainOps( - generator_train_op=model_object.g_optim, - discriminator_train_op=model_object.d_optim, - global_step_inc_op=model_object.increment_global_step) - - session_config = tf.ConfigProto( - allow_soft_placement=True, log_device_placement=False) - - train_steps = tfgan.GANTrainSteps(1, 1) - tfgan.gan_train( - train_ops, - get_hooks_fn=tfgan.get_sequential_train_hooks( - train_steps=train_steps), - hooks=([tf.train.StopAtStepHook(num_steps=2000000)]), - logdir=logdir, - master=FLAGS.master, - is_chief=(FLAGS.task == 0), - save_summaries_steps=FLAGS.save_summaries_steps, - save_checkpoint_secs=FLAGS.save_checkpoint_secs, - config=session_config) - -if __name__ == '__main__': - tf.app.run() diff --git a/tmp/utils_in.py b/tmp/utils_in.py deleted file mode 100644 index 65dea64..0000000 --- a/tmp/utils_in.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math -import os -import sympy -import numpy as np -import tensorflow as tf - - -def make_z_normal(num_batches, batch_size, z_dim): - shape = [num_batches, batch_size, z_dim] - z = tf.random_normal(shape, name='z0', dtype=tf.float32) - return z - - -def get_imagenet_batches(data_dir, - batch_size, - num_batches, - label_offset=0, - cycle_length=1, - shuffle_buffer_size=100000): - filenames = tf.gfile.Glob(os.path.join(data_dir, '*_train_*-*-of-*')) - filename_dataset = tf.data.Dataset.from_tensor_slices(filenames) - filename_dataset = filename_dataset.shuffle(len(filenames)) - prefetch = max(int((batch_size * num_batches) / cycle_length), 1) - dataset = filename_dataset.interleave( - lambda fn: tf.data.RecordIODataset(fn).prefetch(prefetch), - cycle_length=cycle_length) - - dataset = dataset.shuffle(shuffle_buffer_size) - image_size = 128 - - def _extract_image_and_label(record): - features = tf.parse_single_example( - record, - features={ - 'image_raw': tf.FixedLenFeature([], tf.string), - 'label': tf.FixedLenFeature([], tf.int64) - }) - - image = tf.decode_raw(features['image_raw'], tf.uint8) - image.set_shape(image_size * image_size * 3) - image = tf.reshape(image, [image_size, image_size, 3]) - - image = tf.cast(image, tf.float32) * (2. / 255) - 1. - - label = tf.cast(features['label'], tf.int32) - label += label_offset - - return image, label - - dataset = dataset.map( - _extract_image_and_label, - num_parallel_calls=16).prefetch(batch_size * num_batches) - dataset = dataset.repeat() - dataset = dataset.batch(batch_size) - dataset = dataset.batch(num_batches) - - iterator = dataset.make_one_shot_iterator() - images, labels = iterator.get_next() - - batches = [] - for i in range(num_batches): - im = images[i, ...] - im.set_shape([batch_size, image_size, image_size, 3]) - - lb = labels[i, ...] - lb.set_shape((batch_size,)) - - batches.append((im, tf.expand_dims(lb, 1))) - - return batches - - -def squarest_grid_size(num_images): - divisors = sympy.divisors(num_images) - square_root = math.sqrt(num_images) - width = 1 - for d in divisors: - if d > square_root: - break - width = d - return (num_images // width, width) diff --git a/train_imagenet.py b/train_imagenet.py index 49b31bf..f353f35 100644 --- a/train_imagenet.py +++ b/train_imagenet.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + """Train Imagenet.""" from __future__ import absolute_import from __future__ import division @@ -38,7 +39,7 @@ # 'Directory name to save the checkpoints. [checkpoint]') flags.DEFINE_string('checkpoint_dir', 'checkpoint', 'Directory name to save the checkpoints. [checkpoint]') -flags.DEFINE_integer('batch_size', 32, 'Number of images in input batch. [64]') # ori 16 +flags.DEFINE_integer('batch_size', 64, 'Number of images in input batch. [64]') # ori 16 flags.DEFINE_integer('shuffle_buffer_size', 100000, 'Number of records to load ' 'before shuffling and yielding for consumption. [100000]') flags.DEFINE_integer('save_summaries_steps', 200, 'Number of seconds between ' @@ -68,6 +69,7 @@ def main(_, is_test=False): print('d_learning_rate', FLAGS.discriminator_learning_rate) print('g_learning_rate', FLAGS.generator_learning_rate) + print('data_dir', FLAGS.data_dir) print(FLAGS.loss_type, FLAGS.batch_size, FLAGS.beta1) print('gf_df_dim', FLAGS.gf_dim, FLAGS.df_dim) print('Starting the program..') @@ -128,6 +130,7 @@ def main(_, is_test=False): print("G step: ", FLAGS.g_step) print("D_step: ", FLAGS.d_step) train_steps = tfgan.GANTrainSteps(FLAGS.g_step, FLAGS.d_step) + tfgan.gan_train( train_ops, get_hooks_fn=tfgan.get_sequential_train_hooks( @@ -135,6 +138,7 @@ def main(_, is_test=False): hooks=([tf.train.StopAtStepHook(num_steps=2000000)] + sync_hooks), logdir=logdir, # master=FLAGS.master, + # scaffold=scaffold, # load from google checkpoint is_chief=(FLAGS.task == 0), save_summaries_steps=FLAGS.save_summaries_steps, save_checkpoint_secs=FLAGS.save_checkpoint_secs, diff --git a/utils.py b/utils.py deleted file mode 100644 index 0a49853..0000000 --- a/utils.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions used in the sngan.""" -import numpy as np -import tensorflow as tf -import math -import os -import numpy as np -import scipy.misc -import sympy -import tensorflow as tf - -tfgan = tf.contrib.gan - - - - -def get_inception_scores(images, batch_size, num_inception_images): - """Gets Inception score for some images. - - Args: - images: Image minibatch. Shape [batch size, width, height, channels]. Values - are in [-1, 1]. - batch_size: Python integer. Batch dimension. - num_inception_images: Number of images to run through Inception at once. - - Returns: - Inception scores. Tensor shape is [batch size]. - - Raises: - ValueError: If `batch_size` is incompatible with the first dimension of - `images`. - ValueError: If `batch_size` isn't divisible by `num_inception_images`. - """ - # Validate inputs. - images.shape[0:1].assert_is_compatible_with([batch_size]) - if batch_size % num_inception_images != 0: - raise ValueError( - "`batch_size` must be divisible by `num_inception_images`.") - - # Resize images. - size = tfgan.eval.INCEPTION_DEFAULT_IMAGE_SIZE - resized_images = tf.image.resize_bilinear(images, [size, size]) - - # Run images through Inception. - num_batches = batch_size // num_inception_images - inc_score = tfgan.eval.inception_score( - resized_images, num_batches=num_batches) - - return inc_score - - -def get_frechet_inception_distance(real_images, generated_images, batch_size, - num_inception_images): - """Gets Frechet Inception Distance between real and generated images. - - Args: - real_images: Real images minibatch. Shape [batch size, width, height, - channels. Values are in [-1, 1]. - generated_images: Generated images minibatch. Shape [batch size, width, - height, channels]. Values are in [-1, 1]. - batch_size: Python integer. Batch dimension. - num_inception_images: Number of images to run through Inception at once. - - Returns: - Frechet Inception distance. A floating-point scalar. - - Raises: - ValueError: If the minibatch size is known at graph construction time, and - doesn't batch `batch_size`. - """ - # Validate input dimensions. - real_images.shape[0:1].assert_is_compatible_with([batch_size]) - generated_images.shape[0:1].assert_is_compatible_with([batch_size]) - - # Resize input images. - size = tfgan.eval.INCEPTION_DEFAULT_IMAGE_SIZE - resized_real_images = tf.image.resize_bilinear(real_images, [size, size]) - resized_generated_images = tf.image.resize_bilinear( - generated_images, [size, size]) - - # Compute Frechet Inception Distance. - num_batches = batch_size // num_inception_images - fid = tfgan.eval.frechet_inception_distance( - resized_real_images, resized_generated_images, num_batches=num_batches) - - return fid - - -def simple_summary(summary_writer, summary_dict, step): - """Takes a dictionary w/ scalar values. Prints it and logs to tensorboard. - - Args: - summary_writer: The tensorflow summary writer. - summary_dict: The dictionary contains all the key-value pairs of data. - step: The current logging step. - """ - print(("Step: %s\t" % step) + "\t".join( - ["%s: %s" % (k, v) for k, v in summary_dict.items()] - )) - summary = tf.Summary(value=[ - tf.Summary.Value(tag=k, simple_value=v) for k, v in summary_dict.items() - ]) - summary_writer.add_summary(summary, global_step=step) - - -def calculate_inception_score(images, batch_size, num_inception_images): - """The function wraps the inception score calculation. - - Args: - images: The numpy array of images, pixel value is from [-1, 1]. - batch_size: The total number of images for inception score calculation. - num_inception_images: Number of images to run through Inception at once. - Returns: - Inception Score. A floating-point scalar. - """ - with tf.Graph().as_default(), tf.Session() as sess: - images_op = tf.convert_to_tensor( - np.array(images) - ) - score_op = get_inception_scores(images_op, batch_size, num_inception_images) - return sess.run(score_op) - - -def calculate_fid(real_images, generated_images, batch_size, - num_inception_images): - """The function wraps to the FID calculation. - - Args: - real_images: The numpy array of real images, pixel value is from [-1, 1]. - generated_images: The numpy array of generated images, pixel value is - from [-1, 1]. - batch_size: The total number of images for inception score calculation - num_inception_images: Number of images to run through Inception at once. - Returns: - Frechet Inception distance. A floating-point scalar. - """ - with tf.Graph().as_default(), tf.Session() as sess: - real_images_op = tf.convert_to_tensor( - np.array(real_images) - ) - generated_images_op = tf.convert_to_tensor( - np.array(generated_images) - ) - fid = get_frechet_inception_distance(real_images_op, generated_images_op, - batch_size, num_inception_images) - return sess.run(fid) diff --git a/utils_ori.py b/utils_ori.py index 88b6395..aa4d7f4 100644 --- a/utils_ori.py +++ b/utils_ori.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + """Some codes from https://github.com/Newmu/dcgan_code.""" from __future__ import absolute_import @@ -50,7 +51,8 @@ def run_custom_inception( images, output_tensor, graph_def=None, - image_size=classifier_metrics.INCEPTION_DEFAULT_IMAGE_SIZE): + # image_size=classifier_metrics.INCEPTION_DEFAULT_IMAGE_SIZE): + image_size=299): # input_tensor=classifier_metrics.INCEPTION_V1_INPUT): """Run images through a pretrained Inception classifier.