diff --git a/test/sox_io_backend/test_torchscript.py b/test/sox_io_backend/test_torchscript.py index fd7d6f8b64..522c1319ce 100644 --- a/test/sox_io_backend/test_torchscript.py +++ b/test/sox_io_backend/test_torchscript.py @@ -2,7 +2,7 @@ from typing import Optional import torch -from torchaudio.backend import sox_io_backend +import torchaudio from parameterized import parameterized from ..common_utils import ( @@ -21,11 +21,11 @@ def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: - return sox_io_backend.info(filepath) + return torchaudio.info(filepath) def py_load_func(filepath: str, normalize: bool, channels_first: bool): - return sox_io_backend.load( + return torchaudio.load( filepath, normalize=normalize, channels_first=channels_first) @@ -36,13 +36,15 @@ def py_save_func( channels_first: bool = True, compression: Optional[float] = None, ): - sox_io_backend.save(filepath, tensor, sample_rate, channels_first, compression) + torchaudio.save(filepath, tensor, sample_rate, channels_first, compression) @skipIfNoExec('sox') @skipIfNoExtension class SoxIO(TempDirMixin, TorchaudioTestCase): """TorchScript-ability Test suite for `sox_io_backend`""" + backend = 'sox_io' + @parameterized.expand(list(itertools.product( ['float32', 'int32', 'int16', 'uint8'], [8000, 16000], diff --git a/test/test_backend.py b/test/test_backend.py index a13dd5088a..6b67cb2898 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -1,5 +1,3 @@ -import unittest - import torchaudio from . import common_utils @@ -33,6 +31,12 @@ class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase) backend_module = torchaudio.backend.sox_backend +@common_utils.skipIfNoExtension +class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase): + backend = 'sox_io' + backend_module = torchaudio.backend.sox_io_backend + + @common_utils.skipIfNoModule('soundfile') class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase): backend = 'soundfile' diff --git a/test/test_io.py b/test/test_io.py index 3c5f00006d..2c8aece85e 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -30,11 +30,15 @@ class Test_LoadSave(unittest.TestCase): def test_1_save(self): for backend in BACKENDS_MP3: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_1_save(self.test_filepath, False) for backend in BACKENDS: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_1_save(self.test_filepath_wav, True) @@ -81,6 +85,8 @@ def _test_1_save(self, test_filepath, normalization): def test_1_save_sine(self): for backend in BACKENDS: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_1_save_sine() @@ -114,11 +120,15 @@ def _test_1_save_sine(self): def test_2_load(self): for backend in BACKENDS_MP3: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_2_load(self.test_filepath, 278756) for backend in BACKENDS: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_2_load(self.test_filepath_wav, 276858) @@ -155,6 +165,8 @@ def _test_2_load(self, test_filepath, length): def test_2_load_nonormalization(self): for backend in BACKENDS_MP3: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_2_load_nonormalization(self.test_filepath, 278756) @@ -172,6 +184,8 @@ def _test_2_load_nonormalization(self, test_filepath, length): def test_3_load_and_save_is_identity(self): for backend in BACKENDS: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_3_load_and_save_is_identity() @@ -210,6 +224,8 @@ def _test_3_load_and_save_is_identity_across_backend(self, backend1, backend2): def test_4_load_partial(self): for backend in BACKENDS_MP3: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_4_load_partial() @@ -252,6 +268,8 @@ def _test_4_load_partial(self): def test_5_get_info(self): for backend in BACKENDS: + if backend == 'sox_io': + continue with self.subTest(): torchaudio.set_audio_backend(backend) self._test_5_get_info() diff --git a/torchaudio/backend/utils.py b/torchaudio/backend/utils.py index d537f01daf..cb53b3e02f 100644 --- a/torchaudio/backend/utils.py +++ b/torchaudio/backend/utils.py @@ -7,6 +7,7 @@ from . import ( no_backend, sox_backend, + sox_io_backend, soundfile_backend, ) @@ -24,6 +25,7 @@ def list_audio_backends() -> List[str]: backends.append('soundfile') if is_module_available('torchaudio._torchaudio'): backends.append('sox') + backends.append('sox_io') return backends @@ -43,6 +45,8 @@ def set_audio_backend(backend: Optional[str]) -> None: module = no_backend elif backend == 'sox': module = sox_backend + elif backend == 'sox_io': + module = sox_io_backend elif backend == 'soundfile': module = soundfile_backend else: @@ -69,6 +73,8 @@ def get_audio_backend() -> Optional[str]: return None if torchaudio.load == sox_backend.load: return 'sox' + if torchaudio.load == sox_io_backend.load: + return 'sox_io' if torchaudio.load == soundfile_backend.load: return 'soundfile' raise ValueError('Unknown backend.')