Skip to content

Commit

Permalink
Make TestCases backend-aware (#719)
Browse files Browse the repository at this point in the history
* Make tests backend aware by introducing TorchaudioTestCase and reset backend for each TestCase.

* Set backends for the test cases that require specific backend.
  • Loading branch information
mthrok authored Jun 18, 2020
1 parent 03da871 commit b17da7a
Show file tree
Hide file tree
Showing 21 changed files with 160 additions and 217 deletions.
10 changes: 10 additions & 0 deletions test/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ The following test modules are defined for corresponding `torchaudio` module/fun

## Adding test

The following is the current practice of torchaudio test suite.

1. Unless the tests are related to I/O, use synthetic data. [`common_utils`](./common_utils.py) has some data generator functions.
1. When you add a new test case, use `common_utils.TorchaudioTestCase` as base class unless you are writing tests that are common to CPU / CUDA.
- Set class memeber `dtype`, `device` and `backend` for the desired behavior.
- If you do not set `backend` value in your test suite, then I/O functions will be unassigned and attempt to load/save file will fail.
- For `backend` value, in addition to available backends, you can also provide the value "default" and backend will be picked automatically based on availability.
1. If you are writing tests that should pass on diffrent dtype/devices, write a common class inheriting `common_utils.TestBaseMixin`, then inherit `common_utils.PytorchTestCase` and define class attributes (`dtype` / `device` / `backend`) there. See [Torchscript consistency test implementation](./torchscript_consistency_impl.py) and test definitions for [CPU](./torchscript_consistency_cpu_test.py) and [CUDA](./torchscript_consistency_cuda_test.py) devices.
1. For numerically comparing Tensors, use `assertEqual` method from `common_utils.PytorchTestCase` class. This method has a better support for a wide variety of Tensor types.

When you add a new feature(functional/transform), consider the following

1. When you add a new feature, please make it Torchscript-able and batch-consistent unless it degrades the performance. Please add the tests to see if the new feature meet these requirements.
Expand Down
52 changes: 29 additions & 23 deletions test/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import tempfile
import unittest
from typing import Iterable, Union
from contextlib import contextmanager
from typing import Union
from shutil import copytree

import torch
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio

_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -55,24 +54,14 @@ def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32):
return torch.tensor(arr).float().view(size) / m


@contextmanager
def AudioBackendScope(new_backend):
previous_backend = torchaudio.get_audio_backend()
try:
torchaudio.set_audio_backend(new_backend)
yield
finally:
torchaudio.set_audio_backend(previous_backend)


def filter_backends_with_mp3(backends):
# Filter out backends that do not support mp3
test_filepath = get_asset_path('steam-train-whistle-daniel_simon.mp3')

def supports_mp3(backend):
torchaudio.set_audio_backend(backend)
try:
with AudioBackendScope(backend):
torchaudio.load(test_filepath)
torchaudio.load(test_filepath)
return True
except (RuntimeError, ImportError):
return False
Expand All @@ -83,21 +72,38 @@ def supports_mp3(backend):
BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS)


def set_audio_backend(backend):
"""Allow additional backend value, 'default'"""
if backend == 'default':
if 'sox' in BACKENDS:
be = 'sox'
elif 'soundfile' in BACKENDS:
be = 'soundfile'
else:
raise unittest.SkipTest('No default backend available')
else:
be = backend

torchaudio.set_audio_backend(be)


class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None
device = None
backend = None

def setUp(self):
super().setUp()
set_audio_backend(self.backend)

skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')

class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
pass

def common_test_class_parameters(
dtypes: Iterable[str] = ("float32", "float64"),
devices: Iterable[str] = ("cpu", "cuda"),
):
for device in devices:
for dtype in dtypes:
yield {"device": torch.device(device), "dtype": getattr(torch, dtype)}

skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')


def get_whitenoise(
Expand Down
14 changes: 8 additions & 6 deletions test/functional_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
from .functional_impl import Lfilter


class TestLFilterFloat32(Lfilter, common_utils.TestCase):
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')


class TestLFilterFloat64(Lfilter, common_utils.TestCase):
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestComputeDeltas(unittest.TestCase):
class TestComputeDeltas(common_utils.TorchaudioTestCase):
"""Test suite for correctness of compute_deltas"""
def test_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
Expand Down Expand Up @@ -57,7 +57,7 @@ def _test_istft_is_inverse_of_stft(kwargs):
_compare_estimate(sound, estimate)


class TestIstft(unittest.TestCase):
class TestIstft(common_utils.TorchaudioTestCase):
"""Test suite for correctness of istft with various input"""
number_of_trials = 100

Expand Down Expand Up @@ -273,7 +273,9 @@ def test_linearity_of_istft4(self):
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)


class TestDetectPitchFrequency(unittest.TestCase):
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
backend = 'default'

def test_pitch(self):
test_filepath_100 = common_utils.get_asset_path("100Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = common_utils.get_asset_path("440Hz_44100Hz_16bit_05sec.wav")
Expand All @@ -294,7 +296,7 @@ def test_pitch(self):
self.assertFalse(s)


class TestDB_to_amplitude(unittest.TestCase):
class TestDB_to_amplitude(common_utils.TorchaudioTestCase):
def test_DB_to_amplitude(self):
# Make some noise
x = torch.rand(1000)
Expand Down
4 changes: 2 additions & 2 deletions test/functional_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@


@common_utils.skipIfNoCuda
class TestLFilterFloat32(Lfilter, common_utils.TestCase):
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestLFilterFloat64(Lfilter, common_utils.TestCase):
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
4 changes: 2 additions & 2 deletions test/kaldi_compatibility_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from .kaldi_compatibility_impl import Kaldi


class TestKaldiFloat32(Kaldi, common_utils.TestCase):
class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')


class TestKaldiFloat64(Kaldi, common_utils.TestCase):
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
4 changes: 2 additions & 2 deletions test/kaldi_compatibility_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@


@common_utils.skipIfNoCuda
class TestKaldiFloat32(Kaldi, common_utils.TestCase):
class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestKaldiFloat64(Kaldi, common_utils.TestCase):
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
2 changes: 2 additions & 0 deletions test/kaldi_compatibility_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def _load_params(path):


class Kaldi(common_utils.TestBaseMixin):
backend = 'sox'

def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
Expand Down
10 changes: 6 additions & 4 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import torchaudio
from torchaudio._internal.module_utils import is_module_available

from . import common_utils

class BackendSwitch:

class BackendSwitchMixin:
"""Test set/get_audio_backend works"""
backend = None
backend_module = None
Expand All @@ -21,20 +23,20 @@ def test_switch(self):
assert torchaudio.info == self.backend_module.info


class TestBackendSwitch_NoBackend(BackendSwitch, unittest.TestCase):
class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = None
backend_module = torchaudio.backend.no_backend


@unittest.skipIf(
not is_module_available('torchaudio._torchaudio'),
'torchaudio C++ extension not available')
class TestBackendSwitch_SoX(BackendSwitch, unittest.TestCase):
class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox'
backend_module = torchaudio.backend.sox_backend


@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available')
class TestBackendSwitch_soundfile(BackendSwitch, unittest.TestCase):
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
backend_module = torchaudio.backend.soundfile_backend
9 changes: 6 additions & 3 deletions test/test_batch_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import unittest

import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.functional as F

from . import common_utils


class TestFunctional(TestCase):
class TestFunctional(common_utils.TorchaudioTestCase):
backend = 'default'
"""Test functions defined in `functional` module"""
def assert_batch_consistency(
self, functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
Expand Down Expand Up @@ -98,12 +98,15 @@ def test_sliding_window_cmn(self):
self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=False)

def test_vad(self):
common_utils.set_audio_backend('default')
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath)
self.assert_batch_consistencies(F.vad, waveform, sample_rate=sample_rate)


class TestTransforms(TestCase):
class TestTransforms(common_utils.TorchaudioTestCase):
backend = 'default'

"""Test suite for classes defined in `transforms` module"""
def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201))
Expand Down
13 changes: 7 additions & 6 deletions test/test_compliance_kaldi.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import math
import os
import math
import unittest

import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import unittest

from . import common_utils
from .compliance import utils as compliance_utils
from .common_utils import AudioBackendScope, BACKENDS


def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
Expand Down Expand Up @@ -46,7 +46,10 @@ def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
window[f, s] = wave[s_in_wave]


class Test_Kaldi(unittest.TestCase):
@common_utils.skipIfNoSoxBackend
class Test_Kaldi(common_utils.TorchaudioTestCase):
backend = 'sox'

test_filepath = common_utils.get_asset_path('kaldi_file.wav')
test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
kaldi_output_dir = common_utils.get_asset_path('kaldi')
Expand Down Expand Up @@ -162,8 +165,6 @@ def test_mfcc_empty(self):
# Passing in an empty tensor should result in an error
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_resample_waveform(self):
def get_output_fn(sound, args):
output = kaldi.resample_waveform(sound, args[1], args[2])
Expand Down
8 changes: 4 additions & 4 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch.utils.data import Dataset, DataLoader

from . import common_utils
from .common_utils import AudioBackendScope, BACKENDS


class TORCHAUDIODS(Dataset):
Expand All @@ -28,9 +27,10 @@ def __len__(self):
return len(self.data)


class Test_DataLoader(unittest.TestCase):
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
class Test_DataLoader(common_utils.TorchaudioTestCase):
backend = 'sox'

@common_utils.skipIfNoSoxBackend
def test_1(self):
expected_size = (2, 1, 16000)
ds = TORCHAUDIODS()
Expand Down
Loading

0 comments on commit b17da7a

Please sign in to comment.