Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix offline mode with single config #6741

Merged
merged 4 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions src/datasets/packaged_modules/cache/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import glob
import json
import os
import shutil
import time
Expand All @@ -22,43 +23,62 @@ def _get_modification_time(cached_directory_path):


def _find_hash_in_cache(
dataset_name: str, config_name: Optional[str], cache_dir: Optional[str]
dataset_name: str,
config_name: Optional[str],
cache_dir: Optional[str],
config_kwargs: dict,
custom_features: Optional[datasets.Features],
) -> Tuple[str, str, str]:
if config_name or config_kwargs or custom_features:
config_id = datasets.BuilderConfig(config_name or "default").create_config_id(
config_kwargs=config_kwargs, custom_features=custom_features
)
else:
config_id = None
cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE))
cached_datasets_directory_path_root = os.path.join(cache_dir, dataset_name.replace("/", "___"))
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(
os.path.join(cached_datasets_directory_path_root, config_name or "*", "*", "*")
os.path.join(cached_datasets_directory_path_root, config_id or "*", "*", "*")
)
if os.path.isdir(cached_directory_path)
and (
config_kwargs
or custom_features
or json.loads(Path(cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
== Path(cached_directory_path).parts[-3] # no extra params => config_id == config_name
)
]
if not cached_directory_paths:
if config_name is not None:
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(
os.path.join(cached_datasets_directory_path_root, "*", "*", "*")
)
if os.path.isdir(cached_directory_path)
]
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*"))
if os.path.isdir(cached_directory_path)
]
available_configs = sorted(
{Path(cached_directory_path).parts[-3] for cached_directory_path in cached_directory_paths}
)
raise ValueError(
f"Couldn't find cache for {dataset_name}"
+ (f" for config '{config_name}'" if config_name else "")
+ (f" for config '{config_id}'" if config_id else "")
+ (f"\nAvailable configs in the cache: {available_configs}" if available_configs else "")
)
# get most recent
cached_directory_path = Path(sorted(cached_directory_paths, key=_get_modification_time)[-1])
version, hash = cached_directory_path.parts[-2:]
other_configs = [
Path(cached_directory_path).parts[-3]
for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", version, hash))
if os.path.isdir(cached_directory_path)
Path(_cached_directory_path).parts[-3]
for _cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", version, hash))
if os.path.isdir(_cached_directory_path)
and (
config_kwargs
or custom_features
or json.loads(Path(_cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
== Path(_cached_directory_path).parts[-3] # no extra params => config_id == config_name
)
]
if not config_name and len(other_configs) > 1:
if not config_id and len(other_configs) > 1:
raise ValueError(
f"There are multiple '{dataset_name}' configurations in the cache: {', '.join(other_configs)}"
f"\nPlease specify which configuration to reload from the cache, e.g."
Expand Down Expand Up @@ -114,15 +134,12 @@ def __init__(
if data_dir is not None:
config_kwargs["data_dir"] = data_dir
if hash == "auto" and version == "auto":
# First we try to find a folder that takes the config_kwargs into account
# e.g. with "default-data_dir=data%2Ffortran" as config_id
config_id = self.BUILDER_CONFIG_CLASS(config_name or "default").create_config_id(
config_kwargs=config_kwargs, custom_features=features
)
config_name, version, hash = _find_hash_in_cache(
dataset_name=repo_id or dataset_name,
config_name=config_id,
config_name=config_name,
cache_dir=cache_dir,
config_kwargs=config_kwargs,
custom_features=features,
)
elif hash == "auto" or version == "auto":
raise NotImplementedError("Pass both hash='auto' and version='auto' instead")
Expand Down
107 changes: 83 additions & 24 deletions tests/packaged_modules/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,47 @@
from datasets.packaged_modules.cache.cache import Cache


SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_single_config_in_metadata"
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_two_configs_in_metadata"


def test_cache(text_dir: Path):
ds = load_dataset(str(text_dir))
def test_cache(text_dir: Path, tmp_path: Path):
cache_dir = tmp_path / "test_cache"
ds = load_dataset(str(text_dir), cache_dir=str(cache_dir))
hash = Path(ds["train"].cache_files[0]["filename"]).parts[-2]
cache = Cache(dataset_name=text_dir.name, hash=hash)
cache = Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, hash=hash)
reloaded = cache.as_dataset()
assert list(ds) == list(reloaded)
assert list(ds["train"]) == list(reloaded["train"])


def test_cache_streaming(text_dir: Path):
ds = load_dataset(str(text_dir))
def test_cache_streaming(text_dir: Path, tmp_path: Path):
cache_dir = tmp_path / "test_cache_streaming"
ds = load_dataset(str(text_dir), cache_dir=str(cache_dir))
hash = Path(ds["train"].cache_files[0]["filename"]).parts[-2]
cache = Cache(dataset_name=text_dir.name, hash=hash)
cache = Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, hash=hash)
reloaded = cache.as_streaming_dataset()
assert list(ds) == list(reloaded)
assert list(ds["train"]) == list(reloaded["train"])


def test_cache_auto_hash(text_dir: Path):
ds = load_dataset(str(text_dir))
cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto")
def test_cache_auto_hash(text_dir: Path, tmp_path: Path):
cache_dir = tmp_path / "test_cache_auto_hash"
ds = load_dataset(str(text_dir), cache_dir=str(cache_dir))
cache = Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, version="auto", hash="auto")
reloaded = cache.as_dataset()
assert list(ds) == list(reloaded)
assert list(ds["train"]) == list(reloaded["train"])


def test_cache_auto_hash_with_custom_config(text_dir: Path):
ds = load_dataset(str(text_dir), sample_by="paragraph")
another_ds = load_dataset(str(text_dir))
cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto", sample_by="paragraph")
another_cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto")
def test_cache_auto_hash_with_custom_config(text_dir: Path, tmp_path: Path):
cache_dir = tmp_path / "test_cache_auto_hash_with_custom_config"
ds = load_dataset(str(text_dir), sample_by="paragraph", cache_dir=str(cache_dir))
another_ds = load_dataset(str(text_dir), cache_dir=str(cache_dir))
cache = Cache(
cache_dir=str(cache_dir), dataset_name=text_dir.name, version="auto", hash="auto", sample_by="paragraph"
)
another_cache = Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, version="auto", hash="auto")
assert cache.config_id.endswith("paragraph")
assert not another_cache.config_id.endswith("paragraph")
reloaded = cache.as_dataset()
Expand All @@ -50,27 +57,79 @@ def test_cache_auto_hash_with_custom_config(text_dir: Path):
assert list(another_ds["train"]) == list(another_reloaded["train"])


def test_cache_missing(text_dir: Path):
load_dataset(str(text_dir))
Cache(dataset_name=text_dir.name, version="auto", hash="auto").download_and_prepare()
def test_cache_missing(text_dir: Path, tmp_path: Path):
cache_dir = tmp_path / "test_cache_missing"
load_dataset(str(text_dir), cache_dir=str(cache_dir))
Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, version="auto", hash="auto").download_and_prepare()
with pytest.raises(ValueError):
Cache(dataset_name="missing", version="auto", hash="auto").download_and_prepare()
Cache(cache_dir=str(cache_dir), dataset_name="missing", version="auto", hash="auto").download_and_prepare()
with pytest.raises(ValueError):
Cache(dataset_name=text_dir.name, hash="missing").download_and_prepare()
Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, hash="missing").download_and_prepare()
with pytest.raises(ValueError):
Cache(dataset_name=text_dir.name, config_name="missing", version="auto", hash="auto").download_and_prepare()
Cache(
cache_dir=str(cache_dir), dataset_name=text_dir.name, config_name="missing", version="auto", hash="auto"
).download_and_prepare()


@pytest.mark.integration
def test_cache_multi_configs():
def test_cache_multi_configs(tmp_path: Path):
cache_dir = tmp_path / "test_cache_multi_configs"
repo_id = SAMPLE_DATASET_TWO_CONFIG_IN_METADATA
dataset_name = repo_id.split("/")[-1]
config_name = "v1"
ds = load_dataset(repo_id, config_name)
cache = Cache(dataset_name=dataset_name, repo_id=repo_id, config_name=config_name, version="auto", hash="auto")
ds = load_dataset(repo_id, config_name, cache_dir=str(cache_dir))
cache = Cache(
cache_dir=str(cache_dir),
dataset_name=dataset_name,
repo_id=repo_id,
config_name=config_name,
version="auto",
hash="auto",
)
reloaded = cache.as_dataset()
assert list(ds) == list(reloaded)
assert len(ds["train"]) == len(reloaded["train"])
with pytest.raises(ValueError) as excinfo:
Cache(dataset_name=dataset_name, repo_id=repo_id, config_name="missing", version="auto", hash="auto")
Cache(
cache_dir=str(cache_dir),
dataset_name=dataset_name,
repo_id=repo_id,
config_name="missing",
version="auto",
hash="auto",
)
assert config_name in str(excinfo.value)


@pytest.mark.integration
def test_cache_single_config(tmp_path: Path):
cache_dir = tmp_path / "test_cache_single_config"
repo_id = SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA
dataset_name = repo_id.split("/")[-1]
config_name = "custom"
ds = load_dataset(repo_id, cache_dir=str(cache_dir))
cache = Cache(cache_dir=str(cache_dir), dataset_name=dataset_name, repo_id=repo_id, version="auto", hash="auto")
reloaded = cache.as_dataset()
assert list(ds) == list(reloaded)
assert len(ds["train"]) == len(reloaded["train"])
cache = Cache(
cache_dir=str(cache_dir),
dataset_name=dataset_name,
config_name=config_name,
repo_id=repo_id,
version="auto",
hash="auto",
)
reloaded = cache.as_dataset()
assert list(ds) == list(reloaded)
assert len(ds["train"]) == len(reloaded["train"])
with pytest.raises(ValueError) as excinfo:
Cache(
cache_dir=str(cache_dir),
dataset_name=dataset_name,
repo_id=repo_id,
config_name="missing",
version="auto",
hash="auto",
)
assert config_name in str(excinfo.value)
Loading