Skip to content

Commit

Permalink
Fixed on_disk_cache issues
Browse files Browse the repository at this point in the history
ghstack-source-id: b560715160f296b4d5872928992a7a630893914f
Pull Request resolved: #1942
  • Loading branch information
VitalyFedyunin committed Oct 12, 2022
1 parent 4d88d4e commit a38f37e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 15 deletions.
3 changes: 1 addition & 2 deletions test/torchtext_unittest/datasets/test_cnndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def test_cnndm(self, split):
dataset = CNNDM(root=self.root_dir, split=split)
samples = list(dataset)
expected_samples = self.samples[split]
for sample, expected_sample in zip_equal(samples, expected_samples):
self.assertEqual(sample, expected_sample)
self.assertEqual(expected_samples, samples)

@parameterized.expand(["train", "val", "test"])
@patch("torchtext.datasets.cnndm._get_split_list", _mock_split_list)
Expand Down
20 changes: 9 additions & 11 deletions torchtext/datasets/cnndm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import hashlib
import os
from collections import defaultdict
from functools import partial
from typing import Union, Tuple
from typing import Union, Set, Tuple

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
Expand Down Expand Up @@ -52,28 +51,24 @@
"test": 11490,
}

story_fnames = defaultdict(set)


def _filepath_fn(root: str, source: str, _=None):
return os.path.join(root, PATH_LIST[source])


# called once per tar file, therefore no duplicate processing
def _extracted_folder_fn(root: str, source: str, split: str, _=None):
global story_fnames
key = source + "_" + split
story_fnames[key] = set(_get_split_list(source, split))
filepaths = [os.path.join(root, _EXTRACTED_FOLDERS[source], story) for story in story_fnames[key]]
return filepaths
filepath = os.path.join(root, key)
return filepath


def _extracted_filepath_fn(root: str, source: str, x: str):
return os.path.join(root, _EXTRACTED_FOLDERS[source], os.path.basename(x))


def _filter_fn(source: str, split: str, x: tuple):
return os.path.basename(x[0]) in story_fnames[source + "_" + split]
def _filter_fn(split_list: Set[str], x: tuple):
return os.path.basename(x[0]) in split_list


def _hash_urls(s: tuple):
Expand All @@ -96,6 +91,9 @@ def _get_split_list(source: str, split: str):


def _load_stories(root: str, source: str, split: str):

split_list = set(_get_split_list(source, split))

story_dp = IterableWrapper([URL[source]])
cache_compressed_dp = story_dp.on_disk_cache(
filepath_fn=partial(_filepath_fn, root, source),
Expand All @@ -108,7 +106,7 @@ def _load_stories(root: str, source: str, split: str):
filepath_fn=partial(_extracted_folder_fn, root, source, split)
)
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, source, split))
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split_list))
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(
mode="wb", filepath_fn=partial(_extracted_filepath_fn, root, source)
Expand Down
2 changes: 1 addition & 1 deletion torchtext/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _filepath_fn(root, _=None):


def _decompressed_filepath_fn(root, decompressed_folder, split, labels, _=None):
return [os.path.join(root, decompressed_folder, split, label) for label in labels]
return os.path.join(root, decompressed_folder, split)


def _filter_fn(filter_imdb_data, split, t):
Expand Down
19 changes: 18 additions & 1 deletion torchtext/datasets/iwslt2017.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,24 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de
filepath_fn=partial(_inner_iwslt_tar_filepath_fn, inner_iwslt_tar)
)
cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar()
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
# As we had filenames duplicated, any trash files in archive can become tgz

def extracted_file_name(inner_iwslt_tar, inner_tar_name):
name = os.path.basename(inner_tar_name)
path = os.path.dirname(inner_iwslt_tar)
return os.path.join(path, name)

cache_decompressed_dp = cache_decompressed_dp.end_caching(
mode="wb", filepath_fn=partial(extracted_file_name, inner_iwslt_tar)
)
# As we corrected path, we need to leave tgz files only now and no dot files

def leave_only_tgz(file_name):
name = os.path.basename(file_name)
_, file_extension = os.path.splitext(file_name)
return file_extension == ".tgz" and name[0] != "."

cache_decompressed_dp = cache_decompressed_dp.filter(leave_only_tgz)
cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2)

src_filename = file_path_by_lang_and_split[src_language][split]
Expand Down

0 comments on commit a38f37e

Please sign in to comment.