diff --git a/test/torchaudio_unittest/sox_io_backend/load_test.py b/test/torchaudio_unittest/sox_io_backend/load_test.py index 59bcaea386..5e3e0100e6 100644 --- a/test/torchaudio_unittest/sox_io_backend/load_test.py +++ b/test/torchaudio_unittest/sox_io_backend/load_test.py @@ -291,3 +291,47 @@ def test_channels_first(self, channels_first): found, _ = sox_io_backend.load(self.path, channels_first=channels_first) expected = self.original if channels_first else self.original.transpose(1, 0) self.assertEqual(found, expected) + + +@skipIfNoExec('sox') +@skipIfNoExtension +class TestSampleRate(TempDirMixin, PytorchTestCase): + """Test the correctness of frame parameters of `sox_io_backend.load`""" + path = None + + def setUp(self): + super().setUp() + sample_rate = 16000 + original = get_wav_data('float32', num_channels=2) + self.path = self.get_temp_path('original.wave') + save_wav(self.path, original, sample_rate) + + @parameterized.expand([(8000, ), (44100, )], name_func=name_func) + def test_sample_rate(self, sample_rate): + """sample_rate changes sample rate""" + found, rate = sox_io_backend.load(self.path, sample_rate=sample_rate) + ref_path = self.get_temp_path('reference.wav') + sox_utils.run_sox_effect(self.path, ref_path, ['rate', f'{sample_rate}']) + expected, expected_rate = load_wav(ref_path) + + assert rate == expected_rate + self.assertEqual(found, expected) + + @parameterized.expand(list(itertools.product( + [8000, 44100], + [0, 1, 10, 100, 1000], + [-1, 1, 10, 100, 1000], + )), name_func=name_func) + def test_frame(self, sample_rate, frame_offset, num_frames): + """frame_offset and num_frames applied after sample_rate""" + found, rate = sox_io_backend.load( + self.path, frame_offset=frame_offset, num_frames=num_frames, sample_rate=sample_rate) + + ref_path = self.get_temp_path('reference.wav') + sox_utils.run_sox_effect(self.path, ref_path, ['rate', f'{sample_rate}']) + reference, expected_rate = load_wav(ref_path) + frame_end = None if num_frames == -1 else frame_offset + num_frames + expected = reference[:, frame_offset:frame_end] + + assert rate == expected_rate + self.assertEqual(found, expected) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 31e69c443e..ee17ba9283 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -40,6 +40,7 @@ def load( num_frames: int = -1, normalize: bool = True, channels_first: bool = True, + sample_rate: Optional[int] = None, ) -> Tuple[torch.Tensor, int]: """Load audio data from file. @@ -84,11 +85,13 @@ def load( Path to audio file frame_offset (int): Number of frames to skip before start reading data. + If ``sample_rate`` is given, frame counts start after the audio is resampled. num_frames (int): Maximum number of frames to read. ``-1`` reads all the remaining samples, starting from ``frame_offset``. This function may return the less number of frames if there is not enough frames in the given file. + If ``sample_rate`` is given, frame counts start after the audio is resampled. normalize (bool): When ``True``, this function always return ``float32``, and sample values are normalized to ``[-1.0, 1.0]``. @@ -98,6 +101,8 @@ def load( channels_first (bool): When True, the returned Tensor has dimension ``[channel, time]``. Otherwise, the returned Tensor's dimension is ``[time, channel]``. + sample_rate (int, optional): + Perform resampling. Returns: torch.Tensor: @@ -105,8 +110,9 @@ def load( integer type, else ``float32`` type. If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``. """ - signal = torch.ops.torchaudio.sox_io_load_audio_file( - filepath, frame_offset, num_frames, normalize, channels_first) + sample_rate = -1 if sample_rate is None else sample_rate + signal = torch.ops.torchaudio.sox_io_load_audio_file_v1( + filepath, frame_offset, num_frames, normalize, channels_first, sample_rate) return signal.get_tensor(), signal.get_sample_rate() diff --git a/torchaudio/csrc/register.cpp b/torchaudio/csrc/register.cpp index a1ec8f2254..63c9ffd1d4 100644 --- a/torchaudio/csrc/register.cpp +++ b/torchaudio/csrc/register.cpp @@ -51,6 +51,9 @@ TORCH_LIBRARY(torchaudio, m) { m.def( "torchaudio::sox_io_load_audio_file", &torchaudio::sox_io::load_audio_file); + m.def( + "torchaudio::sox_io_load_audio_file_v1", + &torchaudio::sox_io::load_audio_file_v1); m.def( "torchaudio::sox_io_save_audio_file", &torchaudio::sox_io::save_audio_file); diff --git a/torchaudio/csrc/sox_io.cpp b/torchaudio/csrc/sox_io.cpp index 4092d215d0..c7e4b8ad3c 100644 --- a/torchaudio/csrc/sox_io.cpp +++ b/torchaudio/csrc/sox_io.cpp @@ -53,6 +53,16 @@ c10::intrusive_ptr load_audio_file( const int64_t num_frames, const bool normalize, const bool channels_first) { + return load_audio_file_v1(path, frame_offset, num_frames, channels_first, -1); +} + +c10::intrusive_ptr load_audio_file_v1( + const std::string& path, + const int64_t frame_offset, + const int64_t num_frames, + const bool normalize, + const bool channels_first, + const int64_t sample_rate) { if (frame_offset < 0) { throw std::runtime_error( "Invalid argument: frame_offset must be non-negative."); @@ -61,8 +71,16 @@ c10::intrusive_ptr load_audio_file( throw std::runtime_error( "Invalid argument: num_frames must be -1 or greater than 0."); } + if (sample_rate == 0 || sample_rate < -1) { + throw std::runtime_error( + "Invalid argument: sample_rate must be -1 or greater than 0."); + } std::vector> effects; + if (sample_rate != -1) { + effects.emplace_back( + std::vector{"rate", std::to_string(sample_rate)}); + } if (num_frames != -1) { std::ostringstream offset, frames; offset << frame_offset << "s"; diff --git a/torchaudio/csrc/sox_io.h b/torchaudio/csrc/sox_io.h index 5288e911e8..a92bbc43d4 100644 --- a/torchaudio/csrc/sox_io.h +++ b/torchaudio/csrc/sox_io.h @@ -23,6 +23,7 @@ struct SignalInfo : torch::CustomClassHolder { c10::intrusive_ptr get_info(const std::string& path); +// ver. 0 c10::intrusive_ptr load_audio_file( const std::string& path, const int64_t frame_offset = 0, @@ -30,6 +31,15 @@ c10::intrusive_ptr load_audio_file( const bool normalize = true, const bool channels_first = true); +// ver. 1 sample_rate is added +c10::intrusive_ptr load_audio_file_v1( + const std::string& path, + const int64_t frame_offset = 0, + const int64_t num_frames = -1, + const bool normalize = true, + const bool channels_first = true, + const int64_t sample_rate = -1); + void save_audio_file( const std::string& file_name, const c10::intrusive_ptr& signal,