Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add seqtext_print for seqToseq demo #1784

Merged
merged 1 commit into from
Apr 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions demo/seqToseq/api_train_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):

def main():
paddle.init(use_gpu=False, trainer_count=1)
is_generating = True
is_generating = False

# source and target dict dim.
dict_size = 30000
Expand Down Expand Up @@ -167,16 +167,47 @@ def event_handler(event):

# generate a english sequence to french
else:
gen_creator = paddle.dataset.wmt14.test(dict_size)
# use the first 3 samples for generation
gen_creator = paddle.dataset.wmt14.gen(dict_size)
gen_data = []
gen_num = 3
for item in gen_creator():
gen_data.append((item[0], ))
if len(gen_data) == 3:
if len(gen_data) == gen_num:
break

beam_gen = seqToseq_net(source_dict_dim, target_dict_dim, is_generating)
# get the pretrained model, whose bleu = 26.92
parameters = paddle.dataset.wmt14.model()
trg_dict = paddle.dataset.wmt14.trg_dict(dict_size)
# prob is the prediction probabilities, and id is the prediction word.
beam_result = paddle.infer(
output_layer=beam_gen,
parameters=parameters,
input=gen_data,
field=['prob', 'id'])

# get the dictionary
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)

# the delimited element of generated sequences is -1,
# the first element of each generated sequence is the sequence length
seq_list = []
seq = []
for w in beam_result[1]:
if w != -1:
seq.append(w)
else:
seq_list.append(' '.join([trg_dict.get(w) for w in seq[1:]]))
seq = []

prob = beam_result[0]
beam_size = 3
for i in xrange(gen_num):
print "\n*******************************************************\n"
print "src:", ' '.join(
[src_dict.get(w) for w in gen_data[i][0]]), "\n"
for j in xrange(beam_size):
print "prob = %f:" % (prob[i][j]), seq_list[i * beam_size + j]


if __name__ == '__main__':
Expand Down
16 changes: 13 additions & 3 deletions python/paddle/v2/dataset/wmt14.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
# this is the pretrained model, whose bleu = 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL = '4ce14a26607fb8a1cc23bcdedb1895e4'
Expand Down Expand Up @@ -108,17 +108,27 @@ def test(dict_size):
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)


def gen(dict_size):
return reader_creator(
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'gen/gen', dict_size)


def model():
tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL)
with gzip.open(tar_file, 'r') as f:
parameters = Parameters.from_tar(f)
return parameters


def trg_dict(dict_size):
def get_dict(dict_size, reverse=True):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
return trg_dict
if reverse:
src_dict = {v: k for k, v in src_dict.items()}
trg_dict = {v: k for k, v in trg_dict.items()}
return src_dict, trg_dict


def fetch():
Expand Down