From 5c885dd063569758103368f28087442e22b73b71 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 16 Jan 2018 12:56:51 +0800 Subject: [PATCH 1/3] Add do_validation. --- tensorflow/machine_translation.py | 101 ++++++++++++++++++------------ 1 file changed, 62 insertions(+), 39 deletions(-) diff --git a/tensorflow/machine_translation.py b/tensorflow/machine_translation.py index 040db63..3bc0b6c 100644 --- a/tensorflow/machine_translation.py +++ b/tensorflow/machine_translation.py @@ -436,6 +436,34 @@ def restore(sess, path, var_list=None): print('model restored from %s' % path) +def adapt_batch_data(data): + src_seq = map(lambda x: x[0], data) + trg_seq = map(lambda x: x[1], data) + lbl_seq = map(lambda x: x[2], data) + + src_sequence_length = np.array( + [len(seq) for seq in src_seq]).astype('int32') + src_seq_maxlen = np.max(src_sequence_length) + + trg_sequence_length = np.array( + [len(seq) for seq in trg_seq]).astype('int32') + trg_seq_maxlen = np.max(trg_sequence_length) + + src_seq = np.array( + [padding_data(seq, src_seq_maxlen, END_TOKEN_IDX) + for seq in src_seq]).astype('int32') + + trg_seq = np.array( + [padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX) + for seq in trg_seq]).astype('int32') + + lbl_seq = np.array( + [padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX) + for seq in lbl_seq]).astype('int32') + + return src_seq, src_sequence_length, trg_seq, trg_sequence_length, lbl_seq + + def train(): train_feed_list, loss = seq_to_seq_net( word_vector_dim=args.word_vector_dim, @@ -459,11 +487,29 @@ def train(): zip(gradients, trainable_params), global_step=global_step) src_dict, trg_dict = paddle.dataset.wmt14.get_dict(args.dict_size) + train_batch_generator = paddle.batch( paddle.reader.shuffle( paddle.dataset.wmt14.train(args.dict_size), buf_size=1000), batch_size=args.batch_size) + test_batch_generator = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.wmt14.test(args.dict_size), buf_size=1000), + batch_size=args.batch_size) + + def do_validataion(): + total_loss = 0.0 + count = 0 + for batch_id, data in enumerate(test_batch_generator()): + adapted_batch_data = adapt_batch_data(data) + outputs = sess.run( + [loss], + feed_dict=dict(zip(*[train_feed_list, adapted_batch_data]))) + total_loss += outputs[0] + count += 1 + return total_loss / count + config = tf.ConfigProto( intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) with tf.Session(config=config) as sess: @@ -472,47 +518,24 @@ def train(): sess.run(init_l) sess.run(init_g) for pass_id in xrange(args.pass_number): + pass_start_time = time.time() + words_seen = 0 for batch_id, data in enumerate(train_batch_generator()): - src_seq = map(lambda x: x[0], data) - trg_seq = map(lambda x: x[1], data) - lbl_seq = map(lambda x: x[2], data) - - src_sequence_length = np.array( - [len(seq) for seq in src_seq]).astype('int32') - src_seq_maxlen = np.max(src_sequence_length) - trg_sequence_length = np.array( - [len(seq) for seq in trg_seq]).astype('int32') - trg_seq_maxlen = np.max(trg_sequence_length) - - src_seq = np.array([ - padding_data(seq, src_seq_maxlen, END_TOKEN_IDX) - for seq in src_seq - ]).astype('int32') - trg_seq = np.array([ - padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX) - for seq in trg_seq - ]).astype('int32') - lbl_seq = np.array([ - padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX) - for seq in lbl_seq - ]).astype('int32') - - outputs = sess.run([updates, loss], - feed_dict={ - train_feed_list[0]: src_seq, - train_feed_list[1]: src_sequence_length, - train_feed_list[2]: trg_seq, - train_feed_list[3]: trg_sequence_length, - train_feed_list[4]: lbl_seq - }) - - print("pass_id=%d, batch_id=%d, loss=%f" % + adapted_batch_data = adapt_batch_data(data) + words_seen += np.sum(adapted_batch_data[1]) + words_seen += np.sum(adapted_batch_data[3]) + outputs = sess.run( + [updates, loss], + feed_dict=dict( + zip(*[train_feed_list, adapted_batch_data]))) + print("pass_id=%d, batch_id=%d, train_loss: %f" % (pass_id, batch_id, outputs[1])) - - if global_step.eval() % args.save_freq == 0: - print('Saving model..') - checkpoint_path = os.path.join(args.model_dir, 'tf_seq2seq') - save(sess, checkpoint_path, global_step=global_step) + pass_end_time = time.time() + test_loss = do_validataion() + time_consumed = pass_end_time - pass_start_time + words_per_sec = words_seen / time_consumed + print("pass_id=%d, test_loss: %f, words/s: %f, sec/pass: %f" % + (pass_id, test_loss, words_per_sec, time_consumed)) def infer(): From 44701e8c59dfa9d528bfd7542adc836b514ff898 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 16 Jan 2018 17:48:08 +0800 Subject: [PATCH 2/3] Refine the model. --- tensorflow/machine_translation.py | 51 +++++++++++++------------------ 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/tensorflow/machine_translation.py b/tensorflow/machine_translation.py index 3bc0b6c..882a995 100644 --- a/tensorflow/machine_translation.py +++ b/tensorflow/machine_translation.py @@ -25,7 +25,7 @@ parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--word_vector_dim", + "--embedding_dim", type=int, default=512, help="The dimension of embedding table. (default: %(default)d)") @@ -43,7 +43,7 @@ "--batch_size", type=int, default=128, - help="The sequence number of a batch data. (default: %(default)d)") + help="The sequence number of a mini-batch data. (default: %(default)d)") parser.add_argument( "--dict_size", type=int, @@ -56,7 +56,7 @@ default=81, help="Max number of time steps for sequence. (default: %(default)d)") parser.add_argument( - "--pass_number", + "--pass_num", type=int, default=10, help="The pass number to train. (default: %(default)d)") @@ -66,11 +66,7 @@ default=0.0002, help="Learning rate used to train the model. (default: %(default)f)") parser.add_argument( - "--mode", - type=str, - default='train', - choices=['train', 'infer'], - help="Do training or inference. (default: %(default)s)") + "--infer_only", action='store_true', help="If set, run forward only.") parser.add_argument( "--beam_size", type=int, @@ -80,7 +76,7 @@ "--max_generation_length", type=int, default=250, - help="The max length of sequence when doing generation. " + help="The maximum length of sequence when doing generation. " "(default: %(default)d)") parser.add_argument( "--save_freq", @@ -201,7 +197,7 @@ def _simple_attention(self, encoder_vec, encoder_proj, decoder_state): inputs=tf.reshape( concated, shape=[-1, self._num_units * 2]), num_outputs=1, - activation_fn=None, + activation_fn=tf.nn.tanh, biases_initializer=None) attention_weights_reshaped = tf.reshape( attention_weights, shape=[tf.shape(encoder_vec)[0], -1, 1]) @@ -267,19 +263,14 @@ def _maybe_mask(m, seq_len_mask): memory) -def seq_to_seq_net(word_vector_dim, - encoder_size, - decoder_size, - source_dict_dim, - target_dict_dim, - is_generating=False, - beam_size=3, - max_generation_length=250): +def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim, + target_dict_dim, is_generating, beam_size, + max_generation_length): src_word_idx = tf.placeholder(tf.int32, shape=[None, None]) src_sequence_length = tf.placeholder(tf.int32, shape=[None, ]) src_embedding_weights = tf.get_variable("source_word_embeddings", - [source_dict_dim, word_vector_dim]) + [source_dict_dim, embedding_dim]) src_embedding = tf.nn.embedding_lookup(src_embedding_weights, src_word_idx) src_forward_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size) @@ -298,7 +289,7 @@ def seq_to_seq_net(word_vector_dim, # project the encoder outputs to size of decoder lstm encoded_proj = tf.contrib.layers.fully_connected( inputs=tf.reshape( - encoded_vec, shape=[-1, word_vector_dim * 2]), + encoded_vec, shape=[-1, embedding_dim * 2]), num_outputs=decoder_size, activation_fn=None, biases_initializer=None) @@ -309,9 +300,9 @@ def seq_to_seq_net(word_vector_dim, backword_first = tf.slice(encoder_outputs[1], [0, 0, 0], [-1, 1, -1]) decoder_boot = tf.contrib.layers.fully_connected( inputs=tf.reshape( - backword_first, shape=[-1, word_vector_dim]), + backword_first, shape=[-1, embedding_dim]), num_outputs=decoder_size, - activation_fn=None, + activation_fn=tf.nn.tanh, biases_initializer=None) # prepare the initial state for decoder lstm @@ -335,7 +326,7 @@ def seq_to_seq_net(word_vector_dim, trg_word_idx = tf.placeholder(tf.int32, shape=[None, None]) trg_sequence_length = tf.placeholder(tf.int32, shape=[None, ]) trg_embedding_weights = tf.get_variable( - "target_word_embeddings", [target_dict_dim, word_vector_dim]) + "target_word_embeddings", [target_dict_dim, embedding_dim]) trg_embedding = tf.nn.embedding_lookup(trg_embedding_weights, trg_word_idx) @@ -388,7 +379,7 @@ def seq_to_seq_net(word_vector_dim, tf.int32) * START_TOKEN_IDX # share the same embedding weights with target word trg_embedding_weights = tf.get_variable( - "target_word_embeddings", [target_dict_dim, word_vector_dim]) + "target_word_embeddings", [target_dict_dim, embedding_dim]) inference_decoder = beam_search_decoder.BeamSearchDecoder( cell=decoder_cell, @@ -466,7 +457,7 @@ def adapt_batch_data(data): def train(): train_feed_list, loss = seq_to_seq_net( - word_vector_dim=args.word_vector_dim, + embedding_dim=args.embedding_dim, encoder_size=args.encoder_size, decoder_size=args.decoder_size, source_dict_dim=args.dict_size, @@ -517,7 +508,7 @@ def do_validataion(): init_l = tf.local_variables_initializer() sess.run(init_l) sess.run(init_g) - for pass_id in xrange(args.pass_number): + for pass_id in xrange(args.pass_num): pass_start_time = time.time() words_seen = 0 for batch_id, data in enumerate(train_batch_generator()): @@ -540,7 +531,7 @@ def do_validataion(): def infer(): infer_feed_list, predicted_ids = seq_to_seq_net( - word_vector_dim=args.word_vector_dim, + embedding_dim=args.embedding_dim, encoder_size=args.encoder_size, decoder_size=args.decoder_size, source_dict_dim=args.dict_size, @@ -594,7 +585,7 @@ def infer(): if __name__ == '__main__': args = parser.parse_args() print_arguments(args) - if args.mode == 'train': - train() - else: + if args.infer_only: infer() + else: + train() From 4ee86b7259790bd3244a33d1750e21521cff3726 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 17 Jan 2018 15:16:24 +0800 Subject: [PATCH 3/3] Refine the data feeding part. --- tensorflow/machine_translation.py | 55 +++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/tensorflow/machine_translation.py b/tensorflow/machine_translation.py index 882a995..6753a47 100644 --- a/tensorflow/machine_translation.py +++ b/tensorflow/machine_translation.py @@ -372,8 +372,13 @@ def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim, average_across_batch=True) # return feeding list and loss operator - return (src_word_idx, src_sequence_length, trg_word_idx, - trg_sequence_length, lbl_word_idx), loss + return { + 'src_word_idx': src_word_idx, + 'src_sequence_length': src_sequence_length, + 'trg_word_idx': trg_word_idx, + 'trg_sequence_length': trg_sequence_length, + 'lbl_word_idx': lbl_word_idx + }, loss else: start_tokens = tf.ones([tf.shape(src_word_idx)[0], ], tf.int32) * START_TOKEN_IDX @@ -398,8 +403,12 @@ def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim, #impute_finished=True,# error occurs maximum_iterations=max_generation_length) - return (src_word_idx, - src_sequence_length), decoder_outputs_decode.predicted_ids + predicted_ids = decoder_outputs_decode.predicted_ids + + return { + 'src_word_idx': src_word_idx, + 'src_sequence_length': src_sequence_length + }, predicted_ids def print_arguments(args): @@ -452,11 +461,17 @@ def adapt_batch_data(data): [padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX) for seq in lbl_seq]).astype('int32') - return src_seq, src_sequence_length, trg_seq, trg_sequence_length, lbl_seq + return { + 'src_word_idx': src_seq, + 'src_sequence_length': src_sequence_length, + 'trg_word_idx': trg_seq, + 'trg_sequence_length': trg_sequence_length, + 'lbl_word_idx': lbl_seq + } def train(): - train_feed_list, loss = seq_to_seq_net( + feeding_dict, loss = seq_to_seq_net( embedding_dim=args.embedding_dim, encoder_size=args.encoder_size, decoder_size=args.decoder_size, @@ -494,9 +509,11 @@ def do_validataion(): count = 0 for batch_id, data in enumerate(test_batch_generator()): adapted_batch_data = adapt_batch_data(data) - outputs = sess.run( - [loss], - feed_dict=dict(zip(*[train_feed_list, adapted_batch_data]))) + outputs = sess.run([loss], + feed_dict={ + item[1]: adapt_batch_data[item[0]] + for item in feeding_dict.items() + }) total_loss += outputs[0] count += 1 return total_loss / count @@ -513,12 +530,13 @@ def do_validataion(): words_seen = 0 for batch_id, data in enumerate(train_batch_generator()): adapted_batch_data = adapt_batch_data(data) - words_seen += np.sum(adapted_batch_data[1]) - words_seen += np.sum(adapted_batch_data[3]) - outputs = sess.run( - [updates, loss], - feed_dict=dict( - zip(*[train_feed_list, adapted_batch_data]))) + words_seen += np.sum(adapted_batch_data['src_sequence_length']) + words_seen += np.sum(adapted_batch_data['trg_sequence_length']) + outputs = sess.run([updates, loss], + feed_dict={ + item[1]: adapted_batch_data[item[0]] + for item in feeding_dict.items() + }) print("pass_id=%d, batch_id=%d, train_loss: %f" % (pass_id, batch_id, outputs[1])) pass_end_time = time.time() @@ -530,7 +548,7 @@ def do_validataion(): def infer(): - infer_feed_list, predicted_ids = seq_to_seq_net( + feeding_dict, predicted_ids = seq_to_seq_net( embedding_dim=args.embedding_dim, encoder_size=args.encoder_size, decoder_size=args.decoder_size, @@ -567,8 +585,9 @@ def infer(): outputs = sess.run([predicted_ids], feed_dict={ - infer_feed_list[0].name: src_seq, - infer_feed_list[1].name: src_sequence_length + feeding_dict['src_word_idx']: src_seq, + feeding_dict['src_sequence_length']: + src_sequence_length }) print("\nDecoder result comparison: ")