From 79a1526f820b44d9fd6c4bd13c2298c5d9d809d1 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 18 Oct 2023 15:50:34 +0200 Subject: [PATCH] Fix get_data_patterns for directories with the word data twice (#6309) * Test get_data_patterns from directory with the word data twice * Fix get_data_patterns * Use glob_pattern_to_regex in entire xjoin * Fix test by passing base_path as posix * Use slash instead of xjoin for data files patterns * Fix slash sep --- src/datasets/data_files.py | 9 ++++++--- tests/test_data_files.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 448adc77a0b..25665f7011e 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -227,7 +227,9 @@ def _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(matched_rel_ return len(hidden_directories_in_path) != len(hidden_directories_in_pattern) -def _get_data_files_patterns(pattern_resolver: Callable[[str], List[str]]) -> Dict[str, List[str]]: +def _get_data_files_patterns( + pattern_resolver: Callable[[str], List[str]], base_path: str = "" +) -> Dict[str, List[str]]: """ Get the default pattern from a directory or repository by testing all the supported patterns. The first patterns to return a non-empty list of data files is returned. @@ -242,7 +244,8 @@ def _get_data_files_patterns(pattern_resolver: Callable[[str], List[str]]) -> Di except FileNotFoundError: continue if len(data_files) > 0: - splits: Set[str] = {string_to_dict(p, glob_pattern_to_regex(split_pattern))["split"] for p in data_files} + pattern = base_path + ("/" if base_path else "") + split_pattern + splits: Set[str] = {string_to_dict(p, glob_pattern_to_regex(pattern))["split"] for p in data_files} sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted( splits - set(DEFAULT_SPLITS) ) @@ -462,7 +465,7 @@ def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig] """ resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config) try: - return _get_data_files_patterns(resolver) + return _get_data_files_patterns(resolver, base_path=base_path) except FileNotFoundError: raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None diff --git a/tests/test_data_files.py b/tests/test_data_files.py index 6fd51595e82..42cd80970da 100644 --- a/tests/test_data_files.py +++ b/tests/test_data_files.py @@ -16,6 +16,7 @@ _get_metadata_files_patterns, _is_inside_unrequested_special_dir, _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir, + get_data_patterns, resolve_pattern, ) from datasets.fingerprint import Hasher @@ -634,3 +635,13 @@ def resolver(pattern): patterns = _get_metadata_files_patterns(resolver) matched = [file_path for pattern in patterns for file_path in resolver(pattern)] assert sorted(matched) == sorted(metadata_files) + + +def test_get_data_patterns_from_directory_with_the_word_data_twice(tmp_path): + repo_dir = tmp_path / "directory-name-ending-with-the-word-data" # parent directory contains the word "data/" + data_dir = repo_dir / "data" + data_dir.mkdir(parents=True) + data_file = data_dir / "train-00001-of-00009.parquet" + data_file.touch() + data_file_patterns = get_data_patterns(repo_dir.as_posix()) + assert data_file_patterns == {"train": ["data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"]}