Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tedlium dataset (all 3 releases) #882

Merged
merged 16 commits into from
Sep 15, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ torchaudio.datasets
All datasets are subclasses of :class:`torch.utils.data.Dataset`
i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
For example: ::

yesno_data = torchaudio.datasets.YESNO('.', download=True)
data_loader = torch.utils.data.DataLoader(yesno_data,
data_loader = torch.utils.data.DataLoader(yesno_data,
batch_size=1,
shuffle=True,
num_workers=args.nThreads)
Expand All @@ -22,7 +22,7 @@ All the datasets have almost similar API. They all have two common arguments:
``transform`` and ``target_transform`` to transform the input and target respectively.


.. currentmodule:: torchaudio.datasets
.. currentmodule:: torchaudio.datasets


CMUARCTIC
Expand Down Expand Up @@ -81,6 +81,13 @@ SPEECHCOMMANDS
:special-members:


TEDLIUM
~~~~~~~~~~~~~~

.. autoclass:: TEDLIUM
:members: __getitem__
:special-members: get_phoneme_dict
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mthrok What does special-members mean here? I interpretate it as extra functions to include in the docs? Thats why I included get_phoneme_dict

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are Sphinx's directive. Check out their documentation

  • :members: is where you list the members you want document
  • :special-members: is where you list the special methods like __init__, __len__, __getitem__ etc...

I think the other documentations are wrong, (__getitem__ should be under :special-member: but it will not show up either way because they don't have a docstring.)

I think you can just do .. autoclass:: TEDLIUM and the rest (get_phoneme_dict) will be handled.

You can build the documentation and check how the resulting documentation looks like.

cd docs
pip install -r requirements.txt
make html
# open ./build/html/index.html


VCTK
~~~~

Expand Down
154 changes: 154 additions & 0 deletions test/torchaudio_unittest/datasets/tedlium_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os

from torchaudio.datasets import tedlium

from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_whitenoise,
save_wav,
normalize_wav,
)

# Used to generate a unique utterance for each dummy audio file
UTTERANCES = [
"AaronHuey_2010X 1 AaronHuey_2010X 0.0 2.0 <o,f0,female> script1\n",
"AaronHuey_2010X 1 AaronHuey_2010X 2.0 4.0 <o,f0,female> script2\n",
"AaronHuey_2010X 1 AaronHuey_2010X 4.0 6.0 <o,f0,female> script3\n",
"AaronHuey_2010X 1 AaronHuey_2010X 6.0 8.0 <o,f0,female> script4\n",
"AaronHuey_2010X 1 AaronHuey_2010X 8.0 10.0 <o,f0,female> script5\n",
]

PHONEME = [
"a AH",
"a(2) EY",
"aachen AA K AH N",
"aad AE D",
"aaden EY D AH N",
"aadmi AE D M IY",
"aae EY EY",
]


class TestTedlium(TempDirMixin, TorchaudioTestCase):
backend = "default"

root_dir = None
samples = {}

@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.root_dir = dataset_dir = os.path.join(cls.root_dir, "tedlium")
os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 16000 # 16kHz
seed = 0

for release in ["release1", "release2", "release3"]:
data = get_whitenoise(sample_rate=sample_rate, duration=10.00, n_channels=1, dtype="float32", seed=seed)
if release in ["release1", "release2"]:
release_dir = os.path.join(
dataset_dir,
tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
tedlium._RELEASE_CONFIGS[release]["subset"],
)
else:
release_dir = os.path.join(
dataset_dir,
tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
tedlium._RELEASE_CONFIGS[release]["data_path"],
)
os.makedirs(release_dir, exist_ok=True)
os.makedirs(os.path.join(release_dir, "stm"), exist_ok=True) # Subfolder for transcripts
os.makedirs(os.path.join(release_dir, "sph"), exist_ok=True) # Subfolder for audio files
filename = f"{release}.sph"
path = os.path.join(os.path.join(release_dir, "sph"), filename)
save_wav(path, data, sample_rate)

trans_filename = f"{release}.stm"
trans_path = os.path.join(os.path.join(release_dir, "stm"), trans_filename)
with open(trans_path, "w") as f:
f.write("".join(UTTERANCES))

dict_filename = f"{release}.dic"
dict_path = os.path.join(release_dir, dict_filename)
with open(dict_path, "w") as f:
f.write("\n".join(PHONEME))

# Create a samples list to compare with
cls.samples[release] = []
for utterance in UTTERANCES:
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = utterance.split(" ", 6)
start_time = int(float(start_time)) * sample_rate
end_time = int(float(end_time)) * sample_rate
sample = (
data[:, start_time:end_time],
sample_rate,
transcript,
talk_id,
speaker_id,
identifier,
)
cls.samples[release].append(sample)
seed += 1

def test_tedlium_release1(self):
release = "release1"
dataset = tedlium.TEDLIUM(self.root_dir, release=release)
num_samples = 0
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[release][i][1]
assert transcript == self.samples[release][i][2]
assert talk_id == self.samples[release][i][3]
assert speaker_id == self.samples[release][i][4]
assert identifier == self.samples[release][i][5]
num_samples += 1

assert num_samples == len(self.samples[release])

dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
phoneme_dict = dataset.phoneme_dict
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME

def test_tedlium_release2(self):
release = "release2"
dataset = tedlium.TEDLIUM(self.root_dir, release=release)
num_samples = 0
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[release][i][1]
assert transcript == self.samples[release][i][2]
assert talk_id == self.samples[release][i][3]
assert speaker_id == self.samples[release][i][4]
assert identifier == self.samples[release][i][5]
num_samples += 1

assert num_samples == len(self.samples[release])

dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
phoneme_dict = dataset.phoneme_dict
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME

def test_tedlium_release3(self):
release = "release3"
dataset = tedlium.TEDLIUM(self.root_dir, release=release)
num_samples = 0
for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[release][i][1]
assert transcript == self.samples[release][i][2]
assert talk_id == self.samples[release][i][3]
assert speaker_id == self.samples[release][i][4]
assert identifier == self.samples[release][i][5]
num_samples += 1

assert num_samples == len(self.samples[release])

dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
phoneme_dict = dataset.phoneme_dict
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME

4 changes: 3 additions & 1 deletion torchaudio/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .ljspeech import LJSPEECH
from .cmuarctic import CMUARCTIC
from .libritts import LIBRITTS
from .tedlium import TEDLIUM

__all__ = (
"COMMONVOICE",
Expand All @@ -18,7 +19,8 @@
"LJSPEECH",
"GTZAN",
"CMUARCTIC",
"LIBRITTS"
"LIBRITTS",
mthrok marked this conversation as resolved.
Show resolved Hide resolved
"diskcache_iterator",
"bg_iterator",
"TEDLIUM",
)
Loading