Skip to content

Commit

Permalink
Merge pull request #58 from pkuyym/fix-57
Browse files Browse the repository at this point in the history
Add do_validation.
  • Loading branch information
pkuyym authored Jan 17, 2018
2 parents 315b20f + 4ee86b7 commit 6465ea0
Showing 1 changed file with 107 additions and 74 deletions.
181 changes: 107 additions & 74 deletions tensorflow/machine_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--word_vector_dim",
"--embedding_dim",
type=int,
default=512,
help="The dimension of embedding table. (default: %(default)d)")
Expand All @@ -43,7 +43,7 @@
"--batch_size",
type=int,
default=128,
help="The sequence number of a batch data. (default: %(default)d)")
help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
"--dict_size",
type=int,
Expand All @@ -56,7 +56,7 @@
default=81,
help="Max number of time steps for sequence. (default: %(default)d)")
parser.add_argument(
"--pass_number",
"--pass_num",
type=int,
default=10,
help="The pass number to train. (default: %(default)d)")
Expand All @@ -66,11 +66,7 @@
default=0.0002,
help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--mode",
type=str,
default='train',
choices=['train', 'infer'],
help="Do training or inference. (default: %(default)s)")
"--infer_only", action='store_true', help="If set, run forward only.")
parser.add_argument(
"--beam_size",
type=int,
Expand All @@ -80,7 +76,7 @@
"--max_generation_length",
type=int,
default=250,
help="The max length of sequence when doing generation. "
help="The maximum length of sequence when doing generation. "
"(default: %(default)d)")
parser.add_argument(
"--save_freq",
Expand Down Expand Up @@ -201,7 +197,7 @@ def _simple_attention(self, encoder_vec, encoder_proj, decoder_state):
inputs=tf.reshape(
concated, shape=[-1, self._num_units * 2]),
num_outputs=1,
activation_fn=None,
activation_fn=tf.nn.tanh,
biases_initializer=None)
attention_weights_reshaped = tf.reshape(
attention_weights, shape=[tf.shape(encoder_vec)[0], -1, 1])
Expand Down Expand Up @@ -267,19 +263,14 @@ def _maybe_mask(m, seq_len_mask):
memory)


def seq_to_seq_net(word_vector_dim,
encoder_size,
decoder_size,
source_dict_dim,
target_dict_dim,
is_generating=False,
beam_size=3,
max_generation_length=250):
def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
target_dict_dim, is_generating, beam_size,
max_generation_length):
src_word_idx = tf.placeholder(tf.int32, shape=[None, None])
src_sequence_length = tf.placeholder(tf.int32, shape=[None, ])

src_embedding_weights = tf.get_variable("source_word_embeddings",
[source_dict_dim, word_vector_dim])
[source_dict_dim, embedding_dim])
src_embedding = tf.nn.embedding_lookup(src_embedding_weights, src_word_idx)

src_forward_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
Expand All @@ -298,7 +289,7 @@ def seq_to_seq_net(word_vector_dim,
# project the encoder outputs to size of decoder lstm
encoded_proj = tf.contrib.layers.fully_connected(
inputs=tf.reshape(
encoded_vec, shape=[-1, word_vector_dim * 2]),
encoded_vec, shape=[-1, embedding_dim * 2]),
num_outputs=decoder_size,
activation_fn=None,
biases_initializer=None)
Expand All @@ -309,9 +300,9 @@ def seq_to_seq_net(word_vector_dim,
backword_first = tf.slice(encoder_outputs[1], [0, 0, 0], [-1, 1, -1])
decoder_boot = tf.contrib.layers.fully_connected(
inputs=tf.reshape(
backword_first, shape=[-1, word_vector_dim]),
backword_first, shape=[-1, embedding_dim]),
num_outputs=decoder_size,
activation_fn=None,
activation_fn=tf.nn.tanh,
biases_initializer=None)

# prepare the initial state for decoder lstm
Expand All @@ -335,7 +326,7 @@ def seq_to_seq_net(word_vector_dim,
trg_word_idx = tf.placeholder(tf.int32, shape=[None, None])
trg_sequence_length = tf.placeholder(tf.int32, shape=[None, ])
trg_embedding_weights = tf.get_variable(
"target_word_embeddings", [target_dict_dim, word_vector_dim])
"target_word_embeddings", [target_dict_dim, embedding_dim])
trg_embedding = tf.nn.embedding_lookup(trg_embedding_weights,
trg_word_idx)

Expand Down Expand Up @@ -381,14 +372,19 @@ def seq_to_seq_net(word_vector_dim,
average_across_batch=True)

# return feeding list and loss operator
return (src_word_idx, src_sequence_length, trg_word_idx,
trg_sequence_length, lbl_word_idx), loss
return {
'src_word_idx': src_word_idx,
'src_sequence_length': src_sequence_length,
'trg_word_idx': trg_word_idx,
'trg_sequence_length': trg_sequence_length,
'lbl_word_idx': lbl_word_idx
}, loss
else:
start_tokens = tf.ones([tf.shape(src_word_idx)[0], ],
tf.int32) * START_TOKEN_IDX
# share the same embedding weights with target word
trg_embedding_weights = tf.get_variable(
"target_word_embeddings", [target_dict_dim, word_vector_dim])
"target_word_embeddings", [target_dict_dim, embedding_dim])

inference_decoder = beam_search_decoder.BeamSearchDecoder(
cell=decoder_cell,
Expand All @@ -407,8 +403,12 @@ def seq_to_seq_net(word_vector_dim,
#impute_finished=True,# error occurs
maximum_iterations=max_generation_length)

return (src_word_idx,
src_sequence_length), decoder_outputs_decode.predicted_ids
predicted_ids = decoder_outputs_decode.predicted_ids

return {
'src_word_idx': src_word_idx,
'src_sequence_length': src_sequence_length
}, predicted_ids


def print_arguments(args):
Expand Down Expand Up @@ -436,9 +436,43 @@ def restore(sess, path, var_list=None):
print('model restored from %s' % path)


def adapt_batch_data(data):
src_seq = map(lambda x: x[0], data)
trg_seq = map(lambda x: x[1], data)
lbl_seq = map(lambda x: x[2], data)

src_sequence_length = np.array(
[len(seq) for seq in src_seq]).astype('int32')
src_seq_maxlen = np.max(src_sequence_length)

trg_sequence_length = np.array(
[len(seq) for seq in trg_seq]).astype('int32')
trg_seq_maxlen = np.max(trg_sequence_length)

src_seq = np.array(
[padding_data(seq, src_seq_maxlen, END_TOKEN_IDX)
for seq in src_seq]).astype('int32')

trg_seq = np.array(
[padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
for seq in trg_seq]).astype('int32')

lbl_seq = np.array(
[padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
for seq in lbl_seq]).astype('int32')

return {
'src_word_idx': src_seq,
'src_sequence_length': src_sequence_length,
'trg_word_idx': trg_seq,
'trg_sequence_length': trg_sequence_length,
'lbl_word_idx': lbl_seq
}


def train():
train_feed_list, loss = seq_to_seq_net(
word_vector_dim=args.word_vector_dim,
feeding_dict, loss = seq_to_seq_net(
embedding_dim=args.embedding_dim,
encoder_size=args.encoder_size,
decoder_size=args.decoder_size,
source_dict_dim=args.dict_size,
Expand All @@ -459,65 +493,63 @@ def train():
zip(gradients, trainable_params), global_step=global_step)

src_dict, trg_dict = paddle.dataset.wmt14.get_dict(args.dict_size)

train_batch_generator = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(args.dict_size), buf_size=1000),
batch_size=args.batch_size)

test_batch_generator = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.test(args.dict_size), buf_size=1000),
batch_size=args.batch_size)

def do_validataion():
total_loss = 0.0
count = 0
for batch_id, data in enumerate(test_batch_generator()):
adapted_batch_data = adapt_batch_data(data)
outputs = sess.run([loss],
feed_dict={
item[1]: adapt_batch_data[item[0]]
for item in feeding_dict.items()
})
total_loss += outputs[0]
count += 1
return total_loss / count

config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_l)
sess.run(init_g)
for pass_id in xrange(args.pass_number):
for pass_id in xrange(args.pass_num):
pass_start_time = time.time()
words_seen = 0
for batch_id, data in enumerate(train_batch_generator()):
src_seq = map(lambda x: x[0], data)
trg_seq = map(lambda x: x[1], data)
lbl_seq = map(lambda x: x[2], data)

src_sequence_length = np.array(
[len(seq) for seq in src_seq]).astype('int32')
src_seq_maxlen = np.max(src_sequence_length)
trg_sequence_length = np.array(
[len(seq) for seq in trg_seq]).astype('int32')
trg_seq_maxlen = np.max(trg_sequence_length)

src_seq = np.array([
padding_data(seq, src_seq_maxlen, END_TOKEN_IDX)
for seq in src_seq
]).astype('int32')
trg_seq = np.array([
padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
for seq in trg_seq
]).astype('int32')
lbl_seq = np.array([
padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
for seq in lbl_seq
]).astype('int32')

adapted_batch_data = adapt_batch_data(data)
words_seen += np.sum(adapted_batch_data['src_sequence_length'])
words_seen += np.sum(adapted_batch_data['trg_sequence_length'])
outputs = sess.run([updates, loss],
feed_dict={
train_feed_list[0]: src_seq,
train_feed_list[1]: src_sequence_length,
train_feed_list[2]: trg_seq,
train_feed_list[3]: trg_sequence_length,
train_feed_list[4]: lbl_seq
item[1]: adapted_batch_data[item[0]]
for item in feeding_dict.items()
})

print("pass_id=%d, batch_id=%d, loss=%f" %
print("pass_id=%d, batch_id=%d, train_loss: %f" %
(pass_id, batch_id, outputs[1]))

if global_step.eval() % args.save_freq == 0:
print('Saving model..')
checkpoint_path = os.path.join(args.model_dir, 'tf_seq2seq')
save(sess, checkpoint_path, global_step=global_step)
pass_end_time = time.time()
test_loss = do_validataion()
time_consumed = pass_end_time - pass_start_time
words_per_sec = words_seen / time_consumed
print("pass_id=%d, test_loss: %f, words/s: %f, sec/pass: %f" %
(pass_id, test_loss, words_per_sec, time_consumed))


def infer():
infer_feed_list, predicted_ids = seq_to_seq_net(
word_vector_dim=args.word_vector_dim,
feeding_dict, predicted_ids = seq_to_seq_net(
embedding_dim=args.embedding_dim,
encoder_size=args.encoder_size,
decoder_size=args.decoder_size,
source_dict_dim=args.dict_size,
Expand Down Expand Up @@ -553,8 +585,9 @@ def infer():

outputs = sess.run([predicted_ids],
feed_dict={
infer_feed_list[0].name: src_seq,
infer_feed_list[1].name: src_sequence_length
feeding_dict['src_word_idx']: src_seq,
feeding_dict['src_sequence_length']:
src_sequence_length
})

print("\nDecoder result comparison: ")
Expand All @@ -571,7 +604,7 @@ def infer():
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
if args.mode == 'train':
train()
else:
if args.infer_only:
infer()
else:
train()

0 comments on commit 6465ea0

Please sign in to comment.