Skip to content

Commit

Permalink
Set individual backends
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jun 17, 2020
1 parent e7aa423 commit 2d4eb39
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ The following test modules are defined for corresponding `torchaudio` module/fun
The following is the current practice of torchaudio test suite.

1. Unless the tests are related to I/O, use synthetic data. [`common_utils`](./common_utils.py) has some data generator functions.
1. When you add a new test case, use `common_utils.TorchaudioTestCase` as base class unless your are writing tests that are common to CPU / CUDA.
1. When you add a new test case, use `common_utils.TorchaudioTestCase` as base class unless you are writing tests that are common to CPU / CUDA.
- Set class memeber `dtype`, `device` and `backend` for the desired behavior.
- If you do not set `backend` value in your test suite, then I/O functions will be unassigned and attempt to load/save file will fail.
- For `backend` value, in addition to available backends, you can also provide the value "default" and backend will be picked automatically based on availability.
Expand Down
4 changes: 4 additions & 0 deletions test/test_batch_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


class TestFunctional(common_utils.TorchaudioTestCase):
backend = 'default'
"""Test functions defined in `functional` module"""
def assert_batch_consistency(
self, functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
Expand Down Expand Up @@ -97,12 +98,15 @@ def test_sliding_window_cmn(self):
self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=False)

def test_vad(self):
common_utils.set_audio_backend('default')
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath)
self.assert_batch_consistencies(F.vad, waveform, sample_rate=sample_rate)


class TestTransforms(common_utils.TorchaudioTestCase):
backend = 'default'

"""Test suite for classes defined in `transforms` module"""
def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201))
Expand Down
2 changes: 1 addition & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_cmuarctic(self):


@common_utils.skipIfNoSoxBackend
class TestCommonVoise(common_utils.TorchaudioTestCase):
class TestCommonVoice(common_utils.TorchaudioTestCase):
backend = 'sox'
path = common_utils.get_asset_path()

Expand Down
4 changes: 4 additions & 0 deletions test/test_librosa_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def _load_audio_asset(*asset_paths, **kwargs):
class TestTransforms(common_utils.TorchaudioTestCase):
"""Test suite for functions in `transforms` module."""
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
common_utils.set_audio_backend('default')
sound, sample_rate = _load_audio_asset('sinewave.wav')
sound_librosa = sound.cpu().numpy().squeeze() # (64000)

Expand Down Expand Up @@ -268,13 +269,15 @@ def test_basics4(self):
}
self.assert_compatibilities(**kwargs)

@unittest.skipIf(not common_utils.BACKENDS_MP3, 'no backend to read mp3')
def test_MelScale(self):
"""MelScale transform is comparable to that of librosa"""
n_fft = 2048
n_mels = 256
hop_length = n_fft // 4

# Prepare spectrogram input. We use torchaudio to compute one.
common_utils.set_audio_backend('default')
sound, sample_rate = _load_audio_asset('whitenoise_1min.mp3')
sound = sound.mean(dim=0, keepdim=True)
spec_ta = F.spectrogram(
Expand All @@ -297,6 +300,7 @@ def test_InverseMelScale(self):
hop_length = n_fft // 4

# Prepare mel spectrogram input. We use torchaudio to compute one.
common_utils.set_audio_backend('default')
sound, sample_rate = _load_audio_asset(
'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14)
sound = sound.mean(dim=0, keepdim=True)
Expand Down

0 comments on commit 2d4eb39

Please sign in to comment.