Skip to content

Commit

Permalink
fix name convention.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Jan 22, 2018
1 parent 9a97c7f commit 2f344e7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
12 changes: 6 additions & 6 deletions python/paddle/v2/dataset/wmt14.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
UNK_IDX = 2


def __read_to_dict__(tar_file, dict_size):
def __to_dict__(fd, size):
def __read_to_dict(tar_file, dict_size):
def __to_dict(fd, size):
out_dict = dict()
for line_count, line in enumerate(fd):
if line_count < size:
Expand All @@ -66,19 +66,19 @@ def __to_dict__(fd, size):
if each_item.name.endswith("src.dict")
]
assert len(names) == 1
src_dict = __to_dict__(f.extractfile(names[0]), dict_size)
src_dict = __to_dict(f.extractfile(names[0]), dict_size)
names = [
each_item.name for each_item in f
if each_item.name.endswith("trg.dict")
]
assert len(names) == 1
trg_dict = __to_dict__(f.extractfile(names[0]), dict_size)
trg_dict = __to_dict(f.extractfile(names[0]), dict_size)
return src_dict, trg_dict


def reader_creator(tar_file, file_name, dict_size):
def reader():
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
with tarfile.open(tar_file, mode='r') as f:
names = [
each_item.name for each_item in f
Expand Down Expand Up @@ -160,7 +160,7 @@ def get_dict(dict_size, reverse=True):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse:
src_dict = {v: k for k, v in src_dict.items()}
trg_dict = {v: k for k, v in trg_dict.items()}
Expand Down
32 changes: 16 additions & 16 deletions python/paddle/v2/dataset/wmt16.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ACL2016 Multimodal Machine Translation. Please see this websit for more details:
http://www.statmt.org/wmt16/multimodal-task.html#task1
ACL2016 Multimodal Machine Translation. Please see this website for more
details: http://www.statmt.org/wmt16/multimodal-task.html#task1
If you use the dataset created for your task, please cite the following paper:
Multi30K: Multilingual English-German Image Descriptions.
Expand Down Expand Up @@ -56,7 +56,7 @@
UNK_MARK = "<unk>"


def __build_dict__(tar_file, dict_size, save_path, lang):
def __build_dict(tar_file, dict_size, save_path, lang):
word_dict = defaultdict(int)
with tarfile.open(tar_file, mode="r") as f:
for line in f.extractfile("wmt16/train"):
Expand All @@ -75,12 +75,12 @@ def __build_dict__(tar_file, dict_size, save_path, lang):
fout.write("%s\n" % (word[0]))


def __load_dict__(tar_file, dict_size, lang, reverse=False):
def __load_dict(tar_file, dict_size, lang, reverse=False):
dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size))
if not os.path.exists(dict_path) or (
len(open(dict_path, "r").readlines()) != dict_size):
__build_dict__(tar_file, dict_size, dict_path, lang)
__build_dict(tar_file, dict_size, dict_path, lang)

word_dict = {}
with open(dict_path, "r") as fdict:
Expand All @@ -92,7 +92,7 @@ def __load_dict__(tar_file, dict_size, lang, reverse=False):
return word_dict


def __get_dict_size__(src_dict_size, trg_dict_size, src_lang):
def __get_dict_size(src_dict_size, trg_dict_size, src_lang):
src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else
TOTAL_DE_WORDS))
trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else
Expand All @@ -102,9 +102,9 @@ def __get_dict_size__(src_dict_size, trg_dict_size, src_lang):

def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang):
def reader():
src_dict = __load_dict__(tar_file, src_dict_size, src_lang)
trg_dict = __load_dict__(tar_file, trg_dict_size,
("de" if src_lang == "en" else "en"))
src_dict = __load_dict(tar_file, src_dict_size, src_lang)
trg_dict = __load_dict(tar_file, trg_dict_size,
("de" if src_lang == "en" else "en"))

# the indice for start mark, end mark, and unk are the same in source
# language and target language. Here uses the source language
Expand Down Expand Up @@ -173,8 +173,8 @@ def train(src_dict_size, trg_dict_size, src_lang="en"):

assert (src_lang in ["en", "de"], ("An error language type. Only support: "
"en (for English); de(for Germany)"))
src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size,
trg_dict_size, src_lang)
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang)

return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
Expand Down Expand Up @@ -222,8 +222,8 @@ def test(src_dict_size, trg_dict_size, src_lang="en"):
("An error language type. "
"Only support: en (for English); de(for Germany)"))

src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size,
trg_dict_size, src_lang)
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang)

return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
Expand Down Expand Up @@ -269,8 +269,8 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"):
assert (src_lang in ["en", "de"],
("An error language type. "
"Only support: en (for English); de(for Germany)"))
src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size,
trg_dict_size, src_lang)
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang)

return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
Expand Down Expand Up @@ -308,7 +308,7 @@ def get_dict(lang, dict_size, reverse=False):
"Please invoke paddle.dataset.wmt16.train/test/validation "
"first to build the dictionary.")
tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz")
return __load_dict__(tar_file, dict_size, lang, reverse)
return __load_dict(tar_file, dict_size, lang, reverse)


def fetch():
Expand Down

0 comments on commit 2f344e7

Please sign in to comment.