Skip to content

Commit

Permalink
Add AMB/AMR-NB/AMR-WB support to "sox_io" backend (#1066)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Dec 4, 2020
1 parent 2a02d7f commit 4406a6b
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 13 deletions.
2 changes: 2 additions & 0 deletions build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def _get_extra_objects():
'libvorbisfile.a',
'libvorbis.a',
'libogg.a',
'libopencore-amrnb.a',
'libopencore-amrwb.a',
]
for lib in libs:
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
Expand Down
30 changes: 30 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,36 @@ def test_sphere(self, sample_rate, num_channels):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check amb file correctly"""
duration = 1
path = self.get_temp_path('data.amb')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels

def test_amr_nb(self):
"""`sox_io_backend.info` can check amr-nb file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.amr-nb')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels


@skipIfNoExtension
class TestInfoOpus(PytorchTestCase):
Expand Down
61 changes: 61 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,53 @@ def assert_sphere(self, sample_rate, num_channels, duration):
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)

def assert_amb(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load amb format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.amb')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate amb with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load amb with torchaudio
data, sr = sox_io_backend.load(path, normalize=normalize)
# 4. Load wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)

def assert_amr_nb(self, duration):
"""`sox_io_backend.load` can load amr-nb format.
This test takes the same strategy as mp3 to compare the result
"""
sample_rate = 8000
num_channels = 1
path = self.get_temp_path('1.original.amr-nb')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate amr-nb with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=32, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load amr-nb with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down Expand Up @@ -260,6 +307,20 @@ def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.load` can load sph format correctly."""
self.assert_sphere(sample_rate, num_channels, duration=1)

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16'],
[8000, 16000],
[1, 2],
[False, True],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load sph format correctly."""
self.assert_amb(dtype, sample_rate, num_channels, normalize, duration=1)

def test_amr_nb(self):
"""`sox_io_backend.load` can load amr_nb format correctly."""
self.assert_amr_nb(duration=1)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down
75 changes: 75 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,68 @@ def assert_sphere(self, sample_rate, num_channels, duration):

self.assertEqual(found, expected)

def assert_amb(self, dtype, sample_rate, num_channels, duration):
"""`sox_io_backend.save` can save amb format.
This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path('1.reference.wav')
amb_path = self.get_temp_path('2.1.torchaudio.amb')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
amb_path_sox = self.get_temp_path('3.1.sox.amb')
wav_path_sox = self.get_temp_path('3.2.sox.wav')

# 1. Generate original wav
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to amb with torchaudio
sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate)
# 2.2. Convert the amb to wav with Sox
sox_utils.convert_audio_file(amb_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]

# 3.1. Convert the original wav to amb with SoX
sox_utils.convert_audio_file(src_path, amb_path_sox)
# 3.2. Convert the amb to wav with Sox
sox_utils.convert_audio_file(amb_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]

self.assertEqual(found, expected)

def assert_amr_nb(self, duration):
"""`sox_io_backend.save` can save amr_nb format.
This test takes the same strategy as mp3 to compare the result
"""
sample_rate = 8000
num_channels = 1
src_path = self.get_temp_path('1.reference.wav')
amr_path = self.get_temp_path('2.1.torchaudio.amr-nb')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
amr_path_sox = self.get_temp_path('3.1.sox.amr-nb')
wav_path_sox = self.get_temp_path('3.2.sox.wav')

# 1. Generate original wav
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to amr_nb with torchaudio
sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate)
# 2.2. Convert the amr_nb to wav with Sox
sox_utils.convert_audio_file(amr_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]

# 3.1. Convert the original wav to amr_nb with SoX
sox_utils.convert_audio_file(src_path, amr_path_sox)
# 3.2. Convert the amr_nb to wav with Sox
sox_utils.convert_audio_file(amr_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]

self.assertEqual(found, expected)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down Expand Up @@ -302,6 +364,19 @@ def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.save` can save sph format."""
self.assert_sphere(sample_rate, num_channels, duration=1)

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.save` can save amb format."""
self.assert_amb(dtype, sample_rate, num_channels, duration=1)

def test_amr_nb(self):
"""`sox_io_backend.save` can save amr-nb format."""
self.assert_amr_nb(duration=1)


@skipIfNoExec('sox')
@skipIfNoExtension
Expand Down
12 changes: 10 additions & 2 deletions third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ ExternalProject_Add(libmad
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/src/libmad/configure ${COMMON_ARGS}
)

ExternalProject_Add(amr
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://sourceforge.net/projects/opencore-amr/files/opencore-amr/opencore-amr-0.1.5.tar.gz
URL_HASH SHA256=2c006cb9d5f651bfb5e60156dbff6af3c9d35c7bbcc9015308c0aff1e14cd341
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/src/amr/configure ${COMMON_ARGS}
)

ExternalProject_Add(libmp3lame
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
DOWNLOAD_DIR ${ARCHIVE_DIR}
Expand Down Expand Up @@ -72,11 +80,11 @@ ExternalProject_Add(opusfile

ExternalProject_Add(libsox
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
DEPENDS libogg libflac libvorbis opusfile libmp3lame libmad
DEPENDS libogg libflac libvorbis opusfile libmp3lame libmad amr
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://downloads.sourceforge.net/project/sox/sox/14.4.2/sox-14.4.2.tar.bz2
URL_HASH SHA256=81a6956d4330e75b5827316e44ae381e6f1e8928003c6aa45896da9041ea149c
# OpenMP is by default compiled against GNU OpenMP, which conflicts with the version of OpenMP that PyTorch uses.
# See https://github.com/pytorch/audio/pull/1026
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --disable-openmp
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp
)
10 changes: 6 additions & 4 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,19 @@ def load(
This function can handle all the codecs that underlying libsox can handle,
however it is tested on the following formats;
* WAV
* WAV, AMB
* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* 8-bit unsigned integer (WAV only)
* MP3
* FLAC
* OGG/VORBIS
* OPUS
* SPHERE
* AMR-NB
To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
Expand Down Expand Up @@ -119,7 +120,7 @@ def save(
Note:
Supported formats are;
* WAV
* WAV, AMB
* 32-bit floating-point
* 32-bit signed integer
Expand All @@ -130,6 +131,7 @@ def save(
* FLAC
* OGG/VORBIS
* SPHERE
* AMR-NB
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
Expand Down Expand Up @@ -160,7 +162,7 @@ def save(
filepath = str(filepath)
if compression is None:
ext = str(filepath).split('.')[-1].lower()
if ext in ['wav', 'sph']:
if ext in ['wav', 'sph', 'amb', 'amr-nb']:
compression = 0.
elif ext == 'mp3':
compression = -4.5
Expand Down
8 changes: 7 additions & 1 deletion torchaudio/csrc/sox_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,17 @@ void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<TensorSignal>& signal,
const double compression) {
const auto tensor = signal->getTensor();
auto tensor = signal->tensor;

validate_input_tensor(tensor);

const auto filetype = get_filetype(file_name);
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(signal->channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "amr-nb format only supports single channel audio.");
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
}
const auto signal_info = get_signalinfo(signal.get(), filetype);
const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
Expand Down
22 changes: 16 additions & 6 deletions torchaudio/csrc/sox_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ sox_encoding_t get_encoding(
return SOX_ENCODING_FLAC;
if (filetype == "ogg" || filetype == "vorbis")
return SOX_ENCODING_VORBIS;
if (filetype == "wav") {
if (filetype == "wav" || filetype == "amb") {
if (dtype == torch::kUInt8)
return SOX_ENCODING_UNSIGNED;
if (dtype == torch::kInt16)
Expand All @@ -236,7 +236,9 @@ sox_encoding_t get_encoding(
}
if (filetype == "sph")
return SOX_ENCODING_SIGN2;
throw std::runtime_error("Unsupported file type.");
if (filetype == "amr-nb")
return SOX_ENCODING_AMR_NB;
throw std::runtime_error("Unsupported file type: " + filetype);
}

unsigned get_precision(
Expand All @@ -248,7 +250,7 @@ unsigned get_precision(
return 24;
if (filetype == "ogg" || filetype == "vorbis")
return SOX_UNSPEC;
if (filetype == "wav") {
if (filetype == "wav" || filetype == "amb") {
if (dtype == torch::kUInt8)
return 8;
if (dtype == torch::kInt16)
Expand All @@ -261,7 +263,13 @@ unsigned get_precision(
}
if (filetype == "sph")
return 32;
throw std::runtime_error("Unsupported file type.");
if (filetype == "amr-nb") {
TORCH_INTERNAL_ASSERT(
dtype == torch::kInt16,
"When saving to AMR-NB format, the input tensor must be int16 type.");
return 16;
}
throw std::runtime_error("Unsupported file type: " + filetype);
}

sox_signalinfo_t get_signalinfo(
Expand All @@ -287,11 +295,13 @@ sox_encodinginfo_t get_encodinginfo(
return compression;
if (filetype == "ogg" || filetype == "vorbis")
return compression;
if (filetype == "wav")
if (filetype == "wav" || filetype == "amb")
return 0.;
if (filetype == "sph")
return 0.;
throw std::runtime_error("Unsupported file type.");
if (filetype == "amr-nb")
return 0.;
throw std::runtime_error("Unsupported file type: " + filetype);
}();

return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),
Expand Down

0 comments on commit 4406a6b

Please sign in to comment.