forked from martin-gorner/tensorflow-rnn-shakespeare
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rnn_train.py
213 lines (186 loc) · 9.96 KB
/
rnn_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# encoding: UTF-8
# Copyright 2017 Google.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.contrib import rnn # rnn stuff temporarily in contrib, moving back to code in TF 1.1
import os
import time
import math
import numpy as np
import my_txtutils as txt
tf.set_random_seed(0)
# model parameters
#
# Usage:
# Training only:
# Leave all the parameters as they are
# Disable validation to run a bit faster (set validation=False below)
# You can follow progress in Tensorboard: tensorboard --log-dir=log
# Training and experimentation (default):
# Keep validation enabled
# You can now play with the parameters anf follow the effects in Tensorboard
# A good choice of parameters ensures that the testing and validation curves stay close
# To see the curves drift apart ("overfitting") try to use an insufficient amount of
# training data (shakedir = "shakespeare/t*.txt" for example)
#
SEQLEN = 30
BATCHSIZE = 200
ALPHASIZE = txt.ALPHASIZE
INTERNALSIZE = 512
NLAYERS = 3
learning_rate = 0.001 # fixed learning rate
dropout_pkeep = 0.8 # some dropout
# load data, either shakespeare, or the Python source of Tensorflow itself
shakedir = "papadiamantis/*.txt"
# shakedir = "../tensorflow/**/*.py"
codetext, valitext, bookranges = txt.read_data_files(shakedir, validation=False)
# display some stats on the data
epoch_size = len(codetext) // (BATCHSIZE * SEQLEN)
txt.print_data_stats(len(codetext), len(valitext), epoch_size)
#
# the model (see FAQ in README.md)
#
lr = tf.placeholder(tf.float32, name='lr') # learning rate
pkeep = tf.placeholder(tf.float32, name='pkeep') # dropout parameter
batchsize = tf.placeholder(tf.int32, name='batchsize')
# inputs
X = tf.placeholder(tf.uint8, [None, None], name='X') # [ BATCHSIZE, SEQLEN ]
Xo = tf.one_hot(X, ALPHASIZE, 1.0, 0.0) # [ BATCHSIZE, SEQLEN, ALPHASIZE ]
# expected outputs = same sequence shifted by 1 since we are trying to predict the next character
Y_ = tf.placeholder(tf.uint8, [None, None], name='Y_') # [ BATCHSIZE, SEQLEN ]
Yo_ = tf.one_hot(Y_, ALPHASIZE, 1.0, 0.0) # [ BATCHSIZE, SEQLEN, ALPHASIZE ]
# input state
Hin = tf.placeholder(tf.float32, [None, INTERNALSIZE*NLAYERS], name='Hin') # [ BATCHSIZE, INTERNALSIZE * NLAYERS]
# using a NLAYERS=3 layers of GRU cells, unrolled SEQLEN=30 times
# dynamic_rnn infers SEQLEN from the size of the inputs Xo
# How to properly apply dropout in RNNs: see README.md
cells = [rnn.GRUCell(INTERNALSIZE) for _ in range(NLAYERS)]
# "naive dropout" implementation
dropcells = [rnn.DropoutWrapper(cell,input_keep_prob=pkeep) for cell in cells]
multicell = rnn.MultiRNNCell(dropcells, state_is_tuple=False)
multicell = rnn.DropoutWrapper(multicell, output_keep_prob=pkeep) # dropout for the softmax layer
Yr, H = tf.nn.dynamic_rnn(multicell, Xo, dtype=tf.float32, initial_state=Hin)
# Yr: [ BATCHSIZE, SEQLEN, INTERNALSIZE ]
# H: [ BATCHSIZE, INTERNALSIZE*NLAYERS ] # this is the last state in the sequence
H = tf.identity(H, name='H') # just to give it a name
# Softmax layer implementation:
# Flatten the first two dimension of the output [ BATCHSIZE, SEQLEN, ALPHASIZE ] => [ BATCHSIZE x SEQLEN, ALPHASIZE ]
# then apply softmax readout layer. This way, the weights and biases are shared across unrolled time steps.
# From the readout point of view, a value coming from a sequence time step or a minibatch item is the same thing.
Yflat = tf.reshape(Yr, [-1, INTERNALSIZE]) # [ BATCHSIZE x SEQLEN, INTERNALSIZE ]
Ylogits = layers.linear(Yflat, ALPHASIZE) # [ BATCHSIZE x SEQLEN, ALPHASIZE ]
Yflat_ = tf.reshape(Yo_, [-1, ALPHASIZE]) # [ BATCHSIZE x SEQLEN, ALPHASIZE ]
loss = tf.nn.softmax_cross_entropy_with_logits(logits=Ylogits, labels=Yflat_) # [ BATCHSIZE x SEQLEN ]
loss = tf.reshape(loss, [batchsize, -1]) # [ BATCHSIZE, SEQLEN ]
Yo = tf.nn.softmax(Ylogits, name='Yo') # [ BATCHSIZE x SEQLEN, ALPHASIZE ]
Y = tf.argmax(Yo, 1) # [ BATCHSIZE x SEQLEN ]
Y = tf.reshape(Y, [batchsize, -1], name="Y") # [ BATCHSIZE, SEQLEN ]
train_step = tf.train.AdamOptimizer(lr).minimize(loss)
# stats for display
seqloss = tf.reduce_mean(loss, 1)
batchloss = tf.reduce_mean(seqloss)
accuracy = tf.reduce_mean(tf.cast(tf.equal(Y_, tf.cast(Y, tf.uint8)), tf.float32))
loss_summary = tf.summary.scalar("batch_loss", batchloss)
acc_summary = tf.summary.scalar("batch_accuracy", accuracy)
summaries = tf.summary.merge([loss_summary, acc_summary])
# Init Tensorboard stuff. This will save Tensorboard information into a different
# folder at each run named 'log/<timestamp>/'. Two sets of data are saved so that
# you can compare training and validation curves visually in Tensorboard.
timestamp = str(math.trunc(time.time()))
summary_writer = tf.summary.FileWriter("log/" + timestamp + "-training")
validation_writer = tf.summary.FileWriter("log/" + timestamp + "-validation")
# Init for saving models. They will be saved into a directory named 'checkpoints'.
# Only the last checkpoint is kept.
if not os.path.exists("checkpoints"):
os.mkdir("checkpoints")
saver = tf.train.Saver(max_to_keep=1000)
# for display: init the progress bar
DISPLAY_FREQ = 50
_50_BATCHES = DISPLAY_FREQ * BATCHSIZE * SEQLEN
progress = txt.Progress(DISPLAY_FREQ, size=111+2, msg="Training on next "+str(DISPLAY_FREQ)+" batches")
# init
istate = np.zeros([BATCHSIZE, INTERNALSIZE*NLAYERS]) # initial zero input state
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
step = 0
# training loop
for x, y_, epoch in txt.rnn_minibatch_sequencer(codetext, BATCHSIZE, SEQLEN, nb_epochs=10):
# train on one minibatch
feed_dict = {X: x, Y_: y_, Hin: istate, lr: learning_rate, pkeep: dropout_pkeep, batchsize: BATCHSIZE}
_, y, ostate = sess.run([train_step, Y, H], feed_dict=feed_dict)
# log training data for Tensorboard display a mini-batch of sequences (every 50 batches)
if step % _50_BATCHES == 0:
feed_dict = {X: x, Y_: y_, Hin: istate, pkeep: 1.0, batchsize: BATCHSIZE} # no dropout for validation
y, l, bl, acc, smm = sess.run([Y, seqloss, batchloss, accuracy, summaries], feed_dict=feed_dict)
txt.print_learning_learned_comparison(x, y, l, bookranges, bl, acc, epoch_size, step, epoch)
summary_writer.add_summary(smm, step)
# run a validation step every 50 batches
# The validation text should be a single sequence but that's too slow (1s per 1024 chars!),
# so we cut it up and batch the pieces (slightly inaccurate)
# tested: validating with 5K sequences instead of 1K is only slightly more accurate, but a lot slower.
if step % _50_BATCHES == 0 and len(valitext) > 0:
VALI_SEQLEN = 1*1024 # Sequence length for validation. State will be wrong at the start of each sequence.
bsize = len(valitext) // VALI_SEQLEN
txt.print_validation_header(len(codetext), bookranges)
vali_x, vali_y, _ = next(txt.rnn_minibatch_sequencer(valitext, bsize, VALI_SEQLEN, 1)) # all data in 1 batch
vali_nullstate = np.zeros([bsize, INTERNALSIZE*NLAYERS])
feed_dict = {X: vali_x, Y_: vali_y, Hin: vali_nullstate, pkeep: 1.0, # no dropout for validation
batchsize: bsize}
ls, acc, smm = sess.run([batchloss, accuracy, summaries], feed_dict=feed_dict)
txt.print_validation_stats(ls, acc)
# save validation data for Tensorboard
validation_writer.add_summary(smm, step)
# display a short text generated with the current weights and biases (every 150 batches)
if step // 3 % _50_BATCHES == 0:
txt.print_text_generation_header()
ry = np.array([[txt.convert_from_alphabet(ord("K"))]])
rh = np.zeros([1, INTERNALSIZE * NLAYERS])
for k in range(1000):
ryo, rh = sess.run([Yo, H], feed_dict={X: ry, pkeep: 1.0, Hin: rh, batchsize: 1})
rc = txt.sample_from_probabilities(ryo, topn=10 if epoch <= 1 else 2)
print(chr(txt.convert_to_alphabet(rc)), end="")
ry = np.array([[rc]])
txt.print_text_generation_footer()
# save a checkpoint (every 500 batches)
if step // 10 % _50_BATCHES == 0:
saved_file = saver.save(sess, 'checkpoints/rnn_train_' + timestamp, global_step=step)
print("Saved file: " + saved_file)
# display progress bar
progress.step(reset=step % _50_BATCHES == 0)
# loop state around
istate = ostate
step += BATCHSIZE * SEQLEN
# all runs: SEQLEN = 30, BATCHSIZE = 100, ALPHASIZE = 98, INTERNALSIZE = 512, NLAYERS = 3
# run 1477669632 decaying learning rate 0.001-0.0001-1e7 dropout 0.5: not good
# run 1477670023 lr=0.001 no dropout: very good
# Tensorflow runs:
# 1485434262
# trained on shakespeare/t*.txt only. Validation on 1K sequences
# validation loss goes up from step 5M (overfitting because of small dataset)
# 1485436038
# trained on shakespeare/t*.txt only. Validation on 5K sequences
# On 5K sequences validation accuracy is slightly higher and loss slightly lower
# => sequence breaks do introduce inaccuracies but the effect is small
# 1485437956
# Trained on shakespeare/*.txt. Validation on 1K sequences
# On this much larger dataset, validation loss still decreasing after 6 epochs (step 35M)
# 1495447371
# Trained on shakespeare/*.txt no dropout, 30 epochs
# Validation loss starts going up after 10 epochs (overfitting)
# 1495440473
# Trained on shakespeare/*.txt "naive dropout" pkeep=0.8, 30 epochs
# Dropout brings the validation loss under control, preventing it from
# going up but the effect is small.