Skip to content

Commit

Permalink
Support file-like object in save func
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 12, 2021
1 parent 72b7680 commit ddac5d7
Show file tree
Hide file tree
Showing 13 changed files with 410 additions and 68 deletions.
42 changes: 41 additions & 1 deletion test/torchaudio_unittest/soundfile_backend/save_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import io
import itertools
from unittest.mock import patch

from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import _soundfile_backend as soundfile_backend
from parameterized import parameterized

from torchaudio_unittest.common_utils import (
TempDirMixin,
Expand Down Expand Up @@ -209,3 +209,43 @@ def test_channels_first(self, channels_first):
found = load_wav(path)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)


@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext):
"""Saving audio to file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')

subtype = 'FLOAT' if ext == 'wav' else None
data = get_wav_data('float32', num_channels=2)
soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype)
expected = soundfile.read(path, dtype='float32')[0]

fileobj = io.BytesIO()
soundfile_backend.save(fileobj, data, sample_rate, format=ext)
fileobj.seek(0)
found, sr = soundfile.read(fileobj, dtype='float32')

assert sr == sample_rate
self.assertEqual(expected, found)

def test_fileobj_wav(self):
"""Saving audio via file-like object works"""
self._test_fileobj('wav')

@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Saving audio via file-like object works"""
self._test_fileobj('flac')

@skipIfFormatNotSupported("NIST")
def test_fileobj_nist(self):
"""Saving audio via file-like object works"""
self._test_fileobj('NIST')

@skipIfFormatNotSupported("OGG")
def test_fileobj_ogg(self):
"""Saving audio via file-like object works"""
self._test_fileobj('OGG')
86 changes: 86 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/save_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import itertools

from torchaudio.backend import sox_io_backend
Expand Down Expand Up @@ -417,3 +418,88 @@ def test_tensor_preserve(self, dtype):
sox_io_backend.save(path, data, 8000)

self.assertEqual(data, expected)


@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(SaveTestBase):
"""
We campare the result of file-like object input against file path input because
`save` function is rigrously tested for file path inputs to match libsox's result,
"""
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_fileobj(self, ext, compression):
"""Saving audio to file object returns the same result as via file path."""
sample_rate = 16000
dtype = 'float32'
num_channels = 2
num_frames = 16000
channels_first = True

data = get_wav_data(dtype, num_channels, num_frames=num_frames)

ref_path = self.get_temp_path(f'reference.{ext}')
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression)
with open(res_path, 'wb') as fileobj:
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression)

expected_data, _ = sox_io_backend.load(ref_path)
data, sr = sox_io_backend.load(res_path)

assert sample_rate == sr
self.assertEqual(expected_data, data)

@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio(self, ext, compression):
"""Saving audio to BytesIO object returns the same result as via file path."""
sample_rate = 16000
dtype = 'float32'
num_channels = 2
num_frames = 16000
channels_first = True

data = get_wav_data(dtype, num_channels, num_frames=num_frames)

ref_path = self.get_temp_path(f'reference.{ext}')
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression)
fileobj = io.BytesIO()
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, format=ext)
fileobj.seek(0)
with open(res_path, 'wb') as file_:
file_.write(fileobj.read())

expected_data, _ = sox_io_backend.load(ref_path)
data, sr = sox_io_backend.load(res_path)

assert sample_rate == sr
self.assertEqual(expected_data, data)
20 changes: 14 additions & 6 deletions torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""The new soundfile backend which will become default in 0.8.0 onward"""
import os
from typing import Tuple, Optional
import warnings

import torch
from torchaudio._internal import module_utils as _mod_utils
from .common import AudioMetaData
from .common import AudioMetaData, get_ext


if _mod_utils.is_module_available("soundfile"):
Expand Down Expand Up @@ -138,6 +139,7 @@ def save(
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
):
"""Save audio data to file.
Expand Down Expand Up @@ -168,6 +170,9 @@ def save(
otherwise ``[time, channel]``.
compression (Optional[float]):
Not used. It is here only for interface compatibility reson with "sox_io" backend.
format (str, optional):
Output audio format. This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
"""
if src.ndim != 2:
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
Expand All @@ -177,7 +182,11 @@ def save(
"The argument is silently ignored."
)

ext = str(filepath).split(".")[-1].lower()
try:
ext = get_ext(filepath, format)
except Exception:
raise RuntimeError('Cannot detect the output format. Provide `format` argument.') from None

if ext != "wav":
subtype = None
elif src.dtype == torch.uint8:
Expand All @@ -193,17 +202,16 @@ def save(
else:
raise ValueError(f"Unsupported dtype for WAV: {src.dtype}")

format_ = None
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# so we extend the extensions manually here
if ext in ["nis", "nist", "sph"]:
format_ = "NIST"
if ext in ["nis", "nist", "sph"] and format is None:
format = "NIST"

if channels_first:
src = src.t()

soundfile.write(
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format_
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format
)


Expand Down
17 changes: 17 additions & 0 deletions torchaudio/backend/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Any, Optional

from torchaudio._internal import module_utils as _mod_utils
Expand All @@ -19,6 +20,22 @@ def __init__(self, sample_rate: int, num_frames: int, num_channels: int):
self.num_channels = num_channels


def get_ext(
src: Any,
format: Optional[str]):
"""Get the file extension from either the given format or target file information
Args:
src (path-like object or file-like object): Target file.
format (optional, str): format provided by user.
"""
if format is not None:
return format.lower()
if hasattr(src, 'name'):
src = src.name
return os.path.splitext(src)[-1][1:].lower()


@_mod_utils.deprecated('Please migrate to `AudioMetaData`.', '0.9.0')
class SignalInfo:
"""One of return types of ``torchaudio.info`` functions.
Expand Down
48 changes: 31 additions & 17 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)

import torchaudio
from .common import AudioMetaData
from .common import AudioMetaData, get_ext


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand Down Expand Up @@ -134,13 +134,35 @@ def load(
return signal.get_tensor(), signal.get_sample_rate()


@torch.jit.unused
def _save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
):
try:
ext = get_ext(filepath, format)
except Exception:
raise RuntimeError('Cannot detect the output format. Provide `format` argument.') from None
if hasattr(filepath, 'write'):
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, ext)
else:
torch.ops.torchaudio.sox_io_save_audio_file(
os.fspath(filepath), src, sample_rate, channels_first, compression, ext)


@_mod_utils.requires_module('torchaudio._torchaudio')
def save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
):
"""Save audio data to file.
Expand Down Expand Up @@ -184,23 +206,15 @@ def save(
| and lowest quality. Default: ``3``.
See the detail at http://sox.sourceforge.net/soxformat.html.
format (str, optional):
Output audio format. This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if compression is None:
ext = str(filepath).split('.')[-1].lower()
if ext in ['wav', 'sph', 'amb', 'amr-nb']:
compression = 0.
elif ext == 'mp3':
compression = -4.5
elif ext == 'flac':
compression = 8.
elif ext in ['ogg', 'vorbis']:
compression = 3.
else:
raise RuntimeError(f'Unsupported file type: "{ext}"')
signal = torch.classes.torchaudio.TensorSignal(src, sample_rate, channels_first)
torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression)
if not torch.jit.is_scripting():
_save(filepath, src, sample_rate, channels_first, compression, format)
return
torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format)


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,8 @@ PYBIND11_MODULE(_torchaudio, m) {
"load_audio_fileobj",
&torchaudio::sox_io::load_audio_fileobj,
"Load audio from file object.");
m.def(
"save_audio_fileobj",
&torchaudio::sox_io::save_audio_fileobj,
"Save audio to file obj.");
}
8 changes: 4 additions & 4 deletions torchaudio/csrc/sox/effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor(
// Create SoxEffectsChain
const auto dtype = in_tensor.dtype();
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_encodinginfo("wav", dtype, 0.),
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
/*input_encoding=*/get_encodinginfo("wav", dtype),
/*output_encoding=*/get_encodinginfo("wav", dtype));

// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
Expand Down Expand Up @@ -112,7 +112,7 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
// Create and run SoxEffectsChain
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
/*output_encoding=*/get_encodinginfo("wav", dtype));

chain.addInputFile(sf);
for (const auto& effect : effects) {
Expand Down Expand Up @@ -193,7 +193,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
/*output_encoding=*/get_encodinginfo("wav", dtype));
chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj);
for (const auto& effect : effects) {
chain.addEffect(effect);
Expand Down
Loading

0 comments on commit ddac5d7

Please sign in to comment.