Skip to content

Commit

Permalink
add exportable mel spec (NVIDIA#5512)
Browse files Browse the repository at this point in the history
* add exportable mel spec

Signed-off-by: shane carroll <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: shane carroll <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
2 people authored and Hainan Xu committed Nov 29, 2022
1 parent a91a57d commit 6a7e3d5
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 7 deletions.
27 changes: 20 additions & 7 deletions nemo/collections/asr/modules/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from packaging import version

from nemo.collections.asr.parts.numba.spec_augment import SpecAugmentNumba, spec_augment_launch_heuristics
from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures
from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures, FilterbankFeaturesTA
from nemo.collections.asr.parts.submodules.spectr_augment import SpecAugment, SpecCutout
from nemo.core.classes import NeuralModule, typecheck
from nemo.core.classes import Exportable, NeuralModule, typecheck
from nemo.core.neural_types import (
AudioSignal,
LengthsType,
Expand Down Expand Up @@ -92,11 +92,8 @@ def get_features(self, input_signal, length):
pass


class AudioToMelSpectrogramPreprocessor(AudioPreprocessor):
class AudioToMelSpectrogramPreprocessor(AudioPreprocessor, Exportable):
"""Featurizer module that converts wavs to mel spectrograms.
We don't use torchaudio's implementation here because the original
implementation is not the same, so for the sake of backwards-compatibility
this will use the old FilterbankFeatures for now.
Args:
sample_rate (int): Sample rate of the input audio data.
Expand Down Expand Up @@ -158,6 +155,7 @@ class AudioToMelSpectrogramPreprocessor(AudioPreprocessor):
Defaults to 0.0
nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation.
Defaults to 4000
use_torchaudio: Whether to use the `torchaudio` implementation.
stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints.
stft_conv: Deprecated argument, kept for compatibility with older checkpoints.
"""
Expand Down Expand Up @@ -222,6 +220,7 @@ def __init__(
rng=None,
nb_augmentation_prob=0.0,
nb_max_freq=4000,
use_torchaudio: bool = False,
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
stft_conv=False, # Deprecated arguments; kept for config compatibility
):
Expand All @@ -239,7 +238,12 @@ def __init__(
if window_stride:
n_window_stride = int(window_stride * self._sample_rate)

self.featurizer = FilterbankFeatures(
# Given the long and similar argument list, point to the class and instantiate it by reference
if not use_torchaudio:
featurizer_class = FilterbankFeatures
else:
featurizer_class = FilterbankFeaturesTA
self.featurizer = featurizer_class(
sample_rate=self._sample_rate,
n_window_size=n_window_size,
n_window_stride=n_window_stride,
Expand All @@ -266,6 +270,14 @@ def __init__(
stft_conv=stft_conv, # Deprecated arguments; kept for config compatibility
)

def input_example(self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200):
batch_size = torch.randint(low=1, high=max_batch, size=[1]).item()
max_length = torch.randint(low=min_length, high=max_dim, size=[1]).item()
signals = torch.rand(size=[batch_size, max_length]) * 2 - 1
lengths = torch.randint(low=min_length, high=max_dim, size=[batch_size])
lengths[0] = max_length
return signals, lengths

def get_features(self, input_signal, length):
return self.featurizer(input_signal, length)

Expand Down Expand Up @@ -699,6 +711,7 @@ class AudioToMelSpectrogramPreprocessorConfig:
rng: Optional[str] = None
nb_augmentation_prob: float = 0.0
nb_max_freq: int = 4000
use_torchaudio: bool = False
stft_exact_pad: bool = False # Deprecated argument, kept for compatibility with older checkpoints.
stft_conv: bool = False # Deprecated argument, kept for compatibility with older checkpoints.

Expand Down
225 changes: 225 additions & 0 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
# This file contains code artifacts adapted from https://github.com/ryanleary/patter
import math
import random
from typing import Optional, Tuple, Union

import librosa
import numpy as np
Expand All @@ -44,6 +45,14 @@
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
from nemo.utils import logging

try:
import torchaudio

HAVE_TORCHAUDIO = True
except ModuleNotFoundError:
HAVE_TORCHAUDIO = False


CONSTANT = 1e-5


Expand Down Expand Up @@ -99,6 +108,39 @@ def splice_frames(x, frame_splicing):
return torch.cat(seq, dim=1)


@torch.jit.script_if_tracing
def make_seq_mask_like(
lengths: torch.Tensor, like: torch.Tensor, time_dim: int = -1, valid_ones: bool = True
) -> torch.Tensor:
"""
Args:
lengths: Tensor with shape [B] containing the sequence length of each batch element
like: The mask will contain the same number of dimensions as this Tensor, and will have the same max
length in the time dimension of this Tensor.
time_dim: Time dimension of the `shape_tensor` and the resulting mask. Zero-based.
valid_ones: If True, valid tokens will contain value `1` and padding will be `0`. Else, invert.
Returns:
A :class:`torch.Tensor` containing 1's and 0's for valid and invalid tokens, respectively, if `valid_ones`, else
vice-versa. Mask will have the same number of dimensions as `like`. Batch and time dimensions will match
the `like`. All other dimensions will be singletons. E.g., if `like.shape == [3, 4, 5]` and
`time_dim == -1', mask will have shape `[3, 1, 5]`.
"""
# Mask with shape [B, T]
mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.view(-1, 1))
# [B, T] -> [B, *, T] where * is any number of singleton dimensions to expand to like tensor
for _ in range(like.dim() - mask.dim()):
mask = mask.unsqueeze(1)
# If needed, transpose time dim
if time_dim != -1 and time_dim != mask.dim() - 1:
mask = mask.transpose(-1, time_dim)
# Maybe invert the padded vs. valid token values
if not valid_ones:
mask = ~mask
return mask


class WaveformFeaturizer(object):
def __init__(self, sample_rate=16000, int_values=False, augmentor=None):
self.augmentor = augmentor if augmentor is not None else AudioAugmentor()
Expand Down Expand Up @@ -401,3 +443,186 @@ def forward(self, x, seq_len):
if pad_amt != 0:
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
return x, seq_len


class FilterbankFeaturesTA(nn.Module):
"""
Exportable, `torchaudio`-based implementation of Mel Spectrogram extraction.
See `AudioToMelSpectrogramPreprocessor` for args.
"""

def __init__(
self,
sample_rate: int = 16000,
n_window_size: int = 320,
n_window_stride: int = 160,
normalize: Optional[str] = "per_feature",
nfilt: int = 64,
n_fft: Optional[int] = None,
preemph: float = 0.97,
lowfreq: float = 0,
highfreq: Optional[float] = None,
log: bool = True,
log_zero_guard_type: str = "add",
log_zero_guard_value: Union[float, str] = 2 ** -24,
dither: float = 1e-5,
window: str = "hann",
pad_to: int = 0,
pad_value: float = 0.0,
# Seems like no one uses these options anymore. Don't convolute the code by supporting thm.
use_grads: bool = False, # Deprecated arguments; kept for config compatibility
max_duration: float = 16.7, # Deprecated arguments; kept for config compatibility
frame_splicing: int = 1, # Deprecated arguments; kept for config compatibility
exact_pad: bool = False, # Deprecated arguments; kept for config compatibility
nb_augmentation_prob: float = 0.0, # Deprecated arguments; kept for config compatibility
nb_max_freq: int = 4000, # Deprecated arguments; kept for config compatibility
mag_power: float = 2.0, # Deprecated arguments; kept for config compatibility
rng: Optional[random.Random] = None, # Deprecated arguments; kept for config compatibility
stft_exact_pad: bool = False, # Deprecated arguments; kept for config compatibility
stft_conv: bool = False, # Deprecated arguments; kept for config compatibility
):
super().__init__()
if not HAVE_TORCHAUDIO:
raise ValueError(f"Need to install torchaudio to instantiate a {self.__class__.__name__}")

# Make sure log zero guard is supported, if given as a string
supported_log_zero_guard_strings = {"eps", "tiny"}
if isinstance(log_zero_guard_value, str) and log_zero_guard_value not in supported_log_zero_guard_strings:
raise ValueError(
f"Log zero guard value must either be a float or a member of {supported_log_zero_guard_strings}"
)

# Copied from `AudioPreprocessor` due to the ad-hoc structuring of the Mel Spec extractor class
self.torch_windows = {
'hann': torch.hann_window,
'hamming': torch.hamming_window,
'blackman': torch.blackman_window,
'bartlett': torch.bartlett_window,
'ones': torch.ones,
None: torch.ones,
}

# Ensure we can look up the window function
if window not in self.torch_windows:
raise ValueError(f"Got window value '{window}' but expected a member of {self.torch_windows.keys()}")

self._win_length = n_window_size
self._hop_length = n_window_stride
self._sample_rate = sample_rate
self._normalize_strategy = normalize
self._use_log = log
self._preemphasis_value = preemph
self._log_zero_guard_type = log_zero_guard_type
self._log_zero_guard_value: Union[str, float] = log_zero_guard_value
self._dither_value = dither
self._pad_to = pad_to
self._pad_value = pad_value
self._num_fft = n_fft
self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=self._sample_rate,
win_length=self._win_length,
hop_length=self._hop_length,
n_mels=nfilt,
window_fn=self.torch_windows[window],
mel_scale="slaney",
norm="slaney",
n_fft=n_fft,
f_max=highfreq,
f_min=lowfreq,
wkwargs={"periodic": False},
)

@property
def filter_banks(self):
""" Matches the analogous class """
return self._mel_spec_extractor.mel_scale.fb

def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float:
if isinstance(self._log_zero_guard_value, float):
return self._log_zero_guard_value
return getattr(torch.finfo(dtype), self._log_zero_guard_value)

def _apply_dithering(self, signals: torch.Tensor) -> torch.Tensor:
if self.training and self._dither_value > 0.0:
noise = torch.randn_like(signals) * self._dither_value
signals = signals + noise
return signals

def _apply_preemphasis(self, signals: torch.Tensor) -> torch.Tensor:
if self._preemphasis_value is not None:
padded = torch.nn.functional.pad(signals, (1, 0))
signals = signals - self._preemphasis_value * padded[:, :-1]
return signals

def _compute_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
out_lengths = input_lengths.div(self._hop_length, rounding_mode="floor").add(1).long()
return out_lengths

def _apply_pad_to(self, features: torch.Tensor) -> torch.Tensor:
# Only apply during training; else need to capture dynamic shape for exported models
if not self.training or self._pad_to == 0 or features.shape[-1] % self._pad_to == 0:
return features
pad_length = self._pad_to - (features.shape[-1] % self._pad_to)
return torch.nn.functional.pad(features, pad=(0, pad_length), value=self._pad_value)

def _apply_log(self, features: torch.Tensor) -> torch.Tensor:
if self._use_log:
zero_guard = self._resolve_log_zero_guard_value(features.dtype)
if self._log_zero_guard_type == "add":
features = features + zero_guard
elif self._log_zero_guard_type == "clamp":
features = features.clamp(min=zero_guard)
else:
raise ValueError(f"Unsupported log zero guard type: '{self._log_zero_guard_type}'")
features = features.log()
return features

def _extract_spectrograms(self, signals: torch.Tensor) -> torch.Tensor:
# Complex FFT needs to be done in single precision
with torch.cuda.amp.autocast(enabled=False):
features = self._mel_spec_extractor(waveform=signals)
return features

def _apply_normalization(self, features: torch.Tensor, lengths: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
# For consistency, this function always does a masked fill even if not normalizing.
mask: torch.Tensor = make_seq_mask_like(lengths=lengths, like=features, time_dim=-1, valid_ones=False)
features = features.masked_fill(mask, 0.0)
# Maybe don't normalize
if self._normalize_strategy is None:
return features
# Use the log zero guard for the sqrt zero guard
guard_value = self._resolve_log_zero_guard_value(features.dtype)
if self._normalize_strategy == "per_feature" or self._normalize_strategy == "all_features":
# 'all_features' reduces over each sample; 'per_feature' reduces over each channel
reduce_dim = 2
if self._normalize_strategy == "all_features":
reduce_dim = [1, 2]
# [B, D, T] -> [B, D, 1] or [B, 1, 1]
means = features.sum(dim=reduce_dim, keepdim=True).div(lengths.view(-1, 1, 1))
stds = (
features.sub(means)
.masked_fill(mask, 0.0)
.pow(2.0)
.sum(dim=reduce_dim, keepdim=True) # [B, D, T] -> [B, D, 1] or [B, 1, 1]
.div(lengths.view(-1, 1, 1) - 1) # assume biased estimator
.clamp(min=guard_value) # avoid sqrt(0)
.sqrt()
)
features = (features - means) / (stds + eps)
else:
# Deprecating constant std/mean
raise ValueError(f"Unsupported norm type: '{self._normalize_strategy}")
features = features.masked_fill(mask, 0.0)
return features

def forward(self, input_signal: torch.Tensor, length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
feature_lengths = self._compute_output_lengths(input_lengths=length)
signals = self._apply_dithering(signals=input_signal)
signals = self._apply_preemphasis(signals=signals)
features = self._extract_spectrograms(signals=signals)
features = self._apply_log(features=features)
features = self._apply_normalization(features=features, lengths=feature_lengths)
features = self._apply_pad_to(features=features)
return features, feature_lengths

0 comments on commit 6a7e3d5

Please sign in to comment.