From 555b2dfdadc9e9be758f228afe72c047a59ea112 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Fri, 14 Apr 2017 11:28:10 +0800 Subject: [PATCH] add seqtext_print for seqToseq demo --- demo/seqToseq/api_train_v2.py | 39 +++++++++++++++++++++++++++---- python/paddle/v2/dataset/wmt14.py | 16 ++++++++++--- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/demo/seqToseq/api_train_v2.py b/demo/seqToseq/api_train_v2.py index ac2665b5b35bd..3072c375123a2 100644 --- a/demo/seqToseq/api_train_v2.py +++ b/demo/seqToseq/api_train_v2.py @@ -126,7 +126,7 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word): def main(): paddle.init(use_gpu=False, trainer_count=1) - is_generating = True + is_generating = False # source and target dict dim. dict_size = 30000 @@ -167,16 +167,47 @@ def event_handler(event): # generate a english sequence to french else: - gen_creator = paddle.dataset.wmt14.test(dict_size) + # use the first 3 samples for generation + gen_creator = paddle.dataset.wmt14.gen(dict_size) gen_data = [] + gen_num = 3 for item in gen_creator(): gen_data.append((item[0], )) - if len(gen_data) == 3: + if len(gen_data) == gen_num: break beam_gen = seqToseq_net(source_dict_dim, target_dict_dim, is_generating) + # get the pretrained model, whose bleu = 26.92 parameters = paddle.dataset.wmt14.model() - trg_dict = paddle.dataset.wmt14.trg_dict(dict_size) + # prob is the prediction probabilities, and id is the prediction word. + beam_result = paddle.infer( + output_layer=beam_gen, + parameters=parameters, + input=gen_data, + field=['prob', 'id']) + + # get the dictionary + src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size) + + # the delimited element of generated sequences is -1, + # the first element of each generated sequence is the sequence length + seq_list = [] + seq = [] + for w in beam_result[1]: + if w != -1: + seq.append(w) + else: + seq_list.append(' '.join([trg_dict.get(w) for w in seq[1:]])) + seq = [] + + prob = beam_result[0] + beam_size = 3 + for i in xrange(gen_num): + print "\n*******************************************************\n" + print "src:", ' '.join( + [src_dict.get(w) for w in gen_data[i][0]]), "\n" + for j in xrange(beam_size): + print "prob = %f:" % (prob[i][j]), seq_list[i * beam_size + j] if __name__ == '__main__': diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index ad9f5f18d674d..23ca8036281b1 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -26,7 +26,7 @@ MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' # this is a small set of data for test. The original data is too large and will be add later. URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' -MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6' +MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c' # this is the pretrained model, whose bleu = 26.92 URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' MD5_MODEL = '4ce14a26607fb8a1cc23bcdedb1895e4' @@ -108,6 +108,11 @@ def test(dict_size): download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size) +def gen(dict_size): + return reader_creator( + download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'gen/gen', dict_size) + + def model(): tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL) with gzip.open(tar_file, 'r') as f: @@ -115,10 +120,15 @@ def model(): return parameters -def trg_dict(dict_size): +def get_dict(dict_size, reverse=True): + # if reverse = False, return dict = {'a':'001', 'b':'002', ...} + # else reverse = true, return dict = {'001':'a', '002':'b', ...} tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN) src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) - return trg_dict + if reverse: + src_dict = {v: k for k, v in src_dict.items()} + trg_dict = {v: k for k, v in trg_dict.items()} + return src_dict, trg_dict def fetch():