diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 6cdf44e223..b3f1b3c725 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -4,11 +4,11 @@ torchaudio.datasets All datasets are subclasses of :class:`torch.utils.data.Dataset` i.e, they have ``__getitem__`` and ``__len__`` methods implemented. Hence, they can all be passed to a :class:`torch.utils.data.DataLoader` -which can load multiple samples parallelly using ``torch.multiprocessing`` workers. +which can load multiple samples parallelly using ``torch.multiprocessing`` workers. For example: :: - + yesno_data = torchaudio.datasets.YESNO('.', download=True) - data_loader = torch.utils.data.DataLoader(yesno_data, + data_loader = torch.utils.data.DataLoader(yesno_data, batch_size=1, shuffle=True, num_workers=args.nThreads) @@ -22,7 +22,7 @@ All the datasets have almost similar API. They all have two common arguments: ``transform`` and ``target_transform`` to transform the input and target respectively. -.. currentmodule:: torchaudio.datasets +.. currentmodule:: torchaudio.datasets CMUARCTIC @@ -81,6 +81,13 @@ SPEECHCOMMANDS :special-members: +TEDLIUM +~~~~~~~~~~~~~~ + +.. autoclass:: TEDLIUM + :members: __getitem__ + :special-members: get_phoneme_dict + VCTK ~~~~ diff --git a/test/torchaudio_unittest/datasets/tedlium_test.py b/test/torchaudio_unittest/datasets/tedlium_test.py new file mode 100644 index 0000000000..c19984cac2 --- /dev/null +++ b/test/torchaudio_unittest/datasets/tedlium_test.py @@ -0,0 +1,153 @@ +import os + +from torchaudio.datasets import tedlium + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + +# Used to generate a unique utterance for each dummy audio file +UTTERANCES = [ + "AaronHuey_2010X 1 AaronHuey_2010X 0.0 2.0 script1\n", + "AaronHuey_2010X 1 AaronHuey_2010X 2.0 4.0 script2\n", + "AaronHuey_2010X 1 AaronHuey_2010X 4.0 6.0 script3\n", + "AaronHuey_2010X 1 AaronHuey_2010X 6.0 8.0 script4\n", + "AaronHuey_2010X 1 AaronHuey_2010X 8.0 10.0 script5\n", +] + +PHONEME = [ + "a AH", + "a(2) EY", + "aachen AA K AH N", + "aad AE D", + "aaden EY D AH N", + "aadmi AE D M IY", + "aae EY EY", +] + + +class TestTedlium(TempDirMixin, TorchaudioTestCase): + backend = "default" + + root_dir = None + samples = {} + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.root_dir = dataset_dir = os.path.join(cls.root_dir, "tedlium") + os.makedirs(dataset_dir, exist_ok=True) + sample_rate = 16000 # 16kHz + seed = 0 + + for release in ["release1", "release2", "release3"]: + data = get_whitenoise(sample_rate=sample_rate, duration=10.00, n_channels=1, dtype="float32", seed=seed) + if release in ["release1", "release2"]: + release_dir = os.path.join( + dataset_dir, + tedlium._RELEASE_CONFIGS[release]["folder_in_archive"], + tedlium._RELEASE_CONFIGS[release]["subset"], + ) + else: + release_dir = os.path.join( + dataset_dir, + tedlium._RELEASE_CONFIGS[release]["folder_in_archive"], + tedlium._RELEASE_CONFIGS[release]["data_path"], + ) + os.makedirs(release_dir, exist_ok=True) + os.makedirs(os.path.join(release_dir, "stm"), exist_ok=True) # Subfolder for transcripts + os.makedirs(os.path.join(release_dir, "sph"), exist_ok=True) # Subfolder for audio files + filename = f"{release}.sph" + path = os.path.join(os.path.join(release_dir, "sph"), filename) + save_wav(path, data, sample_rate) + + trans_filename = f"{release}.stm" + trans_path = os.path.join(os.path.join(release_dir, "stm"), trans_filename) + with open(trans_path, "w") as f: + f.write("".join(UTTERANCES)) + + dict_filename = f"{release}.dic" + dict_path = os.path.join(release_dir, dict_filename) + with open(dict_path, "w") as f: + f.write("\n".join(PHONEME)) + + # Create a samples list to compare with + cls.samples[release] = [] + for utterance in UTTERANCES: + talk_id, _, speaker_id, start_time, end_time, identifier, transcript = utterance.split(" ", 6) + start_time = int(float(start_time)) * sample_rate + end_time = int(float(end_time)) * sample_rate + sample = ( + data[:, start_time:end_time], + sample_rate, + transcript, + talk_id, + speaker_id, + identifier, + ) + cls.samples[release].append(sample) + seed += 1 + + def test_tedlium_release1(self): + release = "release1" + dataset = tedlium.TEDLIUM(self.root_dir, release=release) + num_samples = 0 + for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset): + self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.samples[release][i][1] + assert transcript == self.samples[release][i][2] + assert talk_id == self.samples[release][i][3] + assert speaker_id == self.samples[release][i][4] + assert identifier == self.samples[release][i][5] + num_samples += 1 + + assert num_samples == len(self.samples[release]) + + dataset._dict_path = os.path.join(dataset._path, f"{release}.dic") + phoneme_dict = dataset.phoneme_dict + phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()] + assert phoenemes == PHONEME + + def test_tedlium_release2(self): + release = "release2" + dataset = tedlium.TEDLIUM(self.root_dir, release=release) + num_samples = 0 + for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset): + self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.samples[release][i][1] + assert transcript == self.samples[release][i][2] + assert talk_id == self.samples[release][i][3] + assert speaker_id == self.samples[release][i][4] + assert identifier == self.samples[release][i][5] + num_samples += 1 + + assert num_samples == len(self.samples[release]) + + dataset._dict_path = os.path.join(dataset._path, f"{release}.dic") + phoneme_dict = dataset.phoneme_dict + phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()] + assert phoenemes == PHONEME + + def test_tedlium_release3(self): + release = "release3" + dataset = tedlium.TEDLIUM(self.root_dir, release=release) + num_samples = 0 + for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset): + self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.samples[release][i][1] + assert transcript == self.samples[release][i][2] + assert talk_id == self.samples[release][i][3] + assert speaker_id == self.samples[release][i][4] + assert identifier == self.samples[release][i][5] + num_samples += 1 + + assert num_samples == len(self.samples[release]) + + dataset._dict_path = os.path.join(dataset._path, f"{release}.dic") + phoneme_dict = dataset.phoneme_dict + phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()] + assert phoenemes == PHONEME diff --git a/torchaudio/datasets/__init__.py b/torchaudio/datasets/__init__.py index fb9e0db506..5e5b9c1118 100644 --- a/torchaudio/datasets/__init__.py +++ b/torchaudio/datasets/__init__.py @@ -8,6 +8,7 @@ from .ljspeech import LJSPEECH from .cmuarctic import CMUARCTIC from .libritts import LIBRITTS +from .tedlium import TEDLIUM __all__ = ( "COMMONVOICE", @@ -19,7 +20,8 @@ "LJSPEECH", "GTZAN", "CMUARCTIC", - "LIBRITTS" + "LIBRITTS", "diskcache_iterator", "bg_iterator", + "TEDLIUM", ) diff --git a/torchaudio/datasets/tedlium.py b/torchaudio/datasets/tedlium.py new file mode 100644 index 0000000000..4ae1ddeb0f --- /dev/null +++ b/torchaudio/datasets/tedlium.py @@ -0,0 +1,208 @@ +import os +from typing import Tuple + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio.datasets.utils import ( + download_url, + extract_archive, +) + + +_RELEASE_CONFIGS = { + "release1": { + "folder_in_archive": "TEDLIUM_release1", + "url": "http://www.openslr.org/resources/7/TEDLIUM_release1.tar.gz", + "checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27", + "data_path": "", + "subset": "train", + "supported_subsets": ["train", "test", "dev"], + "dict": "TEDLIUM.150K.dic", + }, + "release2": { + "folder_in_archive": "TEDLIUM_release2", + "url": "http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz", + "checksum": "93281b5fcaaae5c88671c9d000b443cb3c7ea3499ad12010b3934ca41a7b9c58", + "data_path": "", + "subset": "train", + "supported_subsets": ["train", "test", "dev"], + "dict": "TEDLIUM.152k.dic", + }, + "release3": { + "folder_in_archive": "TEDLIUM_release-3", + "url": "http://www.openslr.org/resources/51/TEDLIUM_release-3.tgz", + "checksum": "ad1e454d14d1ad550bc2564c462d87c7a7ec83d4dc2b9210f22ab4973b9eccdb", + "data_path": "data/", + "subset": None, + "supported_subsets": [None], + "dict": "TEDLIUM.152k.dic", + }, +} + + +class TEDLIUM(Dataset): + """ + Create a Dataset for Tedlium. It supports releases 1,2 and 3, each item is a list containings: + [waveform, sample_rate, transcript, talk_id, speaker_id, identifier]. + + Constructor arguments: + + Args: + root (str): Path containing dataset or target path where its downloaded if needed + release (str, optional): TEDLIUM identifier (release1,release2,release3). Defaults to RELEASE. + subset (str, optional): train/dev/test for releases 1&2, None for release3. Defaults to Train/None + download (bool, optional): Download dataset in case is not founded in root path. Defaults to False. + audio_ext (str, optional): Overwrite audio extension when loading items. Defaults to ".sph". + + Special functions: + + _load_tedlium_item: Loads a TEDLIUM dataset sample given a file name and corresponding sentence name + + _load_audio: Default load function used in TEDLIUM dataset, you can overwrite this function to customize + functionality and load individual sentences from a full ted audio talk file + + get_phoneme_dict: Returns the phoneme dictionary of a TEDLIUM release + + """ + + def __init__( + self, root: str, release: str = "release1", subset: str = None, download: bool = False, audio_ext=".sph" + ) -> None: + """Constructor for TEDLIUM dataset. + + Args: + root (str): Path containing dataset or target path where its downloaded if needed + release (str, optional): TEDLIUM identifier (release1,release2,release3). Defaults to RELEASE. + subset (str, optional): train/dev/test for releases 1&2, None for release3. Defaults to Train/None + download (bool, optional): Download dataset in case is not founded in root path. Defaults to False. + audio_ext (str, optional): Overwrite audio extension when loading items. Defaults to ".sph". + + Raises: + RuntimeError: If release identifier does not match any supported release, + """ + self._ext_audio = audio_ext + if release in _RELEASE_CONFIGS.keys(): + folder_in_archive = _RELEASE_CONFIGS[release]["folder_in_archive"] + url = _RELEASE_CONFIGS[release]["url"] + subset = subset if subset else _RELEASE_CONFIGS[release]["subset"] + else: + # Raise warning + raise RuntimeError( + "The release {} does not match any of the supported tedlium releases{} ".format( + release, _RELEASE_CONFIGS.keys(), + ) + ) + if subset not in _RELEASE_CONFIGS[release]["supported_subsets"]: + # Raise warning + raise RuntimeError( + "The subset {} does not match any of the supported tedlium subsets{} ".format( + subset, _RELEASE_CONFIGS[release]["supported_subsets"], + ) + ) + + basename = os.path.basename(url) + archive = os.path.join(root, basename) + + basename = basename.split(".")[0] + + self._path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["data_path"]) + if subset in ["train", "dev", "test"]: + self._path = os.path.join(self._path, subset) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _RELEASE_CONFIGS[release]["checksum"] + download_url(url, root, hash_value=checksum) + extract_archive(archive) + + # Create list for all samples + self._filelist = [] + stm_path = os.path.join(self._path, "stm") + for file in sorted(os.listdir(stm_path)): + if file.endswith(".stm"): + stm_path = os.path.join(self._path, "stm", file) + with open(stm_path) as f: + l = len(f.readlines()) + file = file.replace(".stm", "") + self._filelist.extend((file, line) for line in range(l)) + # Create dict path for later read + self._dict_path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["dict"]) + self._phoneme_dict = None + + def _load_tedlium_item(self, fileid: str, line: int, path: str) -> Tuple[Tensor, int, str, int, int, int]: + """Loads a TEDLIUM dataset sample given a file name and corresponding sentence name. + + Args: + fileid (str): File id to identify both text and audio files corresponding to the sample + line (int): Line identifier for the sample inside the text file + path (str): Dataset root path + + Returns: + Tedlium_item: A namedTuple containing [waveform, sample_rate, transcript, talk_id, speaker_id, identifier] + """ + transcript_path = os.path.join(path, "stm", fileid) + with open(transcript_path + ".stm") as f: + transcript = f.readlines()[line] + talk_id, _, speaker_id, start_time, end_time, identifier, transcript = transcript.split(" ", 6) + + wave_path = os.path.join(path, "sph", fileid) + waveform, sample_rate = self._load_audio(wave_path + self._ext_audio, start_time=start_time, end_time=end_time) + + return (waveform, sample_rate, transcript, talk_id, speaker_id, identifier) + + def _load_audio(self, path: str, start_time: float, end_time: float, sample_rate: int = 16000) -> [Tensor, int]: + """Default load function used in TEDLIUM dataset, you can overwrite this function to customize functionality + and load individual sentences from a full ted audio talk file. + + Args: + path (str): Path to audio file + start_time (int, optional): Time in seconds where the sample sentence stars + end_time (int, optional): Time in seconds where the sample sentence finishes + + Returns: + [Tensor, int]: Audio tensor representation and sample rate + """ + start_time = int(float(start_time) * sample_rate) + end_time = int(float(end_time) * sample_rate) + if torchaudio.get_audio_backend() == "sox_io": + return torchaudio.load(path, frame_offset=start_time, num_frames=end_time - start_time) + return torchaudio.load(path)[:, start_time:end_time] + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: + """TEDLIUM dataset custom function overwritting default loadbehaviour + Loads a TEDLIUM sample given a index N. + + Args: + n (int): Index of sample to be loaded + + Returns: + Tedlium_item: A namedTuple containing [waveform, sample_rate, transcript, talk_id, speaker_id, identifier] + """ + fileid, line = self._filelist[n] + return self._load_tedlium_item(fileid, line, self._path) + + def __len__(self) -> int: + """TEDLIUM dataset custom function overwritting len default behaviour. + + Returns: + int: TEDLIUM dataset length + """ + return len(self._filelist) + + @property + def phoneme_dict(self): + """Returns the phoneme dictionary of a TEDLIUM release. + + Returns: + dictionary: Phoneme dictionary for the current tedlium release + """ + # Read phoneme dictionary + if not self._phoneme_dict: + self._phoneme_dict = {} + with open(self._dict_path, "r", encoding="utf-8") as f: + for line in f.readlines(): + content = line.strip().split() + self._phoneme_dict[content[0]] = tuple(content[1:]) # content[1:] can be empty list + return self._phoneme_dict.copy()