From 4a1640ded27dcde7b197cf60916e8a02101f5d2f Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 22 Jan 2021 01:37:14 +0000 Subject: [PATCH] Handle the case where fileobject returns bytes shorter than requested --- torchaudio/csrc/sox/effects.cpp | 35 ++++++++++++++++++--------- torchaudio/csrc/sox/effects_chain.cpp | 35 ++++++++++++++------------- torchaudio/csrc/sox/utils.cpp | 28 +++++++++++++++++++++ torchaudio/csrc/sox/utils.h | 10 ++++++++ 4 files changed, 80 insertions(+), 28 deletions(-) diff --git a/torchaudio/csrc/sox/effects.cpp b/torchaudio/csrc/sox/effects.cpp index 03f06c2602f..ac3ba88f5f7 100644 --- a/torchaudio/csrc/sox/effects.cpp +++ b/torchaudio/csrc/sox/effects.cpp @@ -162,19 +162,32 @@ std::tuple apply_effects_fileobj( // This will trick libsox as if it keeps reading from the FILE* continuously. // Prepare the buffer used throughout the lifecycle of SoxEffectChain. - // Using std::string and let it manage memory. - // 4096 is minimum size requried by auto_detect_format - // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L40-L48 - const size_t in_buffer_size = 4096; - std::string in_buffer(in_buffer_size, 'x'); - auto* in_buf = const_cast(in_buffer.data()); - - // Fetch the header, and copy it to the buffer. - auto header = static_cast(static_cast(fileobj.attr("read")(4096))); - memcpy(static_cast(in_buf), - static_cast(const_cast(header.data())), header.length()); + // For FLAC format, libsox will try to read as much as possible at initialization, therefore the buffer + // has to be filled with valid data content. + // We default to `sox_get_globals()->bufsiz` for buffer size, which can be changed with + // `sox_utils.set_buffer_size`, but if the bytes object returned by Python object is shorter than that, + // we use the length of returned bytes as buffer size. + // Since read_fileobj function repeatedly call `read` method until there is requested bytes or + // it reaches EOF, smaller buffer size means the whole file is smaller than the buffer size. + // If the fetched chunk is shorter than 256, we pad it with null byte, so that libsox can correctly + // read the header + std::string buffer = read_fileobj(&fileobj, sox_get_globals()->bufsiz); + if (buffer.length() < 256) { + buffer += std::string(256 - buffer.length(), '\0'); + } + auto* in_buf = const_cast(buffer.data()); + auto in_buffer_size = buffer.length(); // Open file (this starts reading the header) + // When opening a file there are two functions that can touches FILE*. + // * `auto_detect_format` + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43 + // * `startread` handler of detected format. + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574 + // To see the handler of a particular format, go to + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/.c + // For example, voribs can be found + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158 SoxFormat sf(sox_open_mem_read( in_buf, in_buffer_size, diff --git a/torchaudio/csrc/sox/effects_chain.cpp b/torchaudio/csrc/sox/effects_chain.cpp index 9c365b72aa4..66d4674ea3f 100644 --- a/torchaudio/csrc/sox/effects_chain.cpp +++ b/torchaudio/csrc/sox/effects_chain.cpp @@ -291,6 +291,7 @@ namespace { struct FileObjInputPriv { sox_format_t* sf; py::object* fileobj; + bool read_finished; char* buffer; uint64_t buffer_size; }; @@ -307,9 +308,7 @@ struct FileObjOutputPriv { int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { auto priv = static_cast(effp->priv); auto sf = priv->sf; - auto fileobj = priv->fileobj; auto buffer = priv->buffer; - auto buffer_size = priv->buffer_size; // 1. Refresh the buffer // @@ -332,20 +331,21 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { // ^ ftell const auto num_consumed = sf->tell_off; - const auto num_remain = buffer_size - num_consumed; - - // 1.1. First, we fetch the data to see if there is data to fill the buffer - py::bytes chunk_ = fileobj->attr("read")(num_consumed); - const auto num_refill = py::len(chunk_); - const auto offset = buffer_size - (num_remain + num_refill); - - if(num_refill > num_consumed) { - std::ostringstream message; - message << "Tried to read up to " << num_consumed << " bytes but, " - << "recieved " << num_refill << " bytes. " - << "The given object does not confirm to read protocol of file object."; - throw std::runtime_error(message.str()); + const auto num_remain = priv->buffer_size - num_consumed; + + // 1.1. Fetch the data to see if there is data to fill the buffer + uint64_t num_refill = 0; + std::string chunk; + if (num_consumed && !priv->read_finished) { + chunk = read_fileobj(priv->fileobj, num_consumed); + num_refill = chunk.length(); + if (!num_refill) { + // https://docs.python.org/3/library/io.html#io.BufferedIOBase.read + // > An empty bytes object is returned if the stream is already at EOF. + priv->read_finished = true; + } } + const auto offset = num_consumed - num_refill; // 1.2. Move the unconsumed data towards the beginning of buffer. if (num_remain) { @@ -356,7 +356,6 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { // 1.3. Refill the remaining buffer. if (num_refill) { - auto chunk = static_cast(chunk_); auto src = static_cast(const_cast(chunk.c_str())); auto dst = buffer + offset + num_remain; memcpy(dst, src, num_refill); @@ -377,7 +376,8 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { // store the actual number read back to *osamp *osamp = sox_read(sf, obuf, *osamp); - return *osamp? SOX_SUCCESS : SOX_EOF; + // Decoding is finished when fileobject is exhausted and sox can no longer decode a sample. + return (priv->read_finished && !*osamp) ? SOX_EOF : SOX_SUCCESS; } int fileobj_output_flow( @@ -461,6 +461,7 @@ void SoxEffectsChain::addInputFileObj( auto priv = static_cast(e->priv); priv->sf = sf; priv->fileobj = fileobj; + priv->read_finished = false; priv->buffer = buffer; priv->buffer_size = buffer_size; if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 0ea95cbbb30..686b7e16836 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -313,5 +313,33 @@ sox_encodinginfo_t get_encodinginfo( /*opposite_endian=*/sox_false}; } +#ifdef TORCH_API_INCLUDE_EXTENSION_H + +std::string read_fileobj(py::object* fileobj, uint64_t size) { + std::string bytes; + bytes.reserve(size+1); + + uint64_t remain = size; + while(remain > 0) { + auto chunk = static_cast(static_cast(fileobj->attr("read")(remain))); + auto chunk_len = chunk.length(); + if (chunk_len ==0) { + break; + } else if (chunk_len > remain) { + std::ostringstream message; + message << "Tried to read up to " << remain << " bytes but, " + << "received " << chunk_len << " bytes. " + << "The given object does not confirm to read protocol of file object."; + throw std::runtime_error(message.str()); + } + bytes += chunk; + remain -= chunk_len; + } + std::cout << " - fetched: " << bytes.length() << std::endl; + return bytes; +} + +#endif // TORCH_API_INCLUDE_EXTENSION_H + } // namespace sox_utils } // namespace torchaudio diff --git a/torchaudio/csrc/sox/utils.h b/torchaudio/csrc/sox/utils.h index 2d434d6f72d..659c6e76683 100644 --- a/torchaudio/csrc/sox/utils.h +++ b/torchaudio/csrc/sox/utils.h @@ -4,6 +4,10 @@ #include #include +#ifdef TORCH_API_INCLUDE_EXTENSION_H +#include +#endif // TORCH_API_INCLUDE_EXTENSION_H + namespace torchaudio { namespace sox_utils { @@ -127,6 +131,12 @@ sox_encodinginfo_t get_encodinginfo( const caffe2::TypeMeta dtype, c10::optional& compression); +#ifdef TORCH_API_INCLUDE_EXTENSION_H + +std::string read_fileobj(py::object* fileobj, uint64_t size); + +#endif // TORCH_API_INCLUDE_EXTENSION_H + } // namespace sox_utils } // namespace torchaudio #endif