-
Notifications
You must be signed in to change notification settings - Fork 2
/
pix2pix_network.py
101 lines (93 loc) · 6.66 KB
/
pix2pix_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import tensorflow as tf
from utils import *
class pix2pix_network(object):
def __init__(self, image_A_batch, image_B_batch, batch_size, dropout_rate, weights_path=''):
# Parse input arguments into class variables
self.image_A = image_A_batch
self.image_B = image_B_batch
self.batch_size = batch_size
self.dropout_rate = dropout_rate
self.WEIGHTS_PATH = weights_path
self.l1_Weight = 100.0
def generator_output(self,image_A_input):
# NOTE! the order of operations (as per aauthor's original code - https://github.com/phillipi/pix2pix) is:
# non-linearity (if needed) -> conv -> batchnorm (if needed) for each layer
scope_name = 'gen_e1'
self.gen_e1 = apply_batchnorm(conv(image_A_input, 4, 4, 64, 2, 2, padding = 'SAME', name = scope_name),name = scope_name)
scope_name = 'gen_e2'
self.gen_e2 = apply_batchnorm(conv(lrelu(self.gen_e1,lrelu_alpha=0.2,name = scope_name), 4, 4, 128, 2, 2, padding = 'SAME', name = scope_name),name = scope_name)
scope_name = 'gen_e3'
self.gen_e3 = apply_batchnorm(conv(lrelu(self.gen_e2,lrelu_alpha=0.2,name = scope_name), 4, 4, 256, 2, 2, padding = 'SAME', name = scope_name),name = scope_name)
scope_name = 'gen_e4'
self.gen_e4 = apply_batchnorm(conv(lrelu(self.gen_e3,lrelu_alpha=0.2,name = scope_name), 4, 4, 512, 2, 2, padding = 'SAME', name = scope_name),name = scope_name)
scope_name = 'gen_e5'
self.gen_e5 = apply_batchnorm(conv(lrelu(self.gen_e4,lrelu_alpha=0.2,name = scope_name), 4, 4, 512, 2, 2, padding = 'SAME', name = scope_name),name = scope_name)
scope_name = 'gen_e6'
self.gen_e6 = apply_batchnorm(conv(lrelu(self.gen_e5,lrelu_alpha=0.2,name = scope_name), 4, 4, 512, 2, 2, padding = 'SAME', name = scope_name),name = scope_name)
scope_name = 'gen_e7'
self.gen_e7 = apply_batchnorm(conv(lrelu(self.gen_e6,lrelu_alpha=0.2,name = scope_name), 4, 4, 512, 2, 2, padding = 'SAME', name = scope_name),name = scope_name)
scope_name = 'gen_e8'
self.gen_e8 = conv(lrelu(self.gen_e7,lrelu_alpha=0.2,name = scope_name), 4, 4, 512, 2, 2, padding = 'SAME', name=scope_name)
scope_name = 'gen_d1'
self.gen_d1 = apply_batchnorm(deconv(lrelu(self.gen_e8,lrelu_alpha=0.0, name = scope_name), 4, 4, 512, self.batch_size, 2, 2, padding = 'SAME', name = scope_name),name = scope_name)
self.gen_d1_dropout = dropout(self.gen_d1, self.dropout_rate)
scope_name = 'gen_d2'
self.gen_d2 = apply_batchnorm(deconv(lrelu(tf.concat([self.gen_d1_dropout, self.gen_e7],3),lrelu_alpha=0.0, name = scope_name), 4, 4, 512, self.batch_size, 2, 2, padding = 'SAME', name = scope_name),name = scope_name)
self.gen_d2_dropout = dropout(self.gen_d2, self.dropout_rate)
scope_name = 'gen_d3'
self.gen_d3 = apply_batchnorm(deconv(lrelu(tf.concat([self.gen_d2_dropout, self.gen_e6],3), lrelu_alpha=0.0, name = scope_name), 4, 4, 512, self.batch_size, 2, 2, padding = 'SAME', name = scope_name),name = scope_name)
self.gen_d3_dropout = dropout(self.gen_d3, self.dropout_rate)
scope_name = 'gen_d4'
self.gen_d4 = apply_batchnorm(deconv(lrelu(tf.concat([self.gen_d3_dropout, self.gen_e5],3), lrelu_alpha=0.0, name = scope_name), 4, 4, 512, self.batch_size, 2, 2, padding = 'SAME', name = scope_name), name = scope_name)
scope_name = 'gen_d5'
self.gen_d5 = apply_batchnorm(deconv(lrelu(tf.concat([self.gen_d4, self.gen_e4],3), lrelu_alpha=0.0, name = scope_name), 4, 4, 256, self.batch_size, 2, 2, padding = 'SAME', name = scope_name), name = scope_name)
scope_name = 'gen_d6'
self.gen_d6 = apply_batchnorm(deconv(lrelu(tf.concat([self.gen_d5, self.gen_e3],3), lrelu_alpha=0.0, name = scope_name), 4, 4, 128, self.batch_size, 2, 2, padding = 'SAME', name = scope_name), name = scope_name)
scope_name = 'gen_d7'
self.gen_d7 = apply_batchnorm(deconv(lrelu(tf.concat([self.gen_d6, self.gen_e2],3), lrelu_alpha=0.0, name = scope_name), 4, 4, 64, self.batch_size, 2, 2, padding = 'SAME', name = scope_name), name = scope_name)
scope_name = 'gen_d8'
self.gen_d8 = deconv(lrelu(tf.concat([self.gen_d7, self.gen_e1],3), lrelu_alpha=0.0, name = scope_name), 4, 4, 3, self.batch_size, 2, 2, padding = 'SAME', name = scope_name)
# generated output
self.fake_B = tf.nn.tanh(self.gen_d8)
return self.fake_B
def discriminator_output(self, B_input): # 70x70 discriminator
discrim_input = tf.concat([self.image_A,B_input],3)
scope_name = 'dis_conv1'
self.dis_conv1 = lrelu(conv(discrim_input,4,4,64,2,2,padding='SAME',name=scope_name),lrelu_alpha=0.2, name=scope_name)
scope_name = 'dis_conv2'
self.dis_conv2 = lrelu(apply_batchnorm(conv(self.dis_conv1,4,4,128,2,2,padding='SAME',name=scope_name),name=scope_name),lrelu_alpha=0.2, name=scope_name)
scope_name = 'dis_conv3'
self.dis_conv3 = lrelu(apply_batchnorm(conv(self.dis_conv2,4,4,256,2,2,padding='SAME',name=scope_name),name=scope_name),lrelu_alpha=0.2, name=scope_name)
scope_name = 'dis_conv4'
self.dis_conv4 = lrelu(apply_batchnorm(conv(self.dis_conv3,4,4,512,1,1,padding='SAME',name=scope_name),name=scope_name),lrelu_alpha=0.2, name=scope_name)
scope_name = 'dis_conv5'
self.dis_conv5 = conv(self.dis_conv4,4,4,1,1,1,padding='SAME',name=scope_name)
self.dis_out_per_patch = tf.reshape(self.dis_conv5,[self.batch_size,-1])
return tf.sigmoid(self.dis_out_per_patch)
def compute_loss(self):
eps = 1e-12
fake_B = self.generator_output(self.image_A)
fake_output_D = self.discriminator_output(fake_B)
real_output_D = self.discriminator_output(self.image_B)
self.d_loss = tf.reduce_mean(-(tf.log(real_output_D + eps) + tf.log(1 - fake_output_D + eps)))
self.g_loss_l1= self.l1_Weight*tf.reduce_mean(tf.abs(fake_B - self.image_B))
self.g_loss_gan = tf.reduce_mean(-tf.log(fake_output_D + eps))
return self.d_loss, self.g_loss_l1 + self.g_loss_gan, self.g_loss_l1, self.g_loss_gan
def load_initial_weights(self, session):
# Load the weights into memory: this approach is adopted rather than standard random initialization to allow the
# flexibility to load weights from a numpy file or other files.
if self.WEIGHTS_PATH:
print 'loading initial weights from '+ self.WEIGHTS_PATH
weights_dict = np.load(self.WEIGHTS_PATH, encoding = 'bytes').item()
# else:
# print 'loading random weights'
# weights_dict = get_random_weight_dictionary('pix2pix_initial_weights')
# Loop over all layer names stored in the weights dict
for op_name in weights_dict:
print op_name
with tf.variable_scope(op_name) as scope:
for sub_op_name in weights_dict[op_name]:
data = weights_dict[op_name][sub_op_name]
var = get_scope_variable(name, sub_op_name, shape=[data.shape[0], data.shape[1], data.shape[2], data.shape[3]])
session.run(var.assign(data))