From ca6234ff2d18c0460d377f09081f875dd760d43a Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 31 May 2024 19:10:37 +0200 Subject: [PATCH] Fix NonMatchingSplitsSizesError/ExpectedMoreSplits in no-code Hub datasets when passing data_dir/data_files (#6925) * Do not use exported dataset infos in some cases * Add regression tests --- src/datasets/load.py | 7 ++++++- tests/test_load.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index fd7aa401094..824817843fd 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1235,7 +1235,12 @@ def get_module(self) -> DatasetModule: pass metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data) dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data) - if config.USE_PARQUET_EXPORT: # maybe don't use the infos from the parquet export + # Use the infos from the parquet export except in some cases: + if self.data_dir or self.data_files or (self.revision and self.revision != "main"): + use_exported_dataset_infos = False + else: + use_exported_dataset_infos = True + if config.USE_PARQUET_EXPORT and use_exported_dataset_infos: try: exported_dataset_infos = _dataset_viewer.get_exported_dataset_infos( dataset=self.name, revision=self.revision, token=self.download_config.token diff --git a/tests/test_load.py b/tests/test_load.py index 4b2b9cbf58c..c7c413ae10b 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1267,6 +1267,21 @@ def test_load_dataset_cached_local_script(dataset_loading_script_dir, data_dir, assert f"Dataset '{SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST}' doesn't exist on the Hub" in str(exc_info.value) +@pytest.mark.integration +@pytest.mark.parametrize( + "kwargs, expected_train_num_rows, expected_test_num_rows", + [ + ({}, 2, 2), + ({"data_dir": "data1"}, 1, 1), # GH-6918: NonMatchingSplitsSizesError + ({"data_files": "data1/train.txt"}, 1, None), # GH-6939: ExpectedMoreSplits + ], +) +def test_load_dataset_without_script_from_hub(kwargs, expected_train_num_rows, expected_test_num_rows): + dataset = load_dataset(SAMPLE_DATASET_IDENTIFIER3, **kwargs) + assert dataset["train"].num_rows == expected_train_num_rows + assert (dataset["test"].num_rows == expected_test_num_rows) if expected_test_num_rows else ("test" not in dataset) + + @pytest.mark.integration @pytest.mark.parametrize("stream_from_cache, ", [False, True]) def test_load_dataset_cached_from_hub(stream_from_cache, caplog):