Skip to content

Commit

Permalink
Refactor backend switching
Browse files Browse the repository at this point in the history
1. Do not rely on global variables for backend switch
   So that load/save/info/load_wav functions will be torchscript-able
2. Add no_backend module to for the case there is no backend module available
   [bonus] This allows the whole codebase importable on systems that do not have torchaudio C++ extension nor soundfile.
  • Loading branch information
mthrok committed Jun 11, 2020
1 parent b822580 commit e2643cd
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 161 deletions.
34 changes: 34 additions & 0 deletions test/test_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest

import torchaudio
from torchaudio._internal.module_utils import is_module_available


class TestBackendSwitch(unittest.TestCase):
def test_no_backend(self):
torchaudio.set_audio_backend(None)
assert torchaudio.load == torchaudio.backend.no_backend.load
assert torchaudio.load_wav == torchaudio.backend.no_backend.load_wav
assert torchaudio.save == torchaudio.backend.no_backend.save
assert torchaudio.info == torchaudio.backend.no_backend.info
assert torchaudio.get_audio_backend() is None

@unittest.skipIf(
not is_module_available('torchaudio._torchaudio'),
'torchaudio C++ extension not available')
def test_sox(self):
torchaudio.set_audio_backend('sox')
assert torchaudio.load == torchaudio.backend.sox_backend.load
assert torchaudio.load_wav == torchaudio.backend.sox_backend.load_wav
assert torchaudio.save == torchaudio.backend.sox_backend.save
assert torchaudio.info == torchaudio.backend.sox_backend.info
assert torchaudio.get_audio_backend() == 'sox'

@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available')
def test_soundfile(self):
torchaudio.set_audio_backend('soundfile')
assert torchaudio.load == torchaudio.backend.soundfile_backend.load
assert torchaudio.load_wav == torchaudio.backend.soundfile_backend.load_wav
assert torchaudio.save == torchaudio.backend.soundfile_backend.save
assert torchaudio.info == torchaudio.backend.soundfile_backend.info
assert torchaudio.get_audio_backend() == 'soundfile'
118 changes: 0 additions & 118 deletions torchaudio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

from torch import Tensor
from torchaudio._internal import module_utils as _mod_utils
from torchaudio import (
compliance,
Expand All @@ -11,7 +7,6 @@
transforms
)
from torchaudio.backend import (
_get_audio_backend_module,
list_audio_backends,
get_audio_backend,
set_audio_backend,
Expand Down Expand Up @@ -57,116 +52,3 @@ def shutdown_sox():
This function is deprecated. See ``torchaudio.sox_effects.shutdown_sox_effects``
"""
_shutdown_sox_effects()


def load(filepath: Union[str, Path],
out: Optional[Tensor] = None,
normalization: Union[bool, float, Callable] = True,
channels_first: bool = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: Optional[SignalInfo] = None,
encodinginfo: Optional[EncodingInfo] = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
r"""Loads an audio file from disk into a tensor
Args:
filepath (str or Path): Path to audio file
out (Tensor, optional): An output tensor to use instead of creating one. (Default: ``None``)
normalization (bool, float, or callable, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes signed 32-bit audio), and normalizes to `[-1, 1]`.
If `float`, then output is divided by that number
If `Callable`, then the output is passed as a parameter
to the given function, then the output is divided by
the result. (Default: ``True``)
channels_first (bool, optional): Set channels first or length first in result. (Default: ``True``)
num_frames (int, optional): Number of frames to load. 0 to load everything after the offset.
(Default: ``0``)
offset (int, optional): Number of frames from the start of the file to begin data loading.
(Default: ``0``)
signalinfo (sox_signalinfo_t, optional): A sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined. (Default: ``None``)
encodinginfo (sox_encodinginfo_t, optional): A sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined. (Default: ``None``)
filetype (str, optional): A filetype or extension to be set if sox cannot determine it
automatically. (Default: ``None``)
Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file)
Example
>>> data, sample_rate = torchaudio.load('foo.mp3')
>>> print(data.size())
torch.Size([2, 278756])
>>> print(sample_rate)
44100
>>> data_vol_normalized, _ = torchaudio.load('foo.mp3', normalization=lambda x: torch.abs(x).max())
>>> print(data_vol_normalized.abs().max())
1.
"""
return _get_audio_backend_module().load(
filepath,
out=out,
normalization=normalization,
channels_first=channels_first,
num_frames=num_frames,
offset=offset,
signalinfo=signalinfo,
encodinginfo=encodinginfo,
filetype=filetype,
)


def load_wav(filepath: Union[str, Path], **kwargs: Any) -> Tuple[Tensor, int]:
r""" Loads a wave file. It assumes that the wav file uses 16 bit per sample that needs normalization by shifting
the input right by 16 bits.
Args:
filepath (str or Path): Path to audio file
Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file)
"""
kwargs['normalization'] = 1 << 16
return load(filepath, **kwargs)


def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""Convenience function for `save_encinfo`.
Args:
filepath (str): Path to audio file
src (Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels
sample_rate (int): An integer which is the sample rate of the
audio (as listed in the metadata of the file)
precision (int, optional): Bit precision (Default: ``16``)
channels_first (bool, optional): Set channels first or length first in result. (
Default: ``True``)
"""

return _get_audio_backend_module().save(
filepath, src, sample_rate, precision=precision, channels_first=channels_first
)


def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""Gets metadata from an audio file without loading the signal.
Args:
filepath (str): Path to audio file
Returns:
(sox_signalinfo_t, sox_encodinginfo_t): A si (sox_signalinfo_t) signal
info as a python object. An ei (sox_encodinginfo_t) encoding info
Example
>>> si, ei = torchaudio.info('foo.wav')
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding
"""
return _get_audio_backend_module().info(filepath)
1 change: 0 additions & 1 deletion torchaudio/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from . import utils
from .utils import (
_get_audio_backend_module,
list_audio_backends,
get_audio_backend,
set_audio_backend,
Expand Down
114 changes: 113 additions & 1 deletion torchaudio/backend/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Tuple
from typing import Any, Optional


class SignalInfo:
Expand Down Expand Up @@ -29,3 +29,115 @@ def __init__(self,
self.reverse_nibbles = reverse_nibbles
self.reverse_bits = reverse_bits
self.opposite_endian = opposite_endian


_LOAD_DOCSTRING = r"""Loads an audio file from disk into a tensor
Args:
filepath: Path to audio file
out: An optional output tensor to use instead of creating one. (Default: ``None``)
normalization: Optional normalization.
If boolean `True`, then output is divided by `1 << 31`.
Assuming the input is signed 32-bit audio, this normalizes to `[-1, 1]`.
If `float`, then output is divided by that number.
If `Callable`, then the output is passed as a paramete to the given function,
then the output is divided by the result. (Default: ``True``)
channels_first: Set channels first or length first in result. (Default: ``True``)
num_frames: Number of frames to load. 0 to load everything after the offset.
(Default: ``0``)
offset: Number of frames from the start of the file to begin data loading.
(Default: ``0``)
signalinfo: A sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined. (Default: ``None``)
encodinginfo: A sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined. (Default: ``None``)
filetype: A filetype or extension to be set if sox cannot determine it
automatically. (Default: ``None``)
Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where
L is the number of audio frames and
C is the number of channels.
An integer which is the sample rate of the audio (as listed in the metadata of the file)
Example
>>> data, sample_rate = torchaudio.load('foo.mp3')
>>> print(data.size())
torch.Size([2, 278756])
>>> print(sample_rate)
44100
>>> data_vol_normalized, _ = torchaudio.load('foo.mp3', normalization=lambda x: torch.abs(x).max())
>>> print(data_vol_normalized.abs().max())
1.
"""


_LOAD_WAV_DOCSTRING = r""" Loads a wave file.
It assumes that the wav file uses 16 bit per sample that needs normalization by
shifting the input right by 16 bits.
Args:
filepath: Path to audio file
Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file)
"""

_SAVE_DOCSTRING = r"""Saves a Tensor on file as an audio file
Args:
filepath: Path to audio file
src: An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels
sample_rate: An integer which is the sample rate of the
audio (as listed in the metadata of the file)
precision Bit precision (Default: ``16``)
channels_first (bool, optional): Set channels first or length first in result. (
Default: ``True``)
"""


_INFO_DOCSTRING = r"""Gets metadata from an audio file without loading the signal.
Args:
filepath: Path to audio file
Returns:
(sox_signalinfo_t, sox_encodinginfo_t): A si (sox_signalinfo_t) signal
info as a python object. An ei (sox_encodinginfo_t) encoding info
Example
>>> si, ei = torchaudio.info('foo.wav')
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding
"""


def _impl_load(func):
setattr(func, '__doc__', _LOAD_DOCSTRING)
return func


def _impl_load_wav(func):
setattr(func, '__doc__', _LOAD_WAV_DOCSTRING)
return func


def _impl_save(func):
setattr(func, '__doc__', _SAVE_DOCSTRING)
return func


def _impl_info(func):
setattr(func, '__doc__', _INFO_DOCSTRING)
return func
35 changes: 35 additions & 0 deletions torchaudio/backend/no_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

from torch import Tensor

from . import common
from .common import SignalInfo, EncodingInfo


@common._impl_load
def load(filepath: Union[str, Path],
out: Optional[Tensor] = None,
normalization: Union[bool, float, Callable] = True,
channels_first: bool = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: Optional[SignalInfo] = None,
encodinginfo: Optional[EncodingInfo] = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
raise RuntimeError('No audio I/O backend is available.')


@common._impl_load_wav
def load_wav(filepath: Union[str, Path], **kwargs: Any) -> Tuple[Tensor, int]:
raise RuntimeError('No audio I/O backend is available.')


@common._impl_save
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
raise RuntimeError('No audio I/O backend is available.')


@common._impl_info
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
raise RuntimeError('No audio I/O backend is available.')
11 changes: 11 additions & 0 deletions torchaudio/backend/soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from . import common
from .common import SignalInfo, EncodingInfo

if _mod_utils.is_module_available('soundfile'):
Expand All @@ -24,6 +25,7 @@


@_mod_utils.requires_module('soundfile')
@common._impl_load
def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
Expand Down Expand Up @@ -71,6 +73,14 @@ def load(filepath: str,


@_mod_utils.requires_module('soundfile')
@common._impl_load_wav
def load_wav(filepath, **kwargs):
# kwargs['normalization'] = 1 << 16
return load(filepath, **kwargs)


@_mod_utils.requires_module('soundfile')
@common._impl_save
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save"""

Expand Down Expand Up @@ -104,6 +114,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan


@_mod_utils.requires_module('soundfile')
@common._impl_info
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info"""

Expand Down
Loading

0 comments on commit e2643cd

Please sign in to comment.