From 97db145b101edee591dea16e1c1b53403795315a Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 24 Dec 2020 04:21:58 +0000 Subject: [PATCH] Support file-like object in save func --- .../soundfile_backend/save_test.py | 42 +++++++- .../sox_io_backend/save_test.py | 86 +++++++++++++++++ torchaudio/backend/_soundfile_backend.py | 20 ++-- torchaudio/backend/common.py | 17 ++++ torchaudio/backend/sox_io_backend.py | 48 ++++++---- torchaudio/csrc/pybind.cpp | 4 + torchaudio/csrc/sox/effects.cpp | 8 +- torchaudio/csrc/sox/effects_chain.cpp | 77 +++++++++++++++ torchaudio/csrc/sox/effects_chain.h | 6 ++ torchaudio/csrc/sox/io.cpp | 95 ++++++++++++++++--- torchaudio/csrc/sox/io.h | 17 +++- torchaudio/csrc/sox/utils.cpp | 46 ++++----- torchaudio/csrc/sox/utils.h | 8 +- 13 files changed, 406 insertions(+), 68 deletions(-) diff --git a/test/torchaudio_unittest/soundfile_backend/save_test.py b/test/torchaudio_unittest/soundfile_backend/save_test.py index a99c69b0c86..5c36e0d126f 100644 --- a/test/torchaudio_unittest/soundfile_backend/save_test.py +++ b/test/torchaudio_unittest/soundfile_backend/save_test.py @@ -1,9 +1,9 @@ +import io import itertools from unittest.mock import patch from torchaudio._internal import module_utils as _mod_utils from torchaudio.backend import _soundfile_backend as soundfile_backend -from parameterized import parameterized from torchaudio_unittest.common_utils import ( TempDirMixin, @@ -209,3 +209,43 @@ def test_channels_first(self, channels_first): found = load_wav(path)[0] expected = data if channels_first else data.transpose(1, 0) self.assertEqual(found, expected, atol=1e-4, rtol=1e-8) + + +@skipIfNoModule("soundfile") +class TestFileObject(TempDirMixin, PytorchTestCase): + def _test_fileobj(self, ext): + """Saving audio to file-like object works""" + sample_rate = 16000 + path = self.get_temp_path(f'test.{ext}') + + subtype = 'FLOAT' if ext == 'wav' else None + data = get_wav_data('float32', num_channels=2) + soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype) + expected = soundfile.read(path, dtype='float32')[0] + + fileobj = io.BytesIO() + soundfile_backend.save(fileobj, data, sample_rate, format=ext) + fileobj.seek(0) + found, sr = soundfile.read(fileobj, dtype='float32') + + assert sr == sample_rate + self.assertEqual(expected, found) + + def test_fileobj_wav(self): + """Saving audio via file-like object works""" + self._test_fileobj('wav') + + @skipIfFormatNotSupported("FLAC") + def test_fileobj_flac(self): + """Saving audio via file-like object works""" + self._test_fileobj('flac') + + @skipIfFormatNotSupported("NIST") + def test_fileobj_nist(self): + """Saving audio via file-like object works""" + self._test_fileobj('NIST') + + @skipIfFormatNotSupported("OGG") + def test_fileobj_ogg(self): + """Saving audio via file-like object works""" + self._test_fileobj('OGG') diff --git a/test/torchaudio_unittest/sox_io_backend/save_test.py b/test/torchaudio_unittest/sox_io_backend/save_test.py index b0ee25e01c1..da5ea4a1538 100644 --- a/test/torchaudio_unittest/sox_io_backend/save_test.py +++ b/test/torchaudio_unittest/sox_io_backend/save_test.py @@ -1,3 +1,4 @@ +import io import itertools from torchaudio.backend import sox_io_backend @@ -417,3 +418,88 @@ def test_tensor_preserve(self, dtype): sox_io_backend.save(path, data, 8000) self.assertEqual(data, expected) + + +@skipIfNoExtension +@skipIfNoExec('sox') +class TestFileLikeObject(SaveTestBase): + """ + We campare the result of file-like object input against file path input because + `save` function is rigrously tested for file path inputs to match libsox's result, + """ + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_fileobj(self, ext, compression): + """Saving audio to file object returns the same result as via file path.""" + sample_rate = 16000 + dtype = 'float32' + num_channels = 2 + num_frames = 16000 + channels_first = True + + data = get_wav_data(dtype, num_channels, num_frames=num_frames) + + ref_path = self.get_temp_path(f'reference.{ext}') + res_path = self.get_temp_path(f'test.{ext}') + sox_io_backend.save( + ref_path, data, channels_first=channels_first, + sample_rate=sample_rate, compression=compression) + with open(res_path, 'wb') as fileobj: + sox_io_backend.save( + fileobj, data, channels_first=channels_first, + sample_rate=sample_rate, compression=compression) + + expected_data, _ = sox_io_backend.load(ref_path) + data, sr = sox_io_backend.load(res_path) + + assert sample_rate == sr + self.assertEqual(expected_data, data) + + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_bytesio(self, ext, compression): + """Saving audio to BytesIO object returns the same result as via file path.""" + sample_rate = 16000 + dtype = 'float32' + num_channels = 2 + num_frames = 16000 + channels_first = True + + data = get_wav_data(dtype, num_channels, num_frames=num_frames) + + ref_path = self.get_temp_path(f'reference.{ext}') + res_path = self.get_temp_path(f'test.{ext}') + sox_io_backend.save( + ref_path, data, channels_first=channels_first, + sample_rate=sample_rate, compression=compression) + fileobj = io.BytesIO() + sox_io_backend.save( + fileobj, data, channels_first=channels_first, + sample_rate=sample_rate, compression=compression, format=ext) + fileobj.seek(0) + with open(res_path, 'wb') as file_: + file_.write(fileobj.read()) + + expected_data, _ = sox_io_backend.load(ref_path) + data, sr = sox_io_backend.load(res_path) + + assert sample_rate == sr + self.assertEqual(expected_data, data) diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index 719224b8276..a1d16c3537e 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -1,10 +1,11 @@ """The new soundfile backend which will become default in 0.8.0 onward""" +import os from typing import Tuple, Optional import warnings import torch from torchaudio._internal import module_utils as _mod_utils -from .common import AudioMetaData +from .common import AudioMetaData, get_ext if _mod_utils.is_module_available("soundfile"): @@ -138,6 +139,7 @@ def save( sample_rate: int, channels_first: bool = True, compression: Optional[float] = None, + format: Optional[str] = None, ): """Save audio data to file. @@ -168,6 +170,9 @@ def save( otherwise ``[time, channel]``. compression (Optional[float]): Not used. It is here only for interface compatibility reson with "sox_io" backend. + format (str, optional): + Output audio format. This is required when the output audio format cannot be infered from + ``filepath``, (such as file extension or ``name`` attribute of the given file object). """ if src.ndim != 2: raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.") @@ -177,7 +182,11 @@ def save( "The argument is silently ignored." ) - ext = str(filepath).split(".")[-1].lower() + try: + ext = get_ext(filepath, format) + except Exception: + raise RuntimeError('Cannot detect the output format. Provide `format` argument.') from None + if ext != "wav": subtype = None elif src.dtype == torch.uint8: @@ -193,17 +202,16 @@ def save( else: raise ValueError(f"Unsupported dtype for WAV: {src.dtype}") - format_ = None # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format, # so we extend the extensions manually here - if ext in ["nis", "nist", "sph"]: - format_ = "NIST" + if ext in ["nis", "nist", "sph"] and format is None: + format = "NIST" if channels_first: src = src.t() soundfile.write( - file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format_ + file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format ) diff --git a/torchaudio/backend/common.py b/torchaudio/backend/common.py index 135a18caee8..d5a76078b71 100644 --- a/torchaudio/backend/common.py +++ b/torchaudio/backend/common.py @@ -1,3 +1,4 @@ +import os from typing import Any, Optional from torchaudio._internal import module_utils as _mod_utils @@ -19,6 +20,22 @@ def __init__(self, sample_rate: int, num_frames: int, num_channels: int): self.num_channels = num_channels +def get_ext( + src: Any, + format: Optional[str]): + """Get the file extension from either the given format or target file information + + Args: + src (path-like object or file-like object): Target file. + format (optional, str): format provided by user. + """ + if format is not None: + return format.lower() + if hasattr(src, 'name'): + src = src.name + return os.path.splitext(src)[-1][1:].lower() + + @_mod_utils.deprecated('Please migrate to `AudioMetaData`.', '0.9.0') class SignalInfo: """One of return types of ``torchaudio.info`` functions. diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 1e6d417cb82..696466cf0f0 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -7,7 +7,7 @@ ) import torchaudio -from .common import AudioMetaData +from .common import AudioMetaData, get_ext @_mod_utils.requires_module('torchaudio._torchaudio') @@ -134,6 +134,27 @@ def load( return signal.get_tensor(), signal.get_sample_rate() +@torch.jit.unused +def _save( + filepath: str, + src: torch.Tensor, + sample_rate: int, + channels_first: bool = True, + compression: Optional[float] = None, + format: Optional[str] = None, +): + try: + ext = get_ext(filepath, format) + except Exception: + raise RuntimeError('Cannot detect the output format. Provide `format` argument.') from None + if hasattr(filepath, 'write'): + torchaudio._torchaudio.save_audio_fileobj( + filepath, src, sample_rate, channels_first, compression, ext) + else: + torch.ops.torchaudio.sox_io_save_audio_file( + os.fspath(filepath), src, sample_rate, channels_first, compression, ext) + + @_mod_utils.requires_module('torchaudio._torchaudio') def save( filepath: str, @@ -141,6 +162,7 @@ def save( sample_rate: int, channels_first: bool = True, compression: Optional[float] = None, + format: Optional[str] = None, ): """Save audio data to file. @@ -184,23 +206,15 @@ def save( | and lowest quality. Default: ``3``. See the detail at http://sox.sourceforge.net/soxformat.html. + format (str, optional): + Output audio format. This is required when the output audio format cannot be infered from + ``filepath``, (such as file extension or ``name`` attribute of the given file object). """ - # Cast to str in case type is `pathlib.Path` - filepath = str(filepath) - if compression is None: - ext = str(filepath).split('.')[-1].lower() - if ext in ['wav', 'sph', 'amb', 'amr-nb']: - compression = 0. - elif ext == 'mp3': - compression = -4.5 - elif ext == 'flac': - compression = 8. - elif ext in ['ogg', 'vorbis']: - compression = 3. - else: - raise RuntimeError(f'Unsupported file type: "{ext}"') - signal = torch.classes.torchaudio.TensorSignal(src, sample_rate, channels_first) - torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression) + if not torch.jit.is_scripting(): + _save(filepath, src, sample_rate, channels_first, compression, format) + return + torch.ops.torchaudio.sox_io_save_audio_file( + filepath, src, sample_rate, channels_first, compression, format) @_mod_utils.requires_module('torchaudio._torchaudio') diff --git a/torchaudio/csrc/pybind.cpp b/torchaudio/csrc/pybind.cpp index caf9ad9b198..eb8c30b96ac 100644 --- a/torchaudio/csrc/pybind.cpp +++ b/torchaudio/csrc/pybind.cpp @@ -100,4 +100,8 @@ PYBIND11_MODULE(_torchaudio, m) { "load_audio_fileobj", &torchaudio::sox_io::load_audio_fileobj, "Load audio from file object."); + m.def( + "save_audio_fileobj", + &torchaudio::sox_io::save_audio_fileobj, + "Save audio to file obj."); } diff --git a/torchaudio/csrc/sox/effects.cpp b/torchaudio/csrc/sox/effects.cpp index 6e36638a330..03f06c2602f 100644 --- a/torchaudio/csrc/sox/effects.cpp +++ b/torchaudio/csrc/sox/effects.cpp @@ -59,8 +59,8 @@ c10::intrusive_ptr apply_effects_tensor( // Create SoxEffectsChain const auto dtype = in_tensor.dtype(); torchaudio::sox_effects_chain::SoxEffectsChain chain( - /*input_encoding=*/get_encodinginfo("wav", dtype, 0.), - /*output_encoding=*/get_encodinginfo("wav", dtype, 0.)); + /*input_encoding=*/get_encodinginfo("wav", dtype), + /*output_encoding=*/get_encodinginfo("wav", dtype)); // Prepare output buffer std::vector out_buffer; @@ -112,7 +112,7 @@ c10::intrusive_ptr apply_effects_file( // Create and run SoxEffectsChain torchaudio::sox_effects_chain::SoxEffectsChain chain( /*input_encoding=*/sf->encoding, - /*output_encoding=*/get_encodinginfo("wav", dtype, 0.)); + /*output_encoding=*/get_encodinginfo("wav", dtype)); chain.addInputFile(sf); for (const auto& effect : effects) { @@ -193,7 +193,7 @@ std::tuple apply_effects_fileobj( const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); torchaudio::sox_effects_chain::SoxEffectsChain chain( /*input_encoding=*/sf->encoding, - /*output_encoding=*/get_encodinginfo("wav", dtype, 0.)); + /*output_encoding=*/get_encodinginfo("wav", dtype)); chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj); for (const auto& effect : effects) { chain.addEffect(effect); diff --git a/torchaudio/csrc/sox/effects_chain.cpp b/torchaudio/csrc/sox/effects_chain.cpp index b3976734395..9c365b72aa4 100644 --- a/torchaudio/csrc/sox/effects_chain.cpp +++ b/torchaudio/csrc/sox/effects_chain.cpp @@ -295,6 +295,13 @@ struct FileObjInputPriv { uint64_t buffer_size; }; +struct FileObjOutputPriv { + sox_format_t* sf; + py::object* fileobj; + char** buffer; + size_t* buffer_size; +}; + /// Callback function to feed byte string /// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278 int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { @@ -373,6 +380,45 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { return *osamp? SOX_SUCCESS : SOX_EOF; } +int fileobj_output_flow( + sox_effect_t* effp, + sox_sample_t const* ibuf, + sox_sample_t* obuf LSX_UNUSED, + size_t* isamp, + size_t* osamp) { + *osamp = 0; + if (*isamp) { + auto priv = static_cast(effp->priv); + auto sf = priv->sf; + auto fp = static_cast(sf->fp); + auto fileobj = priv->fileobj; + auto buffer = priv->buffer; + auto buffer_size = priv->buffer_size; + + // Encode chunk + auto num_samples_written = sox_write(sf, ibuf, *isamp); + fflush(fp); + + // Copy the encoded chunk to python object. + fileobj->attr("write")(py::bytes(*buffer, *buffer_size)); + + // Reset FILE* + sf->tell_off = 0; + fseek(fp, 0, SEEK_SET); + + if (num_samples_written != *isamp) { + if (sf->sox_errno) { + std::ostringstream stream; + stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " " + << sf->filename; + throw std::runtime_error(stream.str()); + } + return SOX_EOF; + } + } + return SOX_SUCCESS; +} + sox_effect_handler_t* get_fileobj_input_handler() { static sox_effect_handler_t handler{/*name=*/"input_fileobj_object", /*usage=*/NULL, @@ -387,6 +433,20 @@ sox_effect_handler_t* get_fileobj_input_handler() { return &handler; } +sox_effect_handler_t* get_fileobj_output_handler() { + static sox_effect_handler_t handler{/*name=*/"output_fileobj_object", + /*usage=*/NULL, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/NULL, + /*start=*/NULL, + /*flow=*/fileobj_output_flow, + /*drain=*/NULL, + /*stop=*/NULL, + /*kill=*/NULL, + /*priv_size=*/sizeof(FileObjOutputPriv)}; + return &handler; +} + } // namespace void SoxEffectsChain::addInputFileObj( @@ -408,6 +468,23 @@ void SoxEffectsChain::addInputFileObj( } } +void SoxEffectsChain::addOutputFileObj( + sox_format_t* sf, + char** buffer, + size_t* buffer_size, + py::object* fileobj) { + out_sig_ = sf->signal; + SoxEffect e(sox_create_effect(get_fileobj_output_handler())); + auto priv = static_cast(e->priv); + priv->sf = sf; + priv->fileobj = fileobj; + priv->buffer = buffer; + priv->buffer_size = buffer_size; + if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) { + throw std::runtime_error("Internal Error: Failed to add effect: output fileobj"); + } +} + #endif // TORCH_API_INCLUDE_EXTENSION_H } // namespace sox_effects_chain diff --git a/torchaudio/csrc/sox/effects_chain.h b/torchaudio/csrc/sox/effects_chain.h index b096b3eb3df..a4797b3486e 100644 --- a/torchaudio/csrc/sox/effects_chain.h +++ b/torchaudio/csrc/sox/effects_chain.h @@ -46,6 +46,12 @@ class SoxEffectsChain { uint64_t buffer_size, py::object* fileobj); + void addOutputFileObj( + sox_format_t* sf, + char** buffer, + size_t* buffer_size, + py::object* fileobj); + #endif // TORCH_API_INCLUDE_EXTENSION_H }; diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index e381be14f8e..52df63aa9df 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -95,26 +95,28 @@ c10::intrusive_ptr load_audio_file( } void save_audio_file( - const std::string& file_name, - const c10::intrusive_ptr& signal, - const double compression) { - auto tensor = signal->tensor; - + const std::string& path, + torch::Tensor tensor, + int64_t sample_rate, + bool channels_first, + c10::optional compression, + c10::optional format) { validate_input_tensor(tensor); - const auto filetype = get_filetype(file_name); + auto signal = TensorSignal(tensor, sample_rate, channels_first); + + const auto filetype = format.value_or(get_filetype(path)); if (filetype == "amr-nb") { - const auto num_channels = tensor.size(signal->channels_first ? 0 : 1); + const auto num_channels = tensor.size(channels_first ? 0 : 1); TORCH_CHECK( num_channels == 1, "amr-nb format only supports single channel audio."); tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16); } - const auto signal_info = get_signalinfo(signal.get(), filetype); - const auto encoding_info = - get_encodinginfo(filetype, tensor.dtype(), compression); + const auto signal_info = get_signalinfo(&signal, filetype); + const auto encoding_info = get_encodinginfo(filetype, tensor.dtype(), compression); SoxFormat sf(sox_open_write( - file_name.c_str(), + path.c_str(), &signal_info, &encoding_info, /*filetype=*/filetype.c_str(), @@ -126,9 +128,9 @@ void save_audio_file( } torchaudio::sox_effects_chain::SoxEffectsChain chain( - /*input_encoding=*/get_encodinginfo("wav", tensor.dtype(), 0.), + /*input_encoding=*/get_encodinginfo("wav", tensor.dtype()), /*output_encoding=*/sf->encoding); - chain.addInputTensor(signal.get()); + chain.addInputTensor(&signal); chain.addOutputFile(sf); chain.run(); } @@ -147,6 +149,73 @@ std::tuple load_audio_fileobj( fileobj, effects, normalize, channels_first, format); } +namespace { + +// helper class to automatically release buffer, to be used by save_audio_fileobj +struct AutoReleaseBuffer { + char* ptr; + size_t size; + + AutoReleaseBuffer() : ptr(nullptr), size(0) {} + AutoReleaseBuffer(const AutoReleaseBuffer& other) = delete; + AutoReleaseBuffer(AutoReleaseBuffer&& other) = delete; + AutoReleaseBuffer& operator=(const AutoReleaseBuffer& other) = delete; + AutoReleaseBuffer& operator=(AutoReleaseBuffer&& other) = delete; + ~AutoReleaseBuffer() { + if (ptr) { + free(ptr); + } + } +}; + +} // namespace + +void save_audio_fileobj( + py::object fileobj, + torch::Tensor tensor, + int64_t sample_rate, + bool channels_first, + c10::optional compression, + std::string filetype) { + validate_input_tensor(tensor); + + auto signal = TensorSignal(tensor, sample_rate, channels_first); + + if (filetype == "amr-nb") { + const auto num_channels = tensor.size(channels_first ? 0 : 1); + TORCH_CHECK( + num_channels == 1, "amr-nb format only supports single channel audio."); + tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16); + } + const auto signal_info = get_signalinfo(&signal, filetype); + const auto encoding_info = get_encodinginfo(filetype, tensor.dtype(), compression); + + AutoReleaseBuffer buffer; + + SoxFormat sf(sox_open_memstream_write( + &buffer.ptr, + &buffer.size, + &signal_info, + &encoding_info, + filetype.c_str(), + /*oob=*/nullptr)); + + if (static_cast(sf) == nullptr) { + throw std::runtime_error("Error saving audio file: failed to open file."); + } + + torchaudio::sox_effects_chain::SoxEffectsChain chain( + /*input_encoding=*/get_encodinginfo("wav", tensor.dtype()), + /*output_encoding=*/sf->encoding); + chain.addInputTensor(&signal); + chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj); + chain.run(); + + sf.close(); + + fileobj.attr("write")(py::bytes(buffer.ptr, buffer.size)); +} + #endif // TORCH_API_INCLUDE_EXTENSION_H } // namespace sox_io diff --git a/torchaudio/csrc/sox/io.h b/torchaudio/csrc/sox/io.h index d6e5310077a..ac7191527f8 100644 --- a/torchaudio/csrc/sox/io.h +++ b/torchaudio/csrc/sox/io.h @@ -38,9 +38,12 @@ c10::intrusive_ptr load_audio_file( c10::optional& format); void save_audio_file( - const std::string& file_name, - const c10::intrusive_ptr& signal, - const double compression = 0.); + const std::string& path, + torch::Tensor tensor, + int64_t sample_rate, + bool channels_first, + c10::optional compression, + c10::optional format); #ifdef TORCH_API_INCLUDE_EXTENSION_H @@ -52,6 +55,14 @@ std::tuple load_audio_fileobj( c10::optional& channels_first, c10::optional& format); +void save_audio_fileobj( + py::object fileobj, + torch::Tensor tensor, + int64_t sample_rate, + bool channels_first, + c10::optional compression, + std::string filetype); + #endif // TORCH_API_INCLUDE_EXTENSION_H } // namespace sox_io diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 44f00084e8d..0ea95cbbb30 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -80,11 +80,8 @@ bool TensorSignal::getChannelsFirst() const { } SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {} -SoxFormat::~SoxFormat() { - if (fd_ != nullptr) { - sox_close(fd_); - } -} +SoxFormat::~SoxFormat() { close(); } + sox_format_t* SoxFormat::operator->() const noexcept { return fd_; } @@ -92,6 +89,13 @@ SoxFormat::operator sox_format_t*() const noexcept { return fd_; } +void SoxFormat::close() { + if (fd_ != nullptr) { + sox_close(fd_); + fd_ = nullptr; + } +} + void validate_input_file(const SoxFormat& sf, bool check_length) { if (static_cast(sf) == nullptr) { throw std::runtime_error("Error loading audio file: failed to open file."); @@ -286,27 +290,23 @@ sox_signalinfo_t get_signalinfo( sox_encodinginfo_t get_encodinginfo( const std::string filetype, - const caffe2::TypeMeta dtype, - const double compression) { - const double compression_ = [&]() { - if (filetype == "mp3") - return compression; - if (filetype == "flac") - return compression; - if (filetype == "ogg" || filetype == "vorbis") - return compression; - if (filetype == "wav" || filetype == "amb") - return 0.; - if (filetype == "sph") - return 0.; - if (filetype == "amr-nb") - return 0.; - throw std::runtime_error("Unsupported file type: " + filetype); - }(); + const caffe2::TypeMeta dtype) { + return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype), + /*bits_per_sample=*/get_precision(filetype, dtype), + /*compression=*/HUGE_VAL, + /*reverse_bytes=*/sox_option_default, + /*reverse_nibbles=*/sox_option_default, + /*reverse_bits=*/sox_option_default, + /*opposite_endian=*/sox_false}; +} +sox_encodinginfo_t get_encodinginfo( + const std::string filetype, + const caffe2::TypeMeta dtype, + c10::optional& compression) { return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype), /*bits_per_sample=*/get_precision(filetype, dtype), - /*compression=*/compression_, + /*compression=*/compression.value_or(HUGE_VAL), /*reverse_bytes=*/sox_option_default, /*reverse_nibbles=*/sox_option_default, /*reverse_bits=*/sox_option_default, diff --git a/torchaudio/csrc/sox/utils.h b/torchaudio/csrc/sox/utils.h index ee8d1baa66d..2d434d6f72d 100644 --- a/torchaudio/csrc/sox/utils.h +++ b/torchaudio/csrc/sox/utils.h @@ -61,6 +61,8 @@ struct SoxFormat { sox_format_t* operator->() const noexcept; operator sox_format_t*() const noexcept; + void close(); + private: sox_format_t* fd_; }; @@ -116,10 +118,14 @@ sox_signalinfo_t get_signalinfo( const std::string filetype); /// Get sox_encofinginfo_t for saving audoi file +sox_encodinginfo_t get_encodinginfo( + const std::string filetype, + const caffe2::TypeMeta dtype); + sox_encodinginfo_t get_encodinginfo( const std::string filetype, const caffe2::TypeMeta dtype, - const double compression); + c10::optional& compression); } // namespace sox_utils } // namespace torchaudio