Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WMT16 into dataset. #7661

Merged
merged 2 commits into from
Jan 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions python/paddle/v2/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,23 @@
import uci_housing
import sentiment
import wmt14
import wmt16
import mq2007
import flowers
import voc2012

__all__ = [
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012'
'mnist',
'imikolov',
'imdb',
'cifar',
'movielens',
'conll05',
'sentiment'
'uci_housing',
'wmt14',
'wmt16',
'mq2007',
'flowers',
'voc2012',
]
21 changes: 15 additions & 6 deletions python/paddle/v2/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@
import cPickle as pickle

__all__ = [
'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader',
'convert'
'DATA_HOME',
'download',
'md5file',
'split',
'cluster_files_reader',
'convert',
]

DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
Expand Down Expand Up @@ -58,12 +62,15 @@ def md5file(fname):
return hash_md5.hexdigest()


def download(url, module_name, md5sum):
def download(url, module_name, md5sum, save_name=None):
dirname = os.path.join(DATA_HOME, module_name)
if not os.path.exists(dirname):
os.makedirs(dirname)

filename = os.path.join(dirname, url.split('/')[-1])
filename = os.path.join(dirname,
url.split('/')[-1]
if save_name is None else save_name)

retry = 0
retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum):
Expand Down Expand Up @@ -196,9 +203,11 @@ def convert(output_path, reader, line_count, name_prefix):
Convert data from reader to recordio format files.

:param output_path: directory in which output files will be saved.
:param reader: a data reader, from which the convert program will read data instances.
:param reader: a data reader, from which the convert program will read
data instances.
:param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
:param max_lines_to_shuffle: the max lines numbers to shuffle before
writing.
"""

assert line_count >= 1
Expand Down
66 changes: 66 additions & 0 deletions python/paddle/v2/dataset/tests/wmt16_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle.v2.dataset.wmt16
import unittest


class TestWMT16(unittest.TestCase):
def checkout_one_sample(self, sample):
# train data has 3 field: source language word indices,
# target language word indices, and target next word indices.
self.assertEqual(len(sample), 3)

# test start mark and end mark in source word indices.
self.assertEqual(sample[0][0], 0)
self.assertEqual(sample[0][-1], 1)

# test start mask in target word indices
self.assertEqual(sample[1][0], 0)

# test en mask in target next word indices
self.assertEqual(sample[2][-1], 1)

def test_train(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.train(
src_dict_size=100000, trg_dict_size=100000)()):
if idx >= 10: break
self.checkout_one_sample(sample)

def test_test(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.test(
src_dict_size=1000, trg_dict_size=1000)()):
if idx >= 10: break
self.checkout_one_sample(sample)

def test_val(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.validation(
src_dict_size=1000, trg_dict_size=1000)()):
if idx >= 10: break
self.checkout_one_sample(sample)

def test_get_dict(self):
dict_size = 1000
word_dict = paddle.v2.dataset.wmt16.get_dict("en", dict_size, True)
self.assertEqual(len(word_dict), dict_size)
self.assertEqual(word_dict[0], "<s>")
self.assertEqual(word_dict[1], "<e>")
self.assertEqual(word_dict[2], "<unk>")


if __name__ == "__main__":
unittest.main()
30 changes: 19 additions & 11 deletions python/paddle/v2/dataset/wmt14.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,20 @@
import paddle.v2.dataset.common
from paddle.v2.parameters import Parameters

__all__ = ['train', 'test', 'build_dict', 'convert']

URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
__all__ = [
'train',
'test',
'get_dict',
'convert',
]

URL_DEV_TEST = ('http://www-lium.univ-lemans.fr/~schwenk/'
'cslm_joint_paper/data/dev+test.tgz')
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
# this is a small set of data for test. The original data is too large and
# will be add later.
URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/'
'wmt_shrinked_data/wmt14.tgz')
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
# BLEU of this trained model is 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
Expand All @@ -42,8 +50,8 @@
UNK_IDX = 2


def __read_to_dict__(tar_file, dict_size):
def __to_dict__(fd, size):
def __read_to_dict(tar_file, dict_size):
def __to_dict(fd, size):
out_dict = dict()
for line_count, line in enumerate(fd):
if line_count < size:
Expand All @@ -58,19 +66,19 @@ def __to_dict__(fd, size):
if each_item.name.endswith("src.dict")
]
assert len(names) == 1
src_dict = __to_dict__(f.extractfile(names[0]), dict_size)
src_dict = __to_dict(f.extractfile(names[0]), dict_size)
names = [
each_item.name for each_item in f
if each_item.name.endswith("trg.dict")
]
assert len(names) == 1
trg_dict = __to_dict__(f.extractfile(names[0]), dict_size)
trg_dict = __to_dict(f.extractfile(names[0]), dict_size)
return src_dict, trg_dict


def reader_creator(tar_file, file_name, dict_size):
def reader():
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
with tarfile.open(tar_file, mode='r') as f:
names = [
each_item.name for each_item in f
Expand Down Expand Up @@ -152,7 +160,7 @@ def get_dict(dict_size, reverse=True):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse:
src_dict = {v: k for k, v in src_dict.items()}
trg_dict = {v: k for k, v in trg_dict.items()}
Expand Down
Loading