Skip to content

Commit

Permalink
Fix get_data_patterns for directories with the word data twice (#6309)
Browse files Browse the repository at this point in the history
* Test get_data_patterns from directory with the word data twice

* Fix get_data_patterns

* Use glob_pattern_to_regex in entire xjoin

* Fix test by passing base_path as posix

* Use slash instead of xjoin for data files patterns

* Fix slash sep
  • Loading branch information
albertvillanova committed Oct 24, 2023
1 parent fdc29db commit 79a1526
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ def _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(matched_rel_
return len(hidden_directories_in_path) != len(hidden_directories_in_pattern)


def _get_data_files_patterns(pattern_resolver: Callable[[str], List[str]]) -> Dict[str, List[str]]:
def _get_data_files_patterns(
pattern_resolver: Callable[[str], List[str]], base_path: str = ""
) -> Dict[str, List[str]]:
"""
Get the default pattern from a directory or repository by testing all the supported patterns.
The first patterns to return a non-empty list of data files is returned.
Expand All @@ -242,7 +244,8 @@ def _get_data_files_patterns(pattern_resolver: Callable[[str], List[str]]) -> Di
except FileNotFoundError:
continue
if len(data_files) > 0:
splits: Set[str] = {string_to_dict(p, glob_pattern_to_regex(split_pattern))["split"] for p in data_files}
pattern = base_path + ("/" if base_path else "") + split_pattern
splits: Set[str] = {string_to_dict(p, glob_pattern_to_regex(pattern))["split"] for p in data_files}
sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted(
splits - set(DEFAULT_SPLITS)
)
Expand Down Expand Up @@ -462,7 +465,7 @@ def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig]
"""
resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config)
try:
return _get_data_files_patterns(resolver)
return _get_data_files_patterns(resolver, base_path=base_path)
except FileNotFoundError:
raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None

Expand Down
11 changes: 11 additions & 0 deletions tests/test_data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_get_metadata_files_patterns,
_is_inside_unrequested_special_dir,
_is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir,
get_data_patterns,
resolve_pattern,
)
from datasets.fingerprint import Hasher
Expand Down Expand Up @@ -634,3 +635,13 @@ def resolver(pattern):
patterns = _get_metadata_files_patterns(resolver)
matched = [file_path for pattern in patterns for file_path in resolver(pattern)]
assert sorted(matched) == sorted(metadata_files)


def test_get_data_patterns_from_directory_with_the_word_data_twice(tmp_path):
repo_dir = tmp_path / "directory-name-ending-with-the-word-data" # parent directory contains the word "data/"
data_dir = repo_dir / "data"
data_dir.mkdir(parents=True)
data_file = data_dir / "train-00001-of-00009.parquet"
data_file.touch()
data_file_patterns = get_data_patterns(repo_dir.as_posix())
assert data_file_patterns == {"train": ["data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"]}

0 comments on commit 79a1526

Please sign in to comment.