Skip to content

Commit

Permalink
Add bits_per_sample to info (#1177)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Jan 25, 2021
1 parent 2703175 commit 99ed718
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 20 deletions.
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/soundfile_backend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def skipIfFormatNotSupported(fmt):
import soundfile

fmts = soundfile.available_formats()
return skipIf(fmt not in fmts, f'"{fmt}" is not supported by sondfile')
return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile')
return skipIf(True, '"soundfile" not available.')


Expand Down
46 changes: 39 additions & 7 deletions test/torchaudio_unittest/soundfile_backend/info_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from unittest.mock import patch
import warnings

import torch
from torchaudio.backend import _soundfile_backend as soundfile_backend
from torchaudio._internal import module_utils as _mod_utils
Expand All @@ -18,10 +21,11 @@
@skipIfNoModule("soundfile")
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2],
[("float32", 32), ("int32", 32), ("int16", 16), ("uint8", 8)], [8000, 16000], [1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
def test_wav(self, dtype_and_bit_depth, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file correctly"""
dtype, bits_per_sample = dtype_and_bit_depth
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(
Expand All @@ -32,12 +36,14 @@ def test_wav(self, dtype, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

@parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [4, 8, 16, 32],
[("float32", 32), ("int32", 32), ("int16", 16), ("uint8", 8)], [8000, 16000], [1, 2],
)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
def test_wav_multiple_channels(self, dtype_and_bit_depth, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file with channels more than 2 correctly"""
dtype, bits_per_sample = dtype_and_bit_depth
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(
Expand All @@ -48,6 +54,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("FLAC")
Expand All @@ -63,6 +70,7 @@ def test_flac(self, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == 16

@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("OGG")
Expand All @@ -78,18 +86,42 @@ def test_ogg(self, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0

@parameterize([8000, 16000], [1, 2])
@parameterize([8000, 16000], [1, 2], [('PCM_24', 24), ('PCM_32', 32)])
@skipIfFormatNotSupported("NIST")
def test_sphere(self, sample_rate, num_channels):
def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
"""`soundfile_backend.info` can check sph file correctly"""
duration = 1
num_frames = sample_rate * duration
data = torch.randn(num_frames, num_channels).numpy()
path = self.get_temp_path("data.nist")
soundfile.write(path, data, sample_rate)
subtype, bits_per_sample = subtype_and_bit_depth
soundfile.write(path, data, sample_rate, subtype=subtype)

info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

def test_unknown_subtype_warning(self):
"""soundfile_backend.info issues a warning when the subtype is unknown
This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
dict should be updated.
"""
def _mock_info_func(_):
class MockSoundFileInfo:
samplerate = 8000
frames = 356
channels = 2
subtype = 'UNSEEN_SUBTYPE'
return MockSoundFileInfo()

with patch("soundfile.info", _mock_info_func):
with warnings.catch_warnings(record=True) as w:
info = soundfile_backend.info("foo")
assert len(w) == 1
assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message)
assert info.bits_per_sample == 0
21 changes: 17 additions & 4 deletions test/torchaudio_unittest/sox_io_backend/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_wav(self, dtype, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
Expand All @@ -52,6 +53,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)

@parameterized.expand(list(itertools.product(
[8000, 16000],
Expand All @@ -71,6 +73,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate):
# mp3 does not preserve the number of samples
# assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats

@parameterized.expand(list(itertools.product(
[8000, 16000],
Expand All @@ -89,6 +92,7 @@ def test_flac(self, sample_rate, num_channels, compression_level):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 24 # FLAC standard

@parameterized.expand(list(itertools.product(
[8000, 16000],
Expand All @@ -107,20 +111,23 @@ def test_vorbis(self, sample_rate, num_channels, quality_level):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[16, 32],
)), name_func=name_func)
def test_sphere(self, sample_rate, num_channels):
def test_sphere(self, sample_rate, num_channels, bits_per_sample):
"""`sox_io_backend.info` can check sph file correctly"""
duration = 1
path = self.get_temp_path('data.sph')
sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration)
sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration, bit_depth=bits_per_sample)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
Expand All @@ -131,13 +138,15 @@ def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check amb file correctly"""
duration = 1
path = self.get_temp_path('data.amb')
bits_per_sample = sox_utils.get_bit_depth(dtype)
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
bit_depth=bits_per_sample, duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

def test_amr_nb(self):
"""`sox_io_backend.info` can check amr-nb file correctly"""
Expand All @@ -146,11 +155,13 @@ def test_amr_nb(self):
sample_rate = 8000
path = self.get_temp_path('data.amr-nb')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration)
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16,
duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0


@skipIfNoExtension
Expand All @@ -167,6 +178,7 @@ def test_opus(self, bitrate, num_channels, compression_level):
assert info.sample_rate == 48000
assert info.num_frames == 32768
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats


@skipIfNoExtension
Expand All @@ -184,3 +196,4 @@ def test_mp3(self):
path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
49 changes: 48 additions & 1 deletion torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,45 @@
import soundfile


# Mapping from soundfile subtype to number of bits per sample.
# This is mostly heuristical and the value is set to 0 when it is irrelevant
# (lossy formats) or when it can't be inferred.
# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
# the default seems to be 8 bits but it can be compressed further to 4 bits.
# The dict is inspired from
# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
_SUBTYPE_TO_BITS_PER_SAMPLE = {
'PCM_S8': 8, # Signed 8 bit data
'PCM_16': 16, # Signed 16 bit data
'PCM_24': 24, # Signed 24 bit data
'PCM_32': 32, # Signed 32 bit data
'PCM_U8': 8, # Unsigned 8 bit data (WAV and RAW only)
'FLOAT': 32, # 32 bit float data
'DOUBLE': 64, # 64 bit float data
'ULAW': 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'ALAW': 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'IMA_ADPCM': 0, # IMA ADPCM.
'MS_ADPCM': 0, # Microsoft ADPCM.
'GSM610': 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
'VOX_ADPCM': 0, # OKI / Dialogix ADPCM
'G721_32': 0, # 32kbs G721 ADPCM encoding.
'G723_24': 0, # 24kbs G723 ADPCM encoding.
'G723_40': 0, # 40kbs G723 ADPCM encoding.
'DWVW_12': 12, # 12 bit Delta Width Variable Word encoding.
'DWVW_16': 16, # 16 bit Delta Width Variable Word encoding.
'DWVW_24': 24, # 24 bit Delta Width Variable Word encoding.
'DWVW_N': 0, # N bit Delta Width Variable Word encoding.
'DPCM_8': 8, # 8 bit differential PCM (XI only)
'DPCM_16': 16, # 16 bit differential PCM (XI only)
'VORBIS': 0, # Xiph Vorbis encoding. (lossy)
'ALAC_16': 16, # Apple Lossless Audio Codec (16 bit).
'ALAC_20': 20, # Apple Lossless Audio Codec (20 bit).
'ALAC_24': 24, # Apple Lossless Audio Codec (24 bit).
'ALAC_32': 32, # Apple Lossless Audio Codec (32 bit).
}


@_mod_utils.requires_module("soundfile")
def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file.
Expand All @@ -27,7 +66,15 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
AudioMetaData: meta data of the given audio.
"""
sinfo = soundfile.info(filepath)
return AudioMetaData(sinfo.samplerate, sinfo.frames, sinfo.channels)
if sinfo.subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
warnings.warn(
f"The {sinfo.subtype} subtype is unknown to TorchAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
)
bits_per_sample = _SUBTYPE_TO_BITS_PER_SAMPLE.get(sinfo.subtype, 0)
return AudioMetaData(sinfo.samplerate, sinfo.frames, sinfo.channels, bits_per_sample=bits_per_sample)


_SUBTYPE2DTYPE = {
Expand Down
5 changes: 4 additions & 1 deletion torchaudio/backend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ class AudioMetaData:
:ivar int sample_rate: Sample rate
:ivar int num_frames: The number of frames
:ivar int num_channels: The number of channels
:ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
or when it cannot be accurately inferred.
"""
def __init__(self, sample_rate: int, num_frames: int, num_channels: int):
def __init__(self, sample_rate: int, num_frames: int, num_channels: int, bits_per_sample: int):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
self.bits_per_sample = bits_per_sample


@_mod_utils.deprecated('Please migrate to `AudioMetaData`.', '0.9.0')
Expand Down
3 changes: 2 additions & 1 deletion torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def info(
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels(),
sinfo.get_bits_per_sample())


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand Down
13 changes: 10 additions & 3 deletions torchaudio/csrc/sox/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ namespace sox_io {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_)
const int64_t num_frames_,
const int64_t bits_per_sample_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_frames(num_frames_){};
num_frames(num_frames_),
bits_per_sample(bits_per_sample_){};

int64_t SignalInfo::getSampleRate() const {
return sample_rate;
Expand All @@ -30,6 +32,10 @@ int64_t SignalInfo::getNumFrames() const {
return num_frames;
}

int64_t SignalInfo::getBitsPerSample() const {
return bits_per_sample;
}

c10::intrusive_ptr<SignalInfo> get_info(
const std::string& path,
c10::optional<std::string>& format) {
Expand All @@ -46,7 +52,8 @@ c10::intrusive_ptr<SignalInfo> get_info(
return c10::make_intrusive<SignalInfo>(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels));
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample));
}

namespace {
Expand Down
5 changes: 4 additions & 1 deletion torchaudio/csrc/sox/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_frames;
int64_t bits_per_sample;

SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_);
const int64_t num_frames_,
const int64_t bits_per_sample_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumFrames() const;
int64_t getBitsPerSample() const;
};

c10::intrusive_ptr<SignalInfo> get_info(
Expand Down
3 changes: 2 additions & 1 deletion torchaudio/csrc/sox/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.class_<torchaudio::sox_io::SignalInfo>("SignalInfo")
.def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames);
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames)
.def("get_bits_per_sample", &torchaudio::sox_io::SignalInfo::getBitsPerSample);

m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info);
m.def(
Expand Down

0 comments on commit 99ed718

Please sign in to comment.