Skip to content

Commit

Permalink
Support file-like object in apply_effects_file
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 27, 2021
1 parent 47d97e3 commit 30fc490
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 5 deletions.
161 changes: 160 additions & 1 deletion test/torchaudio_unittest/sox_effect/sox_effect_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import io
import itertools
from pathlib import Path
import tarfile

from torchaudio import sox_effects
from parameterized import parameterized
from torchaudio import sox_effects
from torchaudio._internal import module_utils as _mod_utils

from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
PytorchTestCase,
skipIfNoExtension,
skipIfNoModule,
skipIfNoExec,
get_asset_path,
get_sinusoid,
get_wav_data,
Expand All @@ -21,6 +27,10 @@
)


if _mod_utils.is_module_available("requests"):
import requests


@skipIfNoExtension
class TestSoxEffects(PytorchTestCase):
def test_init(self):
Expand Down Expand Up @@ -262,3 +272,152 @@ def test_mp3(self):
path = get_asset_path("mp3_without_ext")
_, sr = sox_effects.apply_effects_file(path, effects, format="mp3")
assert sr == 16000


@skipIfNoExec('sox')
@skipIfNoExtension
class TestFileObject(TempDirMixin, PytorchTestCase):
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_fileobj(self, ext, compression):
"""Applying effects via file object works"""
sample_rate = 16000
channels_first = True
effects = [['band', '300', '10']]
format_ = ext if ext in ['mp3'] else None
input_path = self.get_temp_path(f'input.{ext}')
reference_path = self.get_temp_path('reference.wav')

sox_utils.gen_audio_file(
input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)

with open(input_path, 'rb') as fileobj:
found, sr = sox_effects.apply_effects_file(
fileobj, effects, channels_first=channels_first, format=format_)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)

@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio(self, ext, compression):
"""Applying effects via BytesIO object works"""
sample_rate = 16000
channels_first = True
effects = [['band', '300', '10']]
format_ = ext if ext in ['mp3'] else None
input_path = self.get_temp_path(f'input.{ext}')
reference_path = self.get_temp_path('reference.wav')

sox_utils.gen_audio_file(
input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)

with open(input_path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_effects.apply_effects_file(
fileobj, effects, channels_first=channels_first, format=format_)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)

@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_tarfile(self, ext, compression):
"""Applying effects to compressed audio via file-like file works"""
sample_rate = 16000
channels_first = True
effects = [['band', '300', '10']]
format_ = ext if ext in ['mp3'] else None
audio_file = f'input.{ext}'

input_path = self.get_temp_path(audio_file)
reference_path = self.get_temp_path('reference.wav')
archive_path = self.get_temp_path('archive.tar.gz')

sox_utils.gen_audio_file(
input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)

with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(input_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = sox_effects.apply_effects_file(
fileobj, effects, channels_first=channels_first, format=format_)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)


@skipIfNoExtension
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_requests(self, ext, compression):
sample_rate = 16000
channels_first = True
effects = [['band', '300', '10']]
format_ = ext if ext in ['mp3'] else None
audio_file = f'input.{ext}'
input_path = self.get_temp_path(audio_file)
reference_path = self.get_temp_path('reference.wav')

sox_utils.gen_audio_file(
input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)

url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_effects.apply_effects_file(
resp.raw, effects, channels_first=channels_first, format=format_)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)
5 changes: 5 additions & 0 deletions torchaudio/csrc/pybind.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/extension.h>
#include <torchaudio/csrc/sox/effects.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/legacy.h>

Expand Down Expand Up @@ -112,4 +113,8 @@ PYBIND11_MODULE(_torchaudio, m) {
"save_audio_fileobj",
&torchaudio::sox_io::save_audio_fileobj,
"Save audio to file obj.");
m.def(
"apply_effects_fileobj",
&torchaudio::sox_effects::apply_effects_fileobj,
"Decode audio data from file-like obj and apply effects.");
}
24 changes: 20 additions & 4 deletions torchaudio/sox_effects/sox_effects.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
from pathlib import Path
from typing import List, Tuple, Optional

import torch

import torchaudio
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.utils.sox_utils import list_effects

Expand Down Expand Up @@ -170,8 +173,16 @@ def apply_effects_file(
rate and leave samples untouched.
Args:
path (str or pathlib.Path): Path to the audio file. This function also handles ``pathlib.Path`` objects, but is
annotated as ``str`` for TorchScript compiler compatibility.
path (path-like object or file-like object):
Source of audio data. When the function is not compiled by TorchScript,
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.
Note:
* This argument is intentionally annotated as ``str`` only for
TorchScript compiler compatibility.
effects (List[List[str]]): List of effects.
normalize (bool):
When ``True``, this function always return ``float32``, and sample values are
Expand Down Expand Up @@ -252,8 +263,13 @@ def apply_effects_file(
>>> for batch in loader:
>>> pass
"""
# Get string representation of 'path' in case Path object is passed
path = str(path)
if not torch.jit.is_scripting():
if hasattr(path, 'read'):
return torchaudio._torchaudio.apply_effects_fileobj(
path, effects, normalize, channels_first, format)
signal = torch.ops.torchaudio.sox_effects_apply_effects_file(
os.fspath(path), effects, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
signal = torch.ops.torchaudio.sox_effects_apply_effects_file(
path, effects, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()

0 comments on commit 30fc490

Please sign in to comment.