From 7b5fc585fcaf77b92839e82d0ce2c2fbf0d9ea95 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 22 Dec 2023 12:36:13 +0100 Subject: [PATCH] Preserve order of configs and splits when using Parquet exports (#6526) * Preserve order of configs and splits * Add test * Rephrase --- src/datasets/utils/metadata.py | 41 +++++++++++------ tests/test_metadata_util.py | 83 ++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 15 deletions(-) diff --git a/src/datasets/utils/metadata.py b/src/datasets/utils/metadata.py index 4d2d6b059c7..d85866d8ba9 100644 --- a/src/datasets/utils/metadata.py +++ b/src/datasets/utils/metadata.py @@ -179,26 +179,37 @@ def _from_exported_parquet_files_and_dataset_infos( exported_parquet_files: List[Dict[str, Any]], dataset_infos: DatasetInfosDict, ) -> "MetadataConfigs": - return cls( - { + metadata_configs = { + config_name: { + "data_files": [ + { + "split": split_name, + "path": [ + parquet_file["url"].replace("refs%2Fconvert%2Fparquet", revision) + for parquet_file in parquet_files_for_split + ], + } + for split_name, parquet_files_for_split in groupby(parquet_files_for_config, itemgetter("split")) + ], + "version": str(dataset_infos.get(config_name, DatasetInfo()).version or "0.0.0"), + } + for config_name, parquet_files_for_config in groupby(exported_parquet_files, itemgetter("config")) + } + if dataset_infos: + # Preserve order of configs and splits + metadata_configs = { config_name: { "data_files": [ - { - "split": split_name, - "path": [ - parquet_file["url"].replace("refs%2Fconvert%2Fparquet", revision) - for parquet_file in parquet_files_for_split - ], - } - for split_name, parquet_files_for_split in groupby( - parquet_files_for_config, itemgetter("split") - ) + data_file + for split_name in dataset_info.splits + for data_file in metadata_configs[config_name]["data_files"] + if data_file["split"] == split_name ], - "version": str(dataset_infos.get(config_name, DatasetInfo()).version or "0.0.0"), + "version": metadata_configs[config_name]["version"], } - for config_name, parquet_files_for_config in groupby(exported_parquet_files, itemgetter("config")) + for config_name, dataset_info in dataset_infos.items() } - ) + return cls(metadata_configs) @classmethod def from_dataset_card_data(cls, dataset_card_data: DatasetCardData) -> "MetadataConfigs": diff --git a/tests/test_metadata_util.py b/tests/test_metadata_util.py index 7c487fb11f8..d2d82903bed 100644 --- a/tests/test_metadata_util.py +++ b/tests/test_metadata_util.py @@ -9,6 +9,7 @@ from huggingface_hub import DatasetCard, DatasetCardData from datasets.config import METADATA_CONFIGS_FIELD +from datasets.info import DatasetInfo from datasets.utils.metadata import MetadataConfigs @@ -249,3 +250,85 @@ def test_metadata_configs_incorrect_yaml(): dataset_card_data = DatasetCard.load(path).data with pytest.raises(ValueError): _ = MetadataConfigs.from_dataset_card_data(dataset_card_data) + + +def test_split_order_in_metadata_configs_from_exported_parquet_files_and_dataset_infos(): + exported_parquet_files = [ + { + "dataset": "beans", + "config": "default", + "split": "test", + "url": "https://huggingface.co/datasets/beans/resolve/refs%2Fconvert%2Fparquet/default/test/0000.parquet", + "filename": "0000.parquet", + "size": 17707203, + }, + { + "dataset": "beans", + "config": "default", + "split": "train", + "url": "https://huggingface.co/datasets/beans/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet", + "filename": "0000.parquet", + "size": 143780164, + }, + { + "dataset": "beans", + "config": "default", + "split": "validation", + "url": "https://huggingface.co/datasets/beans/resolve/refs%2Fconvert%2Fparquet/default/validation/0000.parquet", + "filename": "0000.parquet", + "size": 18500862, + }, + ] + dataset_infos = { + "default": DatasetInfo( + dataset_name="beans", + config_name="default", + version="0.0.0", + splits={ + "train": { + "name": "train", + "num_bytes": 143996486, + "num_examples": 1034, + "shard_lengths": None, + "dataset_name": "beans", + }, + "validation": { + "name": "validation", + "num_bytes": 18525985, + "num_examples": 133, + "shard_lengths": None, + "dataset_name": "beans", + }, + "test": { + "name": "test", + "num_bytes": 17730506, + "num_examples": 128, + "shard_lengths": None, + "dataset_name": "beans", + }, + }, + download_checksums={ + "https://huggingface.co/datasets/beans/resolve/main/data/train.zip": { + "num_bytes": 143812152, + "checksum": None, + }, + "https://huggingface.co/datasets/beans/resolve/main/data/validation.zip": { + "num_bytes": 18504213, + "checksum": None, + }, + "https://huggingface.co/datasets/beans/resolve/main/data/test.zip": { + "num_bytes": 17708541, + "checksum": None, + }, + }, + download_size=180024906, + post_processing_size=None, + dataset_size=180252977, + size_in_bytes=360277883, + ) + } + metadata_configs = MetadataConfigs._from_exported_parquet_files_and_dataset_infos( + "123", exported_parquet_files, dataset_infos + ) + split_names = [data_file["split"] for data_file in metadata_configs["default"]["data_files"]] + assert split_names == ["train", "validation", "test"]