diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index a878ada21..86b7f46cd 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -2632,6 +2632,53 @@ def to_huggingface_dataset(self): return export_cuts_to_hf(self) + @staticmethod + def from_huggingface_dataset( + *dataset_args, + audio_key: str = "audio", + text_key: str = "sentence", + lang_key: str = "language", + gender_key: str = "gender", + **dataset_kwargs, + ): + """ + Initializes a Lhotse CutSet from an existing HF dataset, + or args/kwargs passed on to ``datasets.load_dataset()``. + + Use ``audio_key``, ``text_key``, ``lang_key`` and ``gender_key`` options to indicate which keys in dict examples + returned from HF Dataset should be looked up for audio, transcript, language, and gender respectively. + The remaining keys in HF dataset examples will be stored inside ``cut.custom`` dictionary. + + Example with existing HF dataset:: + + >>> import datasets + ... dataset = datasets.load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test") + ... dataset = dataset.map(some_transform) + ... cuts = CutSet.from_huggingface_dataset(dataset) + ... for cut in cuts: + ... pass + + Example providing HF dataset init args/kwargs:: + + >>> import datasets + ... cuts = CutSet.from_huggingface_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test") + ... for cut in cuts: + ... pass + + """ + from lhotse.hf import LazyHFDatasetIterator + + return CutSet( + LazyHFDatasetIterator( + *dataset_args, + audio_key=audio_key, + text_key=text_key, + lang_key=lang_key, + gender_key=gender_key, + **dataset_kwargs, + ) + ) + def __repr__(self) -> str: try: len_val = len(self) diff --git a/lhotse/hf.py b/lhotse/hf.py index 493087ad9..1cfdb4bad 100644 --- a/lhotse/hf.py +++ b/lhotse/hf.py @@ -1,12 +1,14 @@ """ -╔══════════════════════════════════════╗ -║ Export CutSet to HuggingFace Dataset ║ -╚══════════════════════════════════════╝ +╔═════════════════════════════════════════════╗ +║ Export/Import CutSet to HuggingFace Dataset ║ +╚═════════════════════════════════════════════╝ """ +from hashlib import md5 from typing import Any, Dict, List, Optional, Set, Tuple, Union +from lhotse import Recording, SupervisionSegment from lhotse.cut import CutSet, MonoCut -from lhotse.utils import is_module_available +from lhotse.utils import Pathlike, is_module_available def contains_only_mono_cuts(cutset: CutSet) -> bool: @@ -301,3 +303,78 @@ def export_cuts_to_hf(cutset: CutSet): ) return Dataset.from_dict(dataset_dict, features=dataset_info) + + +class LazyHFDatasetIterator: + """ + Thin wrapper on top of HF datasets objects that allows to interact with them through a Lhotse CutSet. + It can be initialized with an existing HF dataset, or args/kwargs passed on to ``datasets.load_dataset()``. + + Use ``audio_key``, ``text_key``, ``lang_key`` and ``gender_key`` options to indicate which keys in dict examples + returned from HF Dataset should be looked up for audio, transcript, language, and gender respectively. + The remaining keys in HF dataset examples will be stored inside ``cut.custom`` dictionary. + + Example with existing HF dataset:: + + >>> import datasets + ... dataset = datasets.load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test") + ... dataset = dataset.map(some_transform) + ... cuts_it = LazyHFDatasetIterator(dataset) + ... for cut in cuts_it: + ... pass + + Example providing HF dataset init args/kwargs:: + + >>> import datasets + ... cuts_it = LazyHFDatasetIterator("mozilla-foundation/common_voice_11_0", "hi", split="test") + ... for cut in cuts_it: + ... pass + + """ + + def __init__( + self, + *dataset_args, + audio_key: str = "audio", + text_key: str = "sentence", + lang_key: str = "language", + gender_key: str = "gender", + **dataset_kwargs + ): + assert is_module_available("datasets") + self.audio_key = audio_key + self.text_key = text_key + self.lang_key = lang_key + self.gender_key = gender_key + self.dataset_args = dataset_args + self.dataset_kwargs = dataset_kwargs + + def __iter__(self): + from datasets import Audio, Dataset, IterableDataset, load_dataset + + if len(self.dataset_args) == 1 and isinstance( + self.dataset_args[0], (Dataset, IterableDataset) + ): + dataset = self.dataset_args[0] + else: + dataset = load_dataset(*self.dataset_args, **self.dataset_kwargs) + + dataset = dataset.cast_column(self.audio_key, Audio(decode=False)) + for item in dataset: + audio_data = item.pop(self.audio_key) + recording = Recording.from_bytes( + audio_data["bytes"], recording_id=md5(audio_data["bytes"]).hexdigest() + ) + supervision = SupervisionSegment( + id=recording.id, + recording_id=recording.id, + start=0.0, + duration=recording.duration, + text=item.pop(self.text_key, None), + language=item.pop(self.lang_key, None), + gender=item.pop(self.gender_key, None), + ) + cut = recording.to_cut() + cut.supervisions = [supervision] + cut.custom = item + yield cut