Skip to content

Commit

Permalink
Add tedlium dataset (#882)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiwidi authored Sep 15, 2020
1 parent b6a61c3 commit 914a846
Show file tree
Hide file tree
Showing 4 changed files with 375 additions and 5 deletions.
15 changes: 11 additions & 4 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ torchaudio.datasets
All datasets are subclasses of :class:`torch.utils.data.Dataset`
i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
For example: ::

yesno_data = torchaudio.datasets.YESNO('.', download=True)
data_loader = torch.utils.data.DataLoader(yesno_data,
data_loader = torch.utils.data.DataLoader(yesno_data,
batch_size=1,
shuffle=True,
num_workers=args.nThreads)
Expand All @@ -22,7 +22,7 @@ All the datasets have almost similar API. They all have two common arguments:
``transform`` and ``target_transform`` to transform the input and target respectively.


.. currentmodule:: torchaudio.datasets
.. currentmodule:: torchaudio.datasets


CMUARCTIC
Expand Down Expand Up @@ -81,6 +81,13 @@ SPEECHCOMMANDS
:special-members:


TEDLIUM
~~~~~~~~~~~~~~

.. autoclass:: TEDLIUM
:members: __getitem__
:special-members: get_phoneme_dict

VCTK
~~~~

Expand Down
153 changes: 153 additions & 0 deletions test/torchaudio_unittest/datasets/tedlium_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os

from torchaudio.datasets import tedlium

from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_whitenoise,
save_wav,
normalize_wav,
)

# Used to generate a unique utterance for each dummy audio file
UTTERANCES = [
"AaronHuey_2010X 1 AaronHuey_2010X 0.0 2.0 <o,f0,female> script1\n",
"AaronHuey_2010X 1 AaronHuey_2010X 2.0 4.0 <o,f0,female> script2\n",
"AaronHuey_2010X 1 AaronHuey_2010X 4.0 6.0 <o,f0,female> script3\n",
"AaronHuey_2010X 1 AaronHuey_2010X 6.0 8.0 <o,f0,female> script4\n",
"AaronHuey_2010X 1 AaronHuey_2010X 8.0 10.0 <o,f0,female> script5\n",
]

PHONEME = [
"a AH",
"a(2) EY",
"aachen AA K AH N",
"aad AE D",
"aaden EY D AH N",
"aadmi AE D M IY",
"aae EY EY",
]


class TestTedlium(TempDirMixin, TorchaudioTestCase):
backend = "default"

root_dir = None
samples = {}

@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.root_dir = dataset_dir = os.path.join(cls.root_dir, "tedlium")
os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 16000 # 16kHz
seed = 0

for release in ["release1", "release2", "release3"]:
data = get_whitenoise(sample_rate=sample_rate, duration=10.00, n_channels=1, dtype="float32", seed=seed)
if release in ["release1", "release2"]:
release_dir = os.path.join(
dataset_dir,
tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
tedlium._RELEASE_CONFIGS[release]["subset"],
)
else:
release_dir = os.path.join(
dataset_dir,
tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
tedlium._RELEASE_CONFIGS[release]["data_path"],
)
os.makedirs(release_dir, exist_ok=True)
os.makedirs(os.path.join(release_dir, "stm"), exist_ok=True) # Subfolder for transcripts
os.makedirs(os.path.join(release_dir, "sph"), exist_ok=True) # Subfolder for audio files
filename = f"{release}.sph"
path = os.path.join(os.path.join(release_dir, "sph"), filename)
save_wav(path, data, sample_rate)

trans_filename = f"{release}.stm"
trans_path = os.path.join(os.path.join(release_dir, "stm"), trans_filename)
with open(trans_path, "w") as f:
f.write("".join(UTTERANCES))

dict_filename = f"{release}.dic"
dict_path = os.path.join(release_dir, dict_filename)
with open(dict_path, "w") as f:
f.write("\n".join(PHONEME))

# Create a samples list to compare with
cls.samples[release] = []
for utterance in UTTERANCES:
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = utterance.split(" ", 6)
start_time = int(float(start_time)) * sample_rate
end_time = int(float(end_time)) * sample_rate
sample = (
data[:, start_time:end_time],
sample_rate,
transcript,
talk_id,
speaker_id,
identifier,
)
cls.samples[release].append(sample)
seed += 1

def test_tedlium_release1(self):
release = "release1"
dataset = tedlium.TEDLIUM(self.root_dir, release=release)
num_samples = 0
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[release][i][1]
assert transcript == self.samples[release][i][2]
assert talk_id == self.samples[release][i][3]
assert speaker_id == self.samples[release][i][4]
assert identifier == self.samples[release][i][5]
num_samples += 1

assert num_samples == len(self.samples[release])

dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
phoneme_dict = dataset.phoneme_dict
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME

def test_tedlium_release2(self):
release = "release2"
dataset = tedlium.TEDLIUM(self.root_dir, release=release)
num_samples = 0
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[release][i][1]
assert transcript == self.samples[release][i][2]
assert talk_id == self.samples[release][i][3]
assert speaker_id == self.samples[release][i][4]
assert identifier == self.samples[release][i][5]
num_samples += 1

assert num_samples == len(self.samples[release])

dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
phoneme_dict = dataset.phoneme_dict
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME

def test_tedlium_release3(self):
release = "release3"
dataset = tedlium.TEDLIUM(self.root_dir, release=release)
num_samples = 0
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[release][i][1]
assert transcript == self.samples[release][i][2]
assert talk_id == self.samples[release][i][3]
assert speaker_id == self.samples[release][i][4]
assert identifier == self.samples[release][i][5]
num_samples += 1

assert num_samples == len(self.samples[release])

dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
phoneme_dict = dataset.phoneme_dict
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME
4 changes: 3 additions & 1 deletion torchaudio/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .ljspeech import LJSPEECH
from .cmuarctic import CMUARCTIC
from .libritts import LIBRITTS
from .tedlium import TEDLIUM

__all__ = (
"COMMONVOICE",
Expand All @@ -19,7 +20,8 @@
"LJSPEECH",
"GTZAN",
"CMUARCTIC",
"LIBRITTS"
"LIBRITTS",
"diskcache_iterator",
"bg_iterator",
"TEDLIUM",
)
Loading

0 comments on commit 914a846

Please sign in to comment.