Skip to content

Commit

Permalink
CutSet.from_huggingface_dataset() for importing HF datasets (#1433)
Browse files Browse the repository at this point in the history
* `CutSet.from_huggingface_dataset()` for importing HF datasets

Signed-off-by: Piotr Żelasko <[email protected]>

* Fix

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko authored Dec 13, 2024
1 parent bd50182 commit 40b25ec
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 4 deletions.
47 changes: 47 additions & 0 deletions lhotse/cut/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2632,6 +2632,53 @@ def to_huggingface_dataset(self):

return export_cuts_to_hf(self)

@staticmethod
def from_huggingface_dataset(
*dataset_args,
audio_key: str = "audio",
text_key: str = "sentence",
lang_key: str = "language",
gender_key: str = "gender",
**dataset_kwargs,
):
"""
Initializes a Lhotse CutSet from an existing HF dataset,
or args/kwargs passed on to ``datasets.load_dataset()``.
Use ``audio_key``, ``text_key``, ``lang_key`` and ``gender_key`` options to indicate which keys in dict examples
returned from HF Dataset should be looked up for audio, transcript, language, and gender respectively.
The remaining keys in HF dataset examples will be stored inside ``cut.custom`` dictionary.
Example with existing HF dataset::
>>> import datasets
... dataset = datasets.load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test")
... dataset = dataset.map(some_transform)
... cuts = CutSet.from_huggingface_dataset(dataset)
... for cut in cuts:
... pass
Example providing HF dataset init args/kwargs::
>>> import datasets
... cuts = CutSet.from_huggingface_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test")
... for cut in cuts:
... pass
"""
from lhotse.hf import LazyHFDatasetIterator

return CutSet(
LazyHFDatasetIterator(
*dataset_args,
audio_key=audio_key,
text_key=text_key,
lang_key=lang_key,
gender_key=gender_key,
**dataset_kwargs,
)
)

def __repr__(self) -> str:
try:
len_val = len(self)
Expand Down
85 changes: 81 additions & 4 deletions lhotse/hf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""
╔══════════════════════════════════════╗
║ Export CutSet to HuggingFace Dataset ║
╚══════════════════════════════════════╝
╔═════════════════════════════════════════════
║ Export/Import CutSet to HuggingFace Dataset ║
╚═════════════════════════════════════════════
"""
from hashlib import md5
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from lhotse import Recording, SupervisionSegment
from lhotse.cut import CutSet, MonoCut
from lhotse.utils import is_module_available
from lhotse.utils import Pathlike, is_module_available


def contains_only_mono_cuts(cutset: CutSet) -> bool:
Expand Down Expand Up @@ -301,3 +303,78 @@ def export_cuts_to_hf(cutset: CutSet):
)

return Dataset.from_dict(dataset_dict, features=dataset_info)


class LazyHFDatasetIterator:
"""
Thin wrapper on top of HF datasets objects that allows to interact with them through a Lhotse CutSet.
It can be initialized with an existing HF dataset, or args/kwargs passed on to ``datasets.load_dataset()``.
Use ``audio_key``, ``text_key``, ``lang_key`` and ``gender_key`` options to indicate which keys in dict examples
returned from HF Dataset should be looked up for audio, transcript, language, and gender respectively.
The remaining keys in HF dataset examples will be stored inside ``cut.custom`` dictionary.
Example with existing HF dataset::
>>> import datasets
... dataset = datasets.load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test")
... dataset = dataset.map(some_transform)
... cuts_it = LazyHFDatasetIterator(dataset)
... for cut in cuts_it:
... pass
Example providing HF dataset init args/kwargs::
>>> import datasets
... cuts_it = LazyHFDatasetIterator("mozilla-foundation/common_voice_11_0", "hi", split="test")
... for cut in cuts_it:
... pass
"""

def __init__(
self,
*dataset_args,
audio_key: str = "audio",
text_key: str = "sentence",
lang_key: str = "language",
gender_key: str = "gender",
**dataset_kwargs
):
assert is_module_available("datasets")
self.audio_key = audio_key
self.text_key = text_key
self.lang_key = lang_key
self.gender_key = gender_key
self.dataset_args = dataset_args
self.dataset_kwargs = dataset_kwargs

def __iter__(self):
from datasets import Audio, Dataset, IterableDataset, load_dataset

if len(self.dataset_args) == 1 and isinstance(
self.dataset_args[0], (Dataset, IterableDataset)
):
dataset = self.dataset_args[0]
else:
dataset = load_dataset(*self.dataset_args, **self.dataset_kwargs)

dataset = dataset.cast_column(self.audio_key, Audio(decode=False))
for item in dataset:
audio_data = item.pop(self.audio_key)
recording = Recording.from_bytes(
audio_data["bytes"], recording_id=md5(audio_data["bytes"]).hexdigest()
)
supervision = SupervisionSegment(
id=recording.id,
recording_id=recording.id,
start=0.0,
duration=recording.duration,
text=item.pop(self.text_key, None),
language=item.pop(self.lang_key, None),
gender=item.pop(self.gender_key, None),
)
cut = recording.to_cut()
cut.supervisions = [supervision]
cut.custom = item
yield cut

0 comments on commit 40b25ec

Please sign in to comment.