-
Notifications
You must be signed in to change notification settings - Fork 664
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This is a part of PRs to add new "sox_io" backend. #726 and depends on #718 and #728 . This PR adds `load` function to "sox_io" backend, which is tested on the following audio formats; - `wav` - `mp3` - `flac` - `ogg/vorbis` * By default, "sox_io" backend returns Tensor with `float32` dtype and the shape of `[channel, time]`. The samples are normalized to fit in the range of `[-1.0, 1.0]`. Unlike existing "sox" backend, the new `load` function can handle WAV file natively, when the input format is WAV with integer type, (such as 32-bit signed integer, 16-bit signed integer and 8-bit unsigned integer) by providing `normalize=False`, this function can return integer Tensor, where the samples are expressed within the whole range of the corresponding dtype, that is, `int32` tensor for `32-bit PCM`, `int16` for `16-bit PCM` and `uint8` for `8-bit PCM`. This behavior follows [scipy.io.wavfile.read](https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html). `normalize` parameter has no effect for other formats and the load function always return normalized value with `float32` Tensor. __* Note__ The current binary distribution of torchaudio does not contain `ogg/vorbis` and `opus` codecs. To handle these files, one needs to build torchaudio from the source with proper codecs in the system. __Note 2__ Since this PR, `scipy` becomes required module for running test.
- Loading branch information
Showing
11 changed files
with
772 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,90 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import scipy.io.wavfile | ||
|
||
|
||
def get_test_name(func, _, params): | ||
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}' | ||
|
||
|
||
def normalize_wav(tensor: torch.Tensor) -> torch.Tensor: | ||
if tensor.dtype == torch.float32: | ||
pass | ||
elif tensor.dtype == torch.int32: | ||
tensor = tensor.to(torch.float32) | ||
tensor[tensor > 0] /= 2147483647. | ||
tensor[tensor < 0] /= 2147483648. | ||
elif tensor.dtype == torch.int16: | ||
tensor = tensor.to(torch.float32) | ||
tensor[tensor > 0] /= 32767. | ||
tensor[tensor < 0] /= 32768. | ||
elif tensor.dtype == torch.uint8: | ||
tensor = tensor.to(torch.float32) - 128 | ||
tensor[tensor > 0] /= 127. | ||
tensor[tensor < 0] /= 128. | ||
return tensor | ||
|
||
|
||
def get_wav_data( | ||
dtype: str, | ||
num_channels: int, | ||
*, | ||
num_frames: Optional[int] = None, | ||
normalize: bool = True, | ||
channels_first: bool = True, | ||
): | ||
"""Generate linear signal of the given dtype and num_channels | ||
Data range is | ||
[-1.0, 1.0] for float32, | ||
[-2147483648, 2147483647] for int32 | ||
[-32768, 32767] for int16 | ||
[0, 255] for uint8 | ||
num_frames allow to change the linear interpolation parameter. | ||
Default values are 256 for uint8, else 1 << 16. | ||
1 << 16 as default is so that int16 value range is completely covered. | ||
""" | ||
dtype_ = getattr(torch, dtype) | ||
|
||
if num_frames is None: | ||
if dtype == 'uint8': | ||
num_frames = 256 | ||
else: | ||
num_frames = 1 << 16 | ||
|
||
if dtype == 'uint8': | ||
base = torch.linspace(0, 255, num_frames, dtype=dtype_) | ||
if dtype == 'float32': | ||
base = torch.linspace(-1., 1., num_frames, dtype=dtype_) | ||
if dtype == 'int32': | ||
base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_) | ||
if dtype == 'int16': | ||
base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_) | ||
data = base.repeat([num_channels, 1]) | ||
if not channels_first: | ||
data = data.transpose(1, 0) | ||
if normalize: | ||
data = normalize_wav(data) | ||
return data | ||
|
||
|
||
def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor: | ||
"""Load wav file without torchaudio""" | ||
sample_rate, data = scipy.io.wavfile.read(path) | ||
data = torch.from_numpy(data.copy()) | ||
if data.ndim == 1: | ||
data = data.unsqueeze(1) | ||
if normalize: | ||
data = normalize_wav(data) | ||
if channels_first: | ||
data = data.transpose(1, 0) | ||
return data, sample_rate | ||
|
||
|
||
def save_wav(path, data, sample_rate, channels_first=True): | ||
"""Save wav file without torchaudio""" | ||
if channels_first: | ||
data = data.transpose(1, 0) | ||
scipy.io.wavfile.write(path, sample_rate, data.numpy()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.