Skip to content

Commit

Permalink
Handle the case where fileobject returns bytes shorter than requested
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 22, 2021
1 parent 6ed4217 commit 4a1640d
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 28 deletions.
35 changes: 24 additions & 11 deletions torchaudio/csrc/sox/effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,32 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
// This will trick libsox as if it keeps reading from the FILE* continuously.

// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
// Using std::string and let it manage memory.
// 4096 is minimum size requried by auto_detect_format
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L40-L48
const size_t in_buffer_size = 4096;
std::string in_buffer(in_buffer_size, 'x');
auto* in_buf = const_cast<char*>(in_buffer.data());

// Fetch the header, and copy it to the buffer.
auto header = static_cast<std::string>(static_cast<py::bytes>(fileobj.attr("read")(4096)));
memcpy(static_cast<void*>(in_buf),
static_cast<void*>(const_cast<char*>(header.data())), header.length());
// For FLAC format, libsox will try to read as much as possible at initialization, therefore the buffer
// has to be filled with valid data content.
// We default to `sox_get_globals()->bufsiz` for buffer size, which can be changed with
// `sox_utils.set_buffer_size`, but if the bytes object returned by Python object is shorter than that,
// we use the length of returned bytes as buffer size.
// Since read_fileobj function repeatedly call `read` method until there is requested bytes or
// it reaches EOF, smaller buffer size means the whole file is smaller than the buffer size.
// If the fetched chunk is shorter than 256, we pad it with null byte, so that libsox can correctly
// read the header
std::string buffer = read_fileobj(&fileobj, sox_get_globals()->bufsiz);
if (buffer.length() < 256) {
buffer += std::string(256 - buffer.length(), '\0');
}
auto* in_buf = const_cast<char*>(buffer.data());
auto in_buffer_size = buffer.length();

// Open file (this starts reading the header)
// When opening a file there are two functions that can touches FILE*.
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
SoxFormat sf(sox_open_mem_read(
in_buf,
in_buffer_size,
Expand Down
35 changes: 18 additions & 17 deletions torchaudio/csrc/sox/effects_chain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ namespace {
struct FileObjInputPriv {
sox_format_t* sf;
py::object* fileobj;
bool read_finished;
char* buffer;
uint64_t buffer_size;
};
Expand All @@ -307,9 +308,7 @@ struct FileObjOutputPriv {
int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
auto priv = static_cast<FileObjInputPriv *>(effp->priv);
auto sf = priv->sf;
auto fileobj = priv->fileobj;
auto buffer = priv->buffer;
auto buffer_size = priv->buffer_size;

// 1. Refresh the buffer
//
Expand All @@ -332,20 +331,21 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// ^ ftell

const auto num_consumed = sf->tell_off;
const auto num_remain = buffer_size - num_consumed;

// 1.1. First, we fetch the data to see if there is data to fill the buffer
py::bytes chunk_ = fileobj->attr("read")(num_consumed);
const auto num_refill = py::len(chunk_);
const auto offset = buffer_size - (num_remain + num_refill);

if(num_refill > num_consumed) {
std::ostringstream message;
message << "Tried to read up to " << num_consumed << " bytes but, "
<< "recieved " << num_refill << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
const auto num_remain = priv->buffer_size - num_consumed;

// 1.1. Fetch the data to see if there is data to fill the buffer
uint64_t num_refill = 0;
std::string chunk;
if (num_consumed && !priv->read_finished) {
chunk = read_fileobj(priv->fileobj, num_consumed);
num_refill = chunk.length();
if (!num_refill) {
// https://docs.python.org/3/library/io.html#io.BufferedIOBase.read
// > An empty bytes object is returned if the stream is already at EOF.
priv->read_finished = true;
}
}
const auto offset = num_consumed - num_refill;

// 1.2. Move the unconsumed data towards the beginning of buffer.
if (num_remain) {
Expand All @@ -356,7 +356,6 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {

// 1.3. Refill the remaining buffer.
if (num_refill) {
auto chunk = static_cast<std::string>(chunk_);
auto src = static_cast<void*>(const_cast<char*>(chunk.c_str()));
auto dst = buffer + offset + num_remain;
memcpy(dst, src, num_refill);
Expand All @@ -377,7 +376,8 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// store the actual number read back to *osamp
*osamp = sox_read(sf, obuf, *osamp);

return *osamp? SOX_SUCCESS : SOX_EOF;
// Decoding is finished when fileobject is exhausted and sox can no longer decode a sample.
return (priv->read_finished && !*osamp) ? SOX_EOF : SOX_SUCCESS;
}

int fileobj_output_flow(
Expand Down Expand Up @@ -461,6 +461,7 @@ void SoxEffectsChain::addInputFileObj(
auto priv = static_cast<FileObjInputPriv*>(e->priv);
priv->sf = sf;
priv->fileobj = fileobj;
priv->read_finished = false;
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
Expand Down
28 changes: 28 additions & 0 deletions torchaudio/csrc/sox/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,5 +313,33 @@ sox_encodinginfo_t get_encodinginfo(
/*opposite_endian=*/sox_false};
}

#ifdef TORCH_API_INCLUDE_EXTENSION_H

std::string read_fileobj(py::object* fileobj, uint64_t size) {
std::string bytes;
bytes.reserve(size+1);

uint64_t remain = size;
while(remain > 0) {
auto chunk = static_cast<std::string>(static_cast<py::bytes>(fileobj->attr("read")(remain)));
auto chunk_len = chunk.length();
if (chunk_len ==0) {
break;
} else if (chunk_len > remain) {
std::ostringstream message;
message << "Tried to read up to " << remain << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
bytes += chunk;
remain -= chunk_len;
}
std::cout << " - fetched: " << bytes.length() << std::endl;
return bytes;
}

#endif // TORCH_API_INCLUDE_EXTENSION_H

} // namespace sox_utils
} // namespace torchaudio
10 changes: 10 additions & 0 deletions torchaudio/csrc/sox/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include <sox.h>
#include <torch/script.h>

#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H

namespace torchaudio {
namespace sox_utils {

Expand Down Expand Up @@ -127,6 +131,12 @@ sox_encodinginfo_t get_encodinginfo(
const caffe2::TypeMeta dtype,
c10::optional<double>& compression);

#ifdef TORCH_API_INCLUDE_EXTENSION_H

std::string read_fileobj(py::object* fileobj, uint64_t size);

#endif // TORCH_API_INCLUDE_EXTENSION_H

} // namespace sox_utils
} // namespace torchaudio
#endif

0 comments on commit 4a1640d

Please sign in to comment.