Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support file-like object in save func #1141

Merged
merged 3 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -456,7 +455,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: docker run -t --gpus all -e UPLOAD_CHANNEL -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -569,7 +567,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -606,7 +603,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Run style check
command: .circleci/unittest/linux/scripts/run_style_checks.sh
Expand Down
4 changes: 0 additions & 4 deletions .circleci/config.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -456,7 +455,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: docker run -t --gpus all -e UPLOAD_CHANNEL -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -569,7 +567,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Install torchaudio
command: .circleci/unittest/linux/scripts/install.sh
Expand Down Expand Up @@ -606,7 +603,6 @@ jobs:
paths:
- conda
- env
- third_party/install
- run:
name: Run style check
command: .circleci/unittest/linux/scripts/run_style_checks.sh
Expand Down
42 changes: 41 additions & 1 deletion test/torchaudio_unittest/soundfile_backend/save_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import io
import itertools
from unittest.mock import patch

from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import _soundfile_backend as soundfile_backend
from parameterized import parameterized

from torchaudio_unittest.common_utils import (
TempDirMixin,
Expand Down Expand Up @@ -209,3 +209,43 @@ def test_channels_first(self, channels_first):
found = load_wav(path)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)


@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext):
"""Saving audio to file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')

subtype = 'FLOAT' if ext == 'wav' else None
data = get_wav_data('float32', num_channels=2)
soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype)
expected = soundfile.read(path, dtype='float32')[0]

fileobj = io.BytesIO()
soundfile_backend.save(fileobj, data, sample_rate, format=ext)
fileobj.seek(0)
found, sr = soundfile.read(fileobj, dtype='float32')

assert sr == sample_rate
self.assertEqual(expected, found)

def test_fileobj_wav(self):
"""Saving audio via file-like object works"""
self._test_fileobj('wav')

@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Saving audio via file-like object works"""
self._test_fileobj('flac')

@skipIfFormatNotSupported("NIST")
def test_fileobj_nist(self):
"""Saving audio via file-like object works"""
self._test_fileobj('NIST')

@skipIfFormatNotSupported("OGG")
def test_fileobj_ogg(self):
"""Saving audio via file-like object works"""
self._test_fileobj('OGG')
86 changes: 86 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/save_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import itertools

from torchaudio.backend import sox_io_backend
Expand Down Expand Up @@ -417,3 +418,88 @@ def test_tensor_preserve(self, dtype):
sox_io_backend.save(path, data, 8000)

self.assertEqual(data, expected)


@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(SaveTestBase):
"""
We campare the result of file-like object input against file path input because
`save` function is rigrously tested for file path inputs to match libsox's result,
"""
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_fileobj(self, ext, compression):
"""Saving audio to file object returns the same result as via file path."""
sample_rate = 16000
dtype = 'float32'
num_channels = 2
num_frames = 16000
channels_first = True

data = get_wav_data(dtype, num_channels, num_frames=num_frames)

ref_path = self.get_temp_path(f'reference.{ext}')
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression)
with open(res_path, 'wb') as fileobj:
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, format=ext)

expected_data, _ = sox_io_backend.load(ref_path)
data, sr = sox_io_backend.load(res_path)

assert sample_rate == sr
self.assertEqual(expected_data, data)

@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio(self, ext, compression):
"""Saving audio to BytesIO object returns the same result as via file path."""
sample_rate = 16000
dtype = 'float32'
num_channels = 2
num_frames = 16000
channels_first = True

data = get_wav_data(dtype, num_channels, num_frames=num_frames)

ref_path = self.get_temp_path(f'reference.{ext}')
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression)
fileobj = io.BytesIO()
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, format=ext)
fileobj.seek(0)
with open(res_path, 'wb') as file_:
file_.write(fileobj.read())

expected_data, _ = sox_io_backend.load(ref_path)
data, sr = sox_io_backend.load(res_path)

assert sample_rate == sr
self.assertEqual(expected_data, data)
18 changes: 13 additions & 5 deletions torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def save(
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
):
"""Save audio data to file.

Expand Down Expand Up @@ -168,6 +169,9 @@ def save(
otherwise ``[time, channel]``.
compression (Optional[float]):
Not used. It is here only for interface compatibility reson with "sox_io" backend.
format (str, optional):
Output audio format. This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
"""
if src.ndim != 2:
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
Expand All @@ -176,8 +180,13 @@ def save(
'`save` function of "soundfile" backend does not support "compression" parameter. '
"The argument is silently ignored."
)
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
ext = format
else:
ext = str(filepath).split(".")[-1].lower()

ext = str(filepath).split(".")[-1].lower()
if ext != "wav":
subtype = None
elif src.dtype == torch.uint8:
Expand All @@ -193,17 +202,16 @@ def save(
else:
raise ValueError(f"Unsupported dtype for WAV: {src.dtype}")

format_ = None
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# so we extend the extensions manually here
if ext in ["nis", "nist", "sph"]:
format_ = "NIST"
if ext in ["nis", "nist", "sph"] and format is None:
format = "NIST"

if channels_first:
src = src.t()

soundfile.write(
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format_
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format
)


Expand Down
44 changes: 28 additions & 16 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,33 @@ def load(
return signal.get_tensor(), signal.get_sample_rate()


@torch.jit.unused
def _save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
):
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, format)
else:
torch.ops.torchaudio.sox_io_save_audio_file(
os.fspath(filepath), src, sample_rate, channels_first, compression, format)


@_mod_utils.requires_module('torchaudio._torchaudio')
def save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
):
"""Save audio data to file.

Expand Down Expand Up @@ -184,23 +204,15 @@ def save(
| and lowest quality. Default: ``3``.

See the detail at http://sox.sourceforge.net/soxformat.html.
format (str, optional):
Output audio format. This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if compression is None:
ext = str(filepath).split('.')[-1].lower()
if ext in ['wav', 'sph', 'amb', 'amr-nb']:
compression = 0.
elif ext == 'mp3':
compression = -4.5
elif ext == 'flac':
compression = 8.
elif ext in ['ogg', 'vorbis']:
compression = 3.
else:
raise RuntimeError(f'Unsupported file type: "{ext}"')
signal = torch.classes.torchaudio.TensorSignal(src, sample_rate, channels_first)
torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression)
if not torch.jit.is_scripting():
_save(filepath, src, sample_rate, channels_first, compression, format)
return
torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format)


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,8 @@ PYBIND11_MODULE(_torchaudio, m) {
"load_audio_fileobj",
&torchaudio::sox_io::load_audio_fileobj,
"Load audio from file object.");
m.def(
"save_audio_fileobj",
&torchaudio::sox_io::save_audio_fileobj,
"Save audio to file obj.");
}
8 changes: 4 additions & 4 deletions torchaudio/csrc/sox/effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor(
// Create SoxEffectsChain
const auto dtype = in_tensor.dtype();
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_encodinginfo("wav", dtype, 0.),
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
/*input_encoding=*/get_encodinginfo("wav", dtype),
/*output_encoding=*/get_encodinginfo("wav", dtype));

// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
Expand Down Expand Up @@ -112,7 +112,7 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
// Create and run SoxEffectsChain
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
/*output_encoding=*/get_encodinginfo("wav", dtype));

chain.addInputFile(sf);
for (const auto& effect : effects) {
Expand Down Expand Up @@ -193,7 +193,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
/*output_encoding=*/get_encodinginfo("wav", dtype));
chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj);
for (const auto& effect : effects) {
chain.addEffect(effect);
Expand Down
Loading