Skip to content

Commit

Permalink
Support file-like object in info (#1108)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Jan 27, 2021
1 parent 22e7e87 commit 41c76a1
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 15 deletions.
63 changes: 63 additions & 0 deletions test/torchaudio_unittest/soundfile_backend/info_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import patch
import warnings
import tarfile

import torch
from torchaudio.backend import _soundfile_backend as soundfile_backend
Expand Down Expand Up @@ -125,3 +126,65 @@ class MockSoundFileInfo:
assert len(w) == 1
assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message)
assert info.bits_per_sample == 0


@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext, subtype, bits_per_sample):
"""Query audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}')

data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(path, data, sample_rate, subtype=subtype)

with open(path, 'rb') as fileobj:
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav', 'PCM_16', 16)

@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac', 'PCM_16', 16)

def _test_tarobj(self, ext, subtype, bits_per_sample):
"""Query compressed audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')

data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(audio_path, data, sample_rate, subtype=subtype)

with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

def test_tarobj_wav(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('wav', 'PCM_16', 16)

@skipIfFormatNotSupported("FLAC")
def test_tarobj_flac(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('flac', 'PCM_16', 16)
151 changes: 150 additions & 1 deletion test/torchaudio_unittest/sox_io_backend/info_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import io
import itertools
from parameterized import parameterized
import tarfile

from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils

from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
skipIfNoModule,
get_asset_path,
get_wav_data,
save_wav,
Expand All @@ -18,6 +23,10 @@
)


if _mod_utils.is_module_available("requests"):
import requests


@skipIfNoExec('sox')
@skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
Expand Down Expand Up @@ -197,3 +206,143 @@ def test_mp3(self):
sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats


@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase):
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_fileobj(self, ext, bits_per_sample):
"""Querying audio via file object works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)

with open(path, 'rb') as fileobj:
sinfo = sox_io_backend.info(fileobj, format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample

def _test_bytesio(self, ext, bits_per_sample, duration):
sample_rate = 16000
num_channels = 2
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)

with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
sinfo = sox_io_backend.info(fileobj, format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample

@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_bytesio(self, ext, bits_per_sample):
"""Querying audio via ByteIO object works"""
self._test_bytesio(ext, bits_per_sample, duration=3)

@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_bytesio_tiny(self, ext, bits_per_sample):
"""Querying audio via ByteIO object works for small data"""
self._test_bytesio(ext, bits_per_sample, duration=1 / 1600)

@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_tarfile(self, ext, bits_per_sample):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')

sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)

with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
sinfo = sox_io_backend.info(fileobj, format=format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample


@skipIfNoExtension
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_requests(self, ext, bits_per_sample):
"""Querying compressed audio via requests works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)

sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)

url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
sinfo = sox_io_backend.info(resp.raw, format=format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample
10 changes: 6 additions & 4 deletions torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file.
Args:
filepath (str or pathlib.Path): Path to audio file.
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
for the consistency with "sox_io" backend, which has a restriction on type annotation
for TorchScript compiler compatiblity.
filepath (path-like object or file-like object):
Source of audio data.
Note:
* This argument is intentionally annotated as ``str`` only,
for the consistency with "sox_io" backend, which has a restriction
on type annotation due to TorchScript compiler compatiblity.
format (str, optional):
Not used. PySoundFile does not accept format hint.
Expand Down
49 changes: 42 additions & 7 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,26 @@
from .common import AudioMetaData


@torch.jit.unused
def _info(
filepath: str,
format: Optional[str] = None,
) -> AudioMetaData:
if hasattr(filepath, 'read'):
sinfo = torchaudio._torchaudio.get_info_fileobj(
filepath, format)
sample_rate, num_channels, num_frames, bits_per_sample = sinfo
return AudioMetaData(
sample_rate, num_frames, num_channels, bits_per_sample)
sinfo = torch.ops.torchaudio.sox_io_get_info(os.fspath(filepath), format)
return AudioMetaData(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample(),
)


@_mod_utils.requires_module('torchaudio._torchaudio')
def info(
filepath: str,
Expand All @@ -18,9 +38,21 @@ def info(
"""Get signal information of an audio file.
Args:
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects,
but is annotated as ``str`` for TorchScript compatibility.
filepath (path-like object or file-like object):
Source of audio data. When the function is not compiled by TorchScript,
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.
Note:
* When the input type is file-like object, this function cannot
get the correct length (``num_samples``) for certain formats,
such as ``mp3`` and ``vorbis``.
In this case, the value of ``num_samples`` is ``0``.
* This argument is intentionally annotated as ``str`` only due to
TorchScript compiler compatibility.
format (str, optional):
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
Expand All @@ -29,11 +61,14 @@ def info(
Returns:
AudioMetaData: Metadata of the given audio.
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if not torch.jit.is_scripting():
return _info(filepath, format)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels(),
sinfo.get_bits_per_sample())
return AudioMetaData(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample())


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ PYBIND11_MODULE(_torchaudio, m) {
"get_info",
&torch::audio::get_info,
"Gets information about an audio file");
m.def(
"get_info_fileobj",
&torchaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
m.def(
"load_audio_fileobj",
&torchaudio::sox_io::load_audio_fileobj,
Expand Down
Loading

0 comments on commit 41c76a1

Please sign in to comment.