From 3cdcd7ba97cffc5c947242f4320d58e089896030 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 23 Jul 2020 12:40:53 -0400 Subject: [PATCH] Refactor datasets test (#817) --- test/datasets/__init__.py | 0 test/datasets/datasets_test.py | 65 +++++++++++++ test/datasets/libritts_test.py | 71 ++++++++++++++ test/datasets/yesno_test.py | 48 +++++++++ test/test_datasets.py | 173 --------------------------------- 5 files changed, 184 insertions(+), 173 deletions(-) create mode 100644 test/datasets/__init__.py create mode 100644 test/datasets/datasets_test.py create mode 100644 test/datasets/libritts_test.py create mode 100644 test/datasets/yesno_test.py delete mode 100644 test/test_datasets.py diff --git a/test/datasets/__init__.py b/test/datasets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/datasets/datasets_test.py b/test/datasets/datasets_test.py new file mode 100644 index 0000000000..5026b270a4 --- /dev/null +++ b/test/datasets/datasets_test.py @@ -0,0 +1,65 @@ +from torchaudio.datasets.commonvoice import COMMONVOICE +from torchaudio.datasets.librispeech import LIBRISPEECH +from torchaudio.datasets.speechcommands import SPEECHCOMMANDS +from torchaudio.datasets.utils import diskcache_iterator, bg_iterator +from torchaudio.datasets.vctk import VCTK +from torchaudio.datasets.ljspeech import LJSPEECH +from torchaudio.datasets.gtzan import GTZAN +from torchaudio.datasets.cmuarctic import CMUARCTIC + +from ..common_utils import ( + TorchaudioTestCase, + get_asset_path, +) + + +class TestDatasets(TorchaudioTestCase): + backend = 'default' + path = get_asset_path() + + def test_vctk(self): + data = VCTK(self.path) + data[0] + + def test_librispeech(self): + data = LIBRISPEECH(self.path, "dev-clean") + data[0] + + def test_ljspeech(self): + data = LJSPEECH(self.path) + data[0] + + def test_speechcommands(self): + data = SPEECHCOMMANDS(self.path) + data[0] + + def test_gtzan(self): + data = GTZAN(self.path) + data[0] + + def test_cmuarctic(self): + data = CMUARCTIC(self.path) + data[0] + + +class TestCommonVoice(TorchaudioTestCase): + backend = 'default' + path = get_asset_path() + + def test_commonvoice(self): + data = COMMONVOICE(self.path, url="tatar") + data[0] + + def test_commonvoice_diskcache(self): + data = COMMONVOICE(self.path, url="tatar") + data = diskcache_iterator(data) + # Save + data[0] + # Load + data[0] + + def test_commonvoice_bg(self): + data = COMMONVOICE(self.path, url="tatar") + data = bg_iterator(data, 5) + for _ in data: + pass diff --git a/test/datasets/libritts_test.py b/test/datasets/libritts_test.py new file mode 100644 index 0000000000..b84cfdac30 --- /dev/null +++ b/test/datasets/libritts_test.py @@ -0,0 +1,71 @@ +import os + +from torchaudio.datasets.libritts import LIBRITTS + +from ..common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + + +class TestLibriTTS(TempDirMixin, TorchaudioTestCase): + backend = 'default' + + root_dir = None + data = [] + utterance_ids = [ + [19, 198, '000000', '000000'], + [26, 495, '000004', '000000'], + ] + original_text = 'this is the original text.' + normalized_text = 'this is the normalized text.' + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + base_dir = os.path.join(cls.root_dir, 'LibriTTS', 'train-clean-100') + for i, utterance_id in enumerate(cls.utterance_ids): + filename = f'{"_".join(str(u) for u in utterance_id)}.wav' + file_dir = os.path.join(base_dir, str(utterance_id[0]), str(utterance_id[1])) + os.makedirs(file_dir, exist_ok=True) + path = os.path.join(file_dir, filename) + + data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16', seed=i) + save_wav(path, data, 8000) + cls.data.append(normalize_wav(data)) + + original_text_filename = f'{"_".join(str(u) for u in utterance_id)}.original.txt' + path_original = os.path.join(file_dir, original_text_filename) + with open(path_original, 'w') as file_: + file_.write(cls.original_text) + + normalized_text_filename = f'{"_".join(str(u) for u in utterance_id)}.normalized.txt' + path_normalized = os.path.join(file_dir, normalized_text_filename) + with open(path_normalized, 'w') as file_: + file_.write(cls.normalized_text) + + def test_libritts(self): + dataset = LIBRITTS(self.root_dir) + samples = list(dataset) + samples.sort(key=lambda s: s[4]) + + for i, (waveform, + sample_rate, + original_text, + normalized_text, + speaker_id, + chapter_id, + utterance_id) in enumerate(samples): + + expected_ids = self.utterance_ids[i] + expected_data = self.data[i] + self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) + assert sample_rate == 8000 + assert speaker_id == expected_ids[0] + assert chapter_id == expected_ids[1] + assert original_text == self.original_text + assert normalized_text == self.normalized_text + assert utterance_id == f'{"_".join(str(u) for u in expected_ids[-4:])}' diff --git a/test/datasets/yesno_test.py b/test/datasets/yesno_test.py new file mode 100644 index 0000000000..42b6112e7f --- /dev/null +++ b/test/datasets/yesno_test.py @@ -0,0 +1,48 @@ +import os + +from torchaudio.datasets import yesno + +from ..common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + + +class TestYesNo(TempDirMixin, TorchaudioTestCase): + backend = 'default' + + root_dir = None + data = [] + labels = [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1], + [0, 1, 0, 1, 0, 1, 1, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1], + ] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + base_dir = os.path.join(cls.root_dir, 'waves_yesno') + os.makedirs(base_dir, exist_ok=True) + for label in cls.labels: + filename = f'{"_".join(str(l) for l in label)}.wav' + path = os.path.join(base_dir, filename) + data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16') + save_wav(path, data, 8000) + cls.data.append(normalize_wav(data)) + + def test_yesno(self): + dataset = yesno.YESNO(self.root_dir) + samples = list(dataset) + samples.sort(key=lambda s: s[2]) + for i, (waveform, sample_rate, label) in enumerate(samples): + expected_label = self.labels[i] + expected_data = self.data[i] + self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) + assert sample_rate == 8000 + assert label == expected_label diff --git a/test/test_datasets.py b/test/test_datasets.py deleted file mode 100644 index c77aa4917e..0000000000 --- a/test/test_datasets.py +++ /dev/null @@ -1,173 +0,0 @@ -import os -import unittest - -from torchaudio.datasets.commonvoice import COMMONVOICE -from torchaudio.datasets.librispeech import LIBRISPEECH -from torchaudio.datasets.speechcommands import SPEECHCOMMANDS -from torchaudio.datasets.utils import diskcache_iterator, bg_iterator -from torchaudio.datasets.vctk import VCTK -from torchaudio.datasets.yesno import YESNO -from torchaudio.datasets.ljspeech import LJSPEECH -from torchaudio.datasets.gtzan import GTZAN -from torchaudio.datasets.cmuarctic import CMUARCTIC -from torchaudio.datasets.libritts import LIBRITTS - -from .common_utils import ( - TempDirMixin, - TorchaudioTestCase, - get_asset_path, - get_whitenoise, - save_wav, - normalize_wav, -) - - -class TestDatasets(TorchaudioTestCase): - backend = 'default' - path = get_asset_path() - - def test_vctk(self): - data = VCTK(self.path) - data[0] - - def test_librispeech(self): - data = LIBRISPEECH(self.path, "dev-clean") - data[0] - - def test_ljspeech(self): - data = LJSPEECH(self.path) - data[0] - - def test_speechcommands(self): - data = SPEECHCOMMANDS(self.path) - data[0] - - def test_gtzan(self): - data = GTZAN(self.path) - data[0] - - def test_cmuarctic(self): - data = CMUARCTIC(self.path) - data[0] - - -class TestCommonVoice(TorchaudioTestCase): - backend = 'default' - path = get_asset_path() - - def test_commonvoice(self): - data = COMMONVOICE(self.path, url="tatar") - data[0] - - def test_commonvoice_diskcache(self): - data = COMMONVOICE(self.path, url="tatar") - data = diskcache_iterator(data) - # Save - data[0] - # Load - data[0] - - def test_commonvoice_bg(self): - data = COMMONVOICE(self.path, url="tatar") - data = bg_iterator(data, 5) - for _ in data: - pass - - -class TestYesNo(TempDirMixin, TorchaudioTestCase): - backend = 'default' - - root_dir = None - data = [] - labels = [ - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 1], - [0, 1, 0, 1, 0, 1, 1, 0], - [1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1], - ] - - @classmethod - def setUpClass(cls): - cls.root_dir = cls.get_base_temp_dir() - base_dir = os.path.join(cls.root_dir, 'waves_yesno') - os.makedirs(base_dir, exist_ok=True) - for label in cls.labels: - filename = f'{"_".join(str(l) for l in label)}.wav' - path = os.path.join(base_dir, filename) - data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16') - save_wav(path, data, 8000) - cls.data.append(normalize_wav(data)) - - def test_yesno(self): - dataset = YESNO(self.root_dir) - samples = list(dataset) - samples.sort(key=lambda s: s[2]) - for i, (waveform, sample_rate, label) in enumerate(samples): - expected_label = self.labels[i] - expected_data = self.data[i] - self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) - assert sample_rate == 8000 - assert label == expected_label - - -class TestLibriTTS(TempDirMixin, TorchaudioTestCase): - backend = 'default' - - root_dir = None - data = [] - utterance_ids = [ - [19, 198, '000000', '000000'], - [26, 495, '000004', '000000'], - ] - original_text = 'this is the original text.' - normalized_text = 'this is the normalized text.' - - @classmethod - def setUpClass(cls): - cls.root_dir = cls.get_base_temp_dir() - base_dir = os.path.join(cls.root_dir, 'LibriTTS', 'train-clean-100') - for i, utterance_id in enumerate(cls.utterance_ids): - filename = f'{"_".join(str(u) for u in utterance_id)}.wav' - file_dir = os.path.join(base_dir, str(utterance_id[0]), str(utterance_id[1])) - os.makedirs(file_dir, exist_ok=True) - path = os.path.join(file_dir, filename) - - data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16', seed=i) - save_wav(path, data, 8000) - cls.data.append(normalize_wav(data)) - - original_text_filename = f'{"_".join(str(u) for u in utterance_id)}.original.txt' - path_original = os.path.join(file_dir, original_text_filename) - f = open(path_original, 'w') - f.write(cls.original_text) - f.close() - - normalized_text_filename = f'{"_".join(str(u) for u in utterance_id)}.normalized.txt' - path_normalized = os.path.join(file_dir, normalized_text_filename) - f = open(path_normalized, 'w') - f.write(cls.normalized_text) - f.close() - - def test_libritts(self): - dataset = LIBRITTS(self.root_dir) - samples = list(dataset) - samples.sort(key=lambda s: s[4]) - - for i, (waveform, - sample_rate, - original_text, - normalized_text, - speaker_id, - chapter_id, - utterance_id) in enumerate(samples): - - expected_ids = self.utterance_ids[i] - expected_data = self.data[i] - self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) - assert sample_rate == 8000 - assert speaker_id == expected_ids[0] - assert chapter_id == expected_ids[1] - assert original_text == self.original_text - assert normalized_text == self.normalized_text - assert utterance_id == f'{"_".join(str(u) for u in expected_ids[-4:])}'