Skip to content

Commit

Permalink
Fix load from file object for small files and shorter bytes (#1181)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Jan 27, 2021
1 parent 41c76a1 commit 47d97e3
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 61 deletions.
75 changes: 75 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,19 @@ def test_mp3(self):
assert sr == 16000


class CloggedFileObj:
def __init__(self, fileobj):
self.fileobj = fileobj
self.buffer = b''

def read(self, n):
if not self.buffer:
self.buffer += self.fileobj.read(n)
ret = self.buffer[:2]
self.buffer = self.buffer[2:]
return ret


@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase):
Expand Down Expand Up @@ -444,6 +457,68 @@ def test_bytesio(self, ext, compression):
assert sr == sample_rate
self.assertEqual(expected, found)

@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio_clogged(self, ext, compression):
"""Loading audio via clogged file object returns the same result as via file path.
This test case validates the case where fileobject returns shorter bytes than requeted.
"""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression)
expected, _ = sox_io_backend.load(path)

with open(path, 'rb') as file_:
fileobj = CloggedFileObj(io.BytesIO(file_.read()))
found, sr = sox_io_backend.load(fileobj, format=format_)

assert sr == sample_rate
self.assertEqual(expected, found)

@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio_tiny(self, ext, compression):
"""Loading very small audio via file object returns the same result as via file path.
"""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression, duration=1 / 1600)
expected, _ = sox_io_backend.load(path)

with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_io_backend.load(fileobj, format=format_)

assert sr == sample_rate
self.assertEqual(expected, found)

@parameterized.expand([
('wav', None),
('mp3', 128),
Expand Down
84 changes: 49 additions & 35 deletions torchaudio/csrc/sox/effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,51 +137,65 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(

#ifdef TORCH_API_INCLUDE_EXTENSION_H

// Streaming decoding over file-like object is tricky because libsox operates on
// FILE pointer. The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer
//
// We want to, instead, fetch byte strings chunk by chunk, consume them, and
// discard.
//
// Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial
// chunk of byte string
// This will perform header-based format detection, if necessary, then fill
// the metadata of sox_format_t. Internally, sox_open_mem_read uses fmemopen,
// which returns FILE* which points the buffer of the provided byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying
// buffer in a way that it
// starts with unseen data, and append the new data read from the given
// fileobj. This will trick libsox as if it keeps reading from the FILE*
// continuously.
// For Step 2. see `fileobj_input_drain` function in effects_chain.cpp
std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
py::object fileobj,
std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
// Streaming decoding over file-like object is tricky because libsox operates
// on FILE pointer. The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer
// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
//
// We want to, instead, fetch byte strings chunk by chunk, consume them, and
// discard.
// For certain format (such as FLAC), libsox keeps reading the content at
// the initialization unless it reaches EOF even when the header is properly
// parsed. (Making buffer size 8192, which is way bigger than the header,
// resulted in libsox consuming all the buffer content at the time it opens
// the file.) Therefore buffer has to always contain valid data, except after
// EOF. We default to `sox_get_globals()->bufsiz`* for buffer size and we
// first check if there is enough data to fill the buffer. `read_fileobj`
// repeatedly calls `read` method until it receives the requested lenght of
// bytes or it reaches EOF. If we get bytes shorter than requested, that means
// the whole audio data are fetched.
//
// Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial
// chunk of byte string
// This will perform header-based format detection, if necessary, then fill
// the metadata of sox_format_t. Internally, sox_open_mem_read uses
// fmemopen, which returns FILE* which points the buffer of the provided
// byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying
// buffer in a way that it
// starts with unseen data, and append the new data read from the given
// fileobj. This will trick libsox as if it keeps reading from the FILE*
// continuously.

// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
// Using std::string and let it manage memory.
// 4096 is minimum size requried by auto_detect_format
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L40-L48
const size_t in_buffer_size = 4096;
std::string in_buffer(in_buffer_size, 'x');
auto* in_buf = const_cast<char*>(in_buffer.data());

// Fetch the header, and copy it to the buffer.
auto header = static_cast<std::string>(
static_cast<py::bytes>(fileobj.attr("read")(4096)));
memcpy(
static_cast<void*>(in_buf),
static_cast<void*>(const_cast<char*>(header.data())),
header.length());
// * This can be changed with `torchaudio.utils.sox_utils.set_buffer_size`.
auto capacity =
(sox_get_globals()->bufsiz > 256) ? sox_get_globals()->bufsiz : 256;
std::string buffer(capacity, '\0');
auto* in_buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, in_buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto in_buffer_size = (num_read > 256) ? num_read : 256;

// Open file (this starts reading the header)
// When opening a file there are two functions that can touches FILE*.
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
SoxFormat sf(sox_open_mem_read(
in_buf,
in_buffer_size,
Expand Down
45 changes: 22 additions & 23 deletions torchaudio/csrc/sox/effects_chain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ namespace {
struct FileObjInputPriv {
sox_format_t* sf;
py::object* fileobj;
bool eof_reached;
char* buffer;
uint64_t buffer_size;
};
Expand All @@ -312,9 +313,7 @@ struct FileObjOutputPriv {
int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
auto priv = static_cast<FileObjInputPriv*>(effp->priv);
auto sf = priv->sf;
auto fileobj = priv->fileobj;
auto buffer = priv->buffer;
auto buffer_size = priv->buffer_size;

// 1. Refresh the buffer
//
Expand All @@ -326,32 +325,30 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
//
// Before:
//
// |<--------consumed------->|<-remaining->|
// |*************************|-------------|
// ^ ftell
// |<-------consumed------>|<---remaining--->|
// |***********************|-----------------|
// ^ ftell
//
// After:
//
// |<-offset->|<-remaining->|<--new data-->|
// |**********|-------------|++++++++++++++|
// |<-offset->|<---remaining--->|<-new data->|
// |**********|-----------------|++++++++++++|
// ^ ftell

const auto num_consumed = sf->tell_off;
const auto num_remain = buffer_size - num_consumed;

// 1.1. First, we fetch the data to see if there is data to fill the buffer
py::bytes chunk_ = fileobj->attr("read")(num_consumed);
const auto num_refill = py::len(chunk_);
const auto offset = buffer_size - (num_remain + num_refill);

if (num_refill > num_consumed) {
std::ostringstream message;
message
<< "Tried to read up to " << num_consumed << " bytes but, "
<< "recieved " << num_refill << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
const auto num_remain = priv->buffer_size - num_consumed;

// 1.1. Fetch the data to see if there is data to fill the buffer
uint64_t num_refill = 0;
std::string chunk(num_consumed, '\0');
if (num_consumed && !priv->eof_reached) {
num_refill = read_fileobj(
priv->fileobj, num_consumed, const_cast<char*>(chunk.data()));
if (num_refill < num_consumed) {
priv->eof_reached = true;
}
}
const auto offset = num_consumed - num_refill;

// 1.2. Move the unconsumed data towards the beginning of buffer.
if (num_remain) {
Expand All @@ -362,7 +359,6 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {

// 1.3. Refill the remaining buffer.
if (num_refill) {
auto chunk = static_cast<std::string>(chunk_);
auto src = static_cast<void*>(const_cast<char*>(chunk.c_str()));
auto dst = buffer + offset + num_remain;
memcpy(dst, src, num_refill);
Expand All @@ -383,7 +379,9 @@ int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// store the actual number read back to *osamp
*osamp = sox_read(sf, obuf, *osamp);

return *osamp ? SOX_SUCCESS : SOX_EOF;
// Decoding is finished when fileobject is exhausted and sox can no longer
// decode a sample.
return (priv->eof_reached && !*osamp) ? SOX_EOF : SOX_SUCCESS;
}

int fileobj_output_flow(
Expand Down Expand Up @@ -469,6 +467,7 @@ void SoxEffectsChain::addInputFileObj(
auto priv = static_cast<FileObjInputPriv*>(e->priv);
priv->sf = sf;
priv->fileobj = fileobj;
priv->eof_reached = false;
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
Expand Down
3 changes: 0 additions & 3 deletions torchaudio/csrc/sox/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,6 @@ uint64_t read_fileobj(py::object* fileobj, const uint64_t size, char* buffer) {
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}

std::cerr << "req: " << request << ", fetched: " << chunk_len << std::endl;
std::cerr << "buffer: " << (void*)buffer << std::endl;
memcpy(buffer, chunk.data(), chunk_len);
buffer += chunk_len;
num_read += chunk_len;
Expand Down

0 comments on commit 47d97e3

Please sign in to comment.