Skip to content

Commit

Permalink
keep more info in DatasetInfo.from_merge #6585 (#6586)
Browse files Browse the repository at this point in the history
* try not to merge DatasetInfos if they're equal

* fixes losing DatasetInfo during parallel Dataset.map

Co-authored-by: Quentin Lhoest <[email protected]>
  • Loading branch information
JochenSiegWork and lhoestq authored Jan 26, 2024
1 parent d627fb8 commit ca76ca1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ def _dump_license(self, file):
@classmethod
def from_merge(cls, dataset_infos: List["DatasetInfo"]):
dataset_infos = [dset_info.copy() for dset_info in dataset_infos if dset_info is not None]

if len(dataset_infos) > 0 and all(dataset_infos[0] == dset_info for dset_info in dataset_infos):
# if all dataset_infos are equal we don't need to merge. Just return the first.
return dataset_infos[0]

description = "\n\n".join(unique_values(info.description for info in dataset_infos)).strip()
citation = "\n\n".join(unique_values(info.citation for info in dataset_infos)).strip()
homepage = "\n\n".join(unique_values(info.homepage for info in dataset_infos)).strip()
Expand Down
30 changes: 30 additions & 0 deletions tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,33 @@ def test_dataset_infos_dict_dump_and_reload(tmp_path, dataset_infos_dict: Datase

if dataset_infos_dict:
assert os.path.exists(os.path.join(tmp_path, "README.md"))


@pytest.mark.parametrize(
"dataset_info",
[
None,
DatasetInfo(),
DatasetInfo(
description="foo",
features=Features({"a": Value("int32")}),
builder_name="builder",
config_name="config",
version="1.0.0",
splits=[{"name": "train"}],
download_size=42,
dataset_name="dataset_name",
),
],
)
def test_from_merge_same_dataset_infos(dataset_info):
num_elements = 3
if dataset_info is not None:
dataset_info_list = [dataset_info.copy() for _ in range(num_elements)]
else:
dataset_info_list = [None] * num_elements
dataset_info_merged = DatasetInfo.from_merge(dataset_info_list)
if dataset_info is not None:
assert dataset_info == dataset_info_merged
else:
assert DatasetInfo() == dataset_info_merged

0 comments on commit ca76ca1

Please sign in to comment.