Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add do_validation. #58

Merged
merged 3 commits into from
Jan 17, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 107 additions & 74 deletions tensorflow/machine_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--word_vector_dim",
"--embedding_dim",
type=int,
default=512,
help="The dimension of embedding table. (default: %(default)d)")
Expand All @@ -43,7 +43,7 @@
"--batch_size",
type=int,
default=128,
help="The sequence number of a batch data. (default: %(default)d)")
help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
"--dict_size",
type=int,
Expand All @@ -56,7 +56,7 @@
default=81,
help="Max number of time steps for sequence. (default: %(default)d)")
parser.add_argument(
"--pass_number",
"--pass_num",
type=int,
default=10,
help="The pass number to train. (default: %(default)d)")
Expand All @@ -66,11 +66,7 @@
default=0.0002,
help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--mode",
type=str,
default='train',
choices=['train', 'infer'],
help="Do training or inference. (default: %(default)s)")
"--infer_only", action='store_true', help="If set, run forward only.")
parser.add_argument(
"--beam_size",
type=int,
Expand All @@ -80,7 +76,7 @@
"--max_generation_length",
type=int,
default=250,
help="The max length of sequence when doing generation. "
help="The maximum length of sequence when doing generation. "
"(default: %(default)d)")
parser.add_argument(
"--save_freq",
Expand Down Expand Up @@ -201,7 +197,7 @@ def _simple_attention(self, encoder_vec, encoder_proj, decoder_state):
inputs=tf.reshape(
concated, shape=[-1, self._num_units * 2]),
num_outputs=1,
activation_fn=None,
activation_fn=tf.nn.tanh,
biases_initializer=None)
attention_weights_reshaped = tf.reshape(
attention_weights, shape=[tf.shape(encoder_vec)[0], -1, 1])
Expand Down Expand Up @@ -267,19 +263,14 @@ def _maybe_mask(m, seq_len_mask):
memory)


def seq_to_seq_net(word_vector_dim,
encoder_size,
decoder_size,
source_dict_dim,
target_dict_dim,
is_generating=False,
beam_size=3,
max_generation_length=250):
def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
target_dict_dim, is_generating, beam_size,
max_generation_length):
src_word_idx = tf.placeholder(tf.int32, shape=[None, None])
src_sequence_length = tf.placeholder(tf.int32, shape=[None, ])

src_embedding_weights = tf.get_variable("source_word_embeddings",
[source_dict_dim, word_vector_dim])
[source_dict_dim, embedding_dim])
src_embedding = tf.nn.embedding_lookup(src_embedding_weights, src_word_idx)

src_forward_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
Expand All @@ -298,7 +289,7 @@ def seq_to_seq_net(word_vector_dim,
# project the encoder outputs to size of decoder lstm
encoded_proj = tf.contrib.layers.fully_connected(
inputs=tf.reshape(
encoded_vec, shape=[-1, word_vector_dim * 2]),
encoded_vec, shape=[-1, embedding_dim * 2]),
num_outputs=decoder_size,
activation_fn=None,
biases_initializer=None)
Expand All @@ -309,9 +300,9 @@ def seq_to_seq_net(word_vector_dim,
backword_first = tf.slice(encoder_outputs[1], [0, 0, 0], [-1, 1, -1])
decoder_boot = tf.contrib.layers.fully_connected(
inputs=tf.reshape(
backword_first, shape=[-1, word_vector_dim]),
backword_first, shape=[-1, embedding_dim]),
num_outputs=decoder_size,
activation_fn=None,
activation_fn=tf.nn.tanh,
biases_initializer=None)

# prepare the initial state for decoder lstm
Expand All @@ -335,7 +326,7 @@ def seq_to_seq_net(word_vector_dim,
trg_word_idx = tf.placeholder(tf.int32, shape=[None, None])
trg_sequence_length = tf.placeholder(tf.int32, shape=[None, ])
trg_embedding_weights = tf.get_variable(
"target_word_embeddings", [target_dict_dim, word_vector_dim])
"target_word_embeddings", [target_dict_dim, embedding_dim])
trg_embedding = tf.nn.embedding_lookup(trg_embedding_weights,
trg_word_idx)

Expand Down Expand Up @@ -381,14 +372,19 @@ def seq_to_seq_net(word_vector_dim,
average_across_batch=True)

# return feeding list and loss operator
return (src_word_idx, src_sequence_length, trg_word_idx,
trg_sequence_length, lbl_word_idx), loss
return {
'src_word_idx': src_word_idx,
'src_sequence_length': src_sequence_length,
'trg_word_idx': trg_word_idx,
'trg_sequence_length': trg_sequence_length,
'lbl_word_idx': lbl_word_idx
}, loss
else:
start_tokens = tf.ones([tf.shape(src_word_idx)[0], ],
tf.int32) * START_TOKEN_IDX
# share the same embedding weights with target word
trg_embedding_weights = tf.get_variable(
"target_word_embeddings", [target_dict_dim, word_vector_dim])
"target_word_embeddings", [target_dict_dim, embedding_dim])

inference_decoder = beam_search_decoder.BeamSearchDecoder(
cell=decoder_cell,
Expand All @@ -407,8 +403,12 @@ def seq_to_seq_net(word_vector_dim,
#impute_finished=True,# error occurs
maximum_iterations=max_generation_length)

return (src_word_idx,
src_sequence_length), decoder_outputs_decode.predicted_ids
predicted_ids = decoder_outputs_decode.predicted_ids

return {
'src_word_idx': src_word_idx,
'src_sequence_length': src_sequence_length
}, predicted_ids


def print_arguments(args):
Expand Down Expand Up @@ -436,9 +436,43 @@ def restore(sess, path, var_list=None):
print('model restored from %s' % path)


def adapt_batch_data(data):
src_seq = map(lambda x: x[0], data)
trg_seq = map(lambda x: x[1], data)
lbl_seq = map(lambda x: x[2], data)

src_sequence_length = np.array(
[len(seq) for seq in src_seq]).astype('int32')
src_seq_maxlen = np.max(src_sequence_length)

trg_sequence_length = np.array(
[len(seq) for seq in trg_seq]).astype('int32')
trg_seq_maxlen = np.max(trg_sequence_length)

src_seq = np.array(
[padding_data(seq, src_seq_maxlen, END_TOKEN_IDX)
for seq in src_seq]).astype('int32')

trg_seq = np.array(
[padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
for seq in trg_seq]).astype('int32')

lbl_seq = np.array(
[padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
for seq in lbl_seq]).astype('int32')

return {
'src_word_idx': src_seq,
'src_sequence_length': src_sequence_length,
'trg_word_idx': trg_seq,
'trg_sequence_length': trg_sequence_length,
'lbl_word_idx': lbl_seq
}


def train():
train_feed_list, loss = seq_to_seq_net(
word_vector_dim=args.word_vector_dim,
feeding_dict, loss = seq_to_seq_net(
embedding_dim=args.embedding_dim,
encoder_size=args.encoder_size,
decoder_size=args.decoder_size,
source_dict_dim=args.dict_size,
Expand All @@ -459,65 +493,63 @@ def train():
zip(gradients, trainable_params), global_step=global_step)

src_dict, trg_dict = paddle.dataset.wmt14.get_dict(args.dict_size)

train_batch_generator = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(args.dict_size), buf_size=1000),
batch_size=args.batch_size)

test_batch_generator = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.test(args.dict_size), buf_size=1000),
batch_size=args.batch_size)

def do_validataion():
total_loss = 0.0
count = 0
for batch_id, data in enumerate(test_batch_generator()):
adapted_batch_data = adapt_batch_data(data)
outputs = sess.run([loss],
feed_dict={
item[1]: adapt_batch_data[item[0]]
for item in feeding_dict.items()
})
total_loss += outputs[0]
count += 1
return total_loss / count

config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_l)
sess.run(init_g)
for pass_id in xrange(args.pass_number):
for pass_id in xrange(args.pass_num):
pass_start_time = time.time()
words_seen = 0
for batch_id, data in enumerate(train_batch_generator()):
src_seq = map(lambda x: x[0], data)
trg_seq = map(lambda x: x[1], data)
lbl_seq = map(lambda x: x[2], data)

src_sequence_length = np.array(
[len(seq) for seq in src_seq]).astype('int32')
src_seq_maxlen = np.max(src_sequence_length)
trg_sequence_length = np.array(
[len(seq) for seq in trg_seq]).astype('int32')
trg_seq_maxlen = np.max(trg_sequence_length)

src_seq = np.array([
padding_data(seq, src_seq_maxlen, END_TOKEN_IDX)
for seq in src_seq
]).astype('int32')
trg_seq = np.array([
padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
for seq in trg_seq
]).astype('int32')
lbl_seq = np.array([
padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
for seq in lbl_seq
]).astype('int32')

adapted_batch_data = adapt_batch_data(data)
words_seen += np.sum(adapted_batch_data['src_sequence_length'])
words_seen += np.sum(adapted_batch_data['trg_sequence_length'])
outputs = sess.run([updates, loss],
feed_dict={
train_feed_list[0]: src_seq,
train_feed_list[1]: src_sequence_length,
train_feed_list[2]: trg_seq,
train_feed_list[3]: trg_sequence_length,
train_feed_list[4]: lbl_seq
item[1]: adapted_batch_data[item[0]]
for item in feeding_dict.items()
})

print("pass_id=%d, batch_id=%d, loss=%f" %
print("pass_id=%d, batch_id=%d, train_loss: %f" %
(pass_id, batch_id, outputs[1]))

if global_step.eval() % args.save_freq == 0:
print('Saving model..')
checkpoint_path = os.path.join(args.model_dir, 'tf_seq2seq')
save(sess, checkpoint_path, global_step=global_step)
pass_end_time = time.time()
test_loss = do_validataion()
time_consumed = pass_end_time - pass_start_time
words_per_sec = words_seen / time_consumed
print("pass_id=%d, test_loss: %f, words/s: %f, sec/pass: %f" %
(pass_id, test_loss, words_per_sec, time_consumed))


def infer():
infer_feed_list, predicted_ids = seq_to_seq_net(
word_vector_dim=args.word_vector_dim,
feeding_dict, predicted_ids = seq_to_seq_net(
embedding_dim=args.embedding_dim,
encoder_size=args.encoder_size,
decoder_size=args.decoder_size,
source_dict_dim=args.dict_size,
Expand Down Expand Up @@ -553,8 +585,9 @@ def infer():

outputs = sess.run([predicted_ids],
feed_dict={
infer_feed_list[0].name: src_seq,
infer_feed_list[1].name: src_sequence_length
feeding_dict['src_word_idx']: src_seq,
feeding_dict['src_sequence_length']:
src_sequence_length
})

print("\nDecoder result comparison: ")
Expand All @@ -571,7 +604,7 @@ def infer():
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
if args.mode == 'train':
train()
else:
if args.infer_only:
infer()
else:
train()