diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py index 9756760743f40..f133426faf33a 100644 --- a/nemo/collections/asr/modules/audio_preprocessing.py +++ b/nemo/collections/asr/modules/audio_preprocessing.py @@ -22,9 +22,9 @@ from packaging import version from nemo.collections.asr.parts.numba.spec_augment import SpecAugmentNumba, spec_augment_launch_heuristics -from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures +from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures, FilterbankFeaturesTA from nemo.collections.asr.parts.submodules.spectr_augment import SpecAugment, SpecCutout -from nemo.core.classes import NeuralModule, typecheck +from nemo.core.classes import Exportable, NeuralModule, typecheck from nemo.core.neural_types import ( AudioSignal, LengthsType, @@ -92,11 +92,8 @@ def get_features(self, input_signal, length): pass -class AudioToMelSpectrogramPreprocessor(AudioPreprocessor): +class AudioToMelSpectrogramPreprocessor(AudioPreprocessor, Exportable): """Featurizer module that converts wavs to mel spectrograms. - We don't use torchaudio's implementation here because the original - implementation is not the same, so for the sake of backwards-compatibility - this will use the old FilterbankFeatures for now. Args: sample_rate (int): Sample rate of the input audio data. @@ -158,6 +155,7 @@ class AudioToMelSpectrogramPreprocessor(AudioPreprocessor): Defaults to 0.0 nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation. Defaults to 4000 + use_torchaudio: Whether to use the `torchaudio` implementation. stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints. stft_conv: Deprecated argument, kept for compatibility with older checkpoints. """ @@ -222,6 +220,7 @@ def __init__( rng=None, nb_augmentation_prob=0.0, nb_max_freq=4000, + use_torchaudio: bool = False, stft_exact_pad=False, # Deprecated arguments; kept for config compatibility stft_conv=False, # Deprecated arguments; kept for config compatibility ): @@ -239,7 +238,12 @@ def __init__( if window_stride: n_window_stride = int(window_stride * self._sample_rate) - self.featurizer = FilterbankFeatures( + # Given the long and similar argument list, point to the class and instantiate it by reference + if not use_torchaudio: + featurizer_class = FilterbankFeatures + else: + featurizer_class = FilterbankFeaturesTA + self.featurizer = featurizer_class( sample_rate=self._sample_rate, n_window_size=n_window_size, n_window_stride=n_window_stride, @@ -266,6 +270,14 @@ def __init__( stft_conv=stft_conv, # Deprecated arguments; kept for config compatibility ) + def input_example(self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200): + batch_size = torch.randint(low=1, high=max_batch, size=[1]).item() + max_length = torch.randint(low=min_length, high=max_dim, size=[1]).item() + signals = torch.rand(size=[batch_size, max_length]) * 2 - 1 + lengths = torch.randint(low=min_length, high=max_dim, size=[batch_size]) + lengths[0] = max_length + return signals, lengths + def get_features(self, input_signal, length): return self.featurizer(input_signal, length) @@ -699,6 +711,7 @@ class AudioToMelSpectrogramPreprocessorConfig: rng: Optional[str] = None nb_augmentation_prob: float = 0.0 nb_max_freq: int = 4000 + use_torchaudio: bool = False stft_exact_pad: bool = False # Deprecated argument, kept for compatibility with older checkpoints. stft_conv: bool = False # Deprecated argument, kept for compatibility with older checkpoints. diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index a42255f1a999a..b433144677e8a 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -34,6 +34,7 @@ # This file contains code artifacts adapted from https://github.com/ryanleary/patter import math import random +from typing import Optional, Tuple, Union import librosa import numpy as np @@ -44,6 +45,14 @@ from nemo.collections.asr.parts.preprocessing.segment import AudioSegment from nemo.utils import logging +try: + import torchaudio + + HAVE_TORCHAUDIO = True +except ModuleNotFoundError: + HAVE_TORCHAUDIO = False + + CONSTANT = 1e-5 @@ -99,6 +108,39 @@ def splice_frames(x, frame_splicing): return torch.cat(seq, dim=1) +@torch.jit.script_if_tracing +def make_seq_mask_like( + lengths: torch.Tensor, like: torch.Tensor, time_dim: int = -1, valid_ones: bool = True +) -> torch.Tensor: + """ + + Args: + lengths: Tensor with shape [B] containing the sequence length of each batch element + like: The mask will contain the same number of dimensions as this Tensor, and will have the same max + length in the time dimension of this Tensor. + time_dim: Time dimension of the `shape_tensor` and the resulting mask. Zero-based. + valid_ones: If True, valid tokens will contain value `1` and padding will be `0`. Else, invert. + + Returns: + A :class:`torch.Tensor` containing 1's and 0's for valid and invalid tokens, respectively, if `valid_ones`, else + vice-versa. Mask will have the same number of dimensions as `like`. Batch and time dimensions will match + the `like`. All other dimensions will be singletons. E.g., if `like.shape == [3, 4, 5]` and + `time_dim == -1', mask will have shape `[3, 1, 5]`. + """ + # Mask with shape [B, T] + mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.view(-1, 1)) + # [B, T] -> [B, *, T] where * is any number of singleton dimensions to expand to like tensor + for _ in range(like.dim() - mask.dim()): + mask = mask.unsqueeze(1) + # If needed, transpose time dim + if time_dim != -1 and time_dim != mask.dim() - 1: + mask = mask.transpose(-1, time_dim) + # Maybe invert the padded vs. valid token values + if not valid_ones: + mask = ~mask + return mask + + class WaveformFeaturizer(object): def __init__(self, sample_rate=16000, int_values=False, augmentor=None): self.augmentor = augmentor if augmentor is not None else AudioAugmentor() @@ -401,3 +443,186 @@ def forward(self, x, seq_len): if pad_amt != 0: x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) return x, seq_len + + +class FilterbankFeaturesTA(nn.Module): + """ + Exportable, `torchaudio`-based implementation of Mel Spectrogram extraction. + + See `AudioToMelSpectrogramPreprocessor` for args. + + """ + + def __init__( + self, + sample_rate: int = 16000, + n_window_size: int = 320, + n_window_stride: int = 160, + normalize: Optional[str] = "per_feature", + nfilt: int = 64, + n_fft: Optional[int] = None, + preemph: float = 0.97, + lowfreq: float = 0, + highfreq: Optional[float] = None, + log: bool = True, + log_zero_guard_type: str = "add", + log_zero_guard_value: Union[float, str] = 2 ** -24, + dither: float = 1e-5, + window: str = "hann", + pad_to: int = 0, + pad_value: float = 0.0, + # Seems like no one uses these options anymore. Don't convolute the code by supporting thm. + use_grads: bool = False, # Deprecated arguments; kept for config compatibility + max_duration: float = 16.7, # Deprecated arguments; kept for config compatibility + frame_splicing: int = 1, # Deprecated arguments; kept for config compatibility + exact_pad: bool = False, # Deprecated arguments; kept for config compatibility + nb_augmentation_prob: float = 0.0, # Deprecated arguments; kept for config compatibility + nb_max_freq: int = 4000, # Deprecated arguments; kept for config compatibility + mag_power: float = 2.0, # Deprecated arguments; kept for config compatibility + rng: Optional[random.Random] = None, # Deprecated arguments; kept for config compatibility + stft_exact_pad: bool = False, # Deprecated arguments; kept for config compatibility + stft_conv: bool = False, # Deprecated arguments; kept for config compatibility + ): + super().__init__() + if not HAVE_TORCHAUDIO: + raise ValueError(f"Need to install torchaudio to instantiate a {self.__class__.__name__}") + + # Make sure log zero guard is supported, if given as a string + supported_log_zero_guard_strings = {"eps", "tiny"} + if isinstance(log_zero_guard_value, str) and log_zero_guard_value not in supported_log_zero_guard_strings: + raise ValueError( + f"Log zero guard value must either be a float or a member of {supported_log_zero_guard_strings}" + ) + + # Copied from `AudioPreprocessor` due to the ad-hoc structuring of the Mel Spec extractor class + self.torch_windows = { + 'hann': torch.hann_window, + 'hamming': torch.hamming_window, + 'blackman': torch.blackman_window, + 'bartlett': torch.bartlett_window, + 'ones': torch.ones, + None: torch.ones, + } + + # Ensure we can look up the window function + if window not in self.torch_windows: + raise ValueError(f"Got window value '{window}' but expected a member of {self.torch_windows.keys()}") + + self._win_length = n_window_size + self._hop_length = n_window_stride + self._sample_rate = sample_rate + self._normalize_strategy = normalize + self._use_log = log + self._preemphasis_value = preemph + self._log_zero_guard_type = log_zero_guard_type + self._log_zero_guard_value: Union[str, float] = log_zero_guard_value + self._dither_value = dither + self._pad_to = pad_to + self._pad_value = pad_value + self._num_fft = n_fft + self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = torchaudio.transforms.MelSpectrogram( + sample_rate=self._sample_rate, + win_length=self._win_length, + hop_length=self._hop_length, + n_mels=nfilt, + window_fn=self.torch_windows[window], + mel_scale="slaney", + norm="slaney", + n_fft=n_fft, + f_max=highfreq, + f_min=lowfreq, + wkwargs={"periodic": False}, + ) + + @property + def filter_banks(self): + """ Matches the analogous class """ + return self._mel_spec_extractor.mel_scale.fb + + def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float: + if isinstance(self._log_zero_guard_value, float): + return self._log_zero_guard_value + return getattr(torch.finfo(dtype), self._log_zero_guard_value) + + def _apply_dithering(self, signals: torch.Tensor) -> torch.Tensor: + if self.training and self._dither_value > 0.0: + noise = torch.randn_like(signals) * self._dither_value + signals = signals + noise + return signals + + def _apply_preemphasis(self, signals: torch.Tensor) -> torch.Tensor: + if self._preemphasis_value is not None: + padded = torch.nn.functional.pad(signals, (1, 0)) + signals = signals - self._preemphasis_value * padded[:, :-1] + return signals + + def _compute_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor: + out_lengths = input_lengths.div(self._hop_length, rounding_mode="floor").add(1).long() + return out_lengths + + def _apply_pad_to(self, features: torch.Tensor) -> torch.Tensor: + # Only apply during training; else need to capture dynamic shape for exported models + if not self.training or self._pad_to == 0 or features.shape[-1] % self._pad_to == 0: + return features + pad_length = self._pad_to - (features.shape[-1] % self._pad_to) + return torch.nn.functional.pad(features, pad=(0, pad_length), value=self._pad_value) + + def _apply_log(self, features: torch.Tensor) -> torch.Tensor: + if self._use_log: + zero_guard = self._resolve_log_zero_guard_value(features.dtype) + if self._log_zero_guard_type == "add": + features = features + zero_guard + elif self._log_zero_guard_type == "clamp": + features = features.clamp(min=zero_guard) + else: + raise ValueError(f"Unsupported log zero guard type: '{self._log_zero_guard_type}'") + features = features.log() + return features + + def _extract_spectrograms(self, signals: torch.Tensor) -> torch.Tensor: + # Complex FFT needs to be done in single precision + with torch.cuda.amp.autocast(enabled=False): + features = self._mel_spec_extractor(waveform=signals) + return features + + def _apply_normalization(self, features: torch.Tensor, lengths: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + # For consistency, this function always does a masked fill even if not normalizing. + mask: torch.Tensor = make_seq_mask_like(lengths=lengths, like=features, time_dim=-1, valid_ones=False) + features = features.masked_fill(mask, 0.0) + # Maybe don't normalize + if self._normalize_strategy is None: + return features + # Use the log zero guard for the sqrt zero guard + guard_value = self._resolve_log_zero_guard_value(features.dtype) + if self._normalize_strategy == "per_feature" or self._normalize_strategy == "all_features": + # 'all_features' reduces over each sample; 'per_feature' reduces over each channel + reduce_dim = 2 + if self._normalize_strategy == "all_features": + reduce_dim = [1, 2] + # [B, D, T] -> [B, D, 1] or [B, 1, 1] + means = features.sum(dim=reduce_dim, keepdim=True).div(lengths.view(-1, 1, 1)) + stds = ( + features.sub(means) + .masked_fill(mask, 0.0) + .pow(2.0) + .sum(dim=reduce_dim, keepdim=True) # [B, D, T] -> [B, D, 1] or [B, 1, 1] + .div(lengths.view(-1, 1, 1) - 1) # assume biased estimator + .clamp(min=guard_value) # avoid sqrt(0) + .sqrt() + ) + features = (features - means) / (stds + eps) + else: + # Deprecating constant std/mean + raise ValueError(f"Unsupported norm type: '{self._normalize_strategy}") + features = features.masked_fill(mask, 0.0) + return features + + def forward(self, input_signal: torch.Tensor, length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + feature_lengths = self._compute_output_lengths(input_lengths=length) + signals = self._apply_dithering(signals=input_signal) + signals = self._apply_preemphasis(signals=signals) + features = self._extract_spectrograms(signals=signals) + features = self._apply_log(features=features) + features = self._apply_normalization(features=features, lengths=feature_lengths) + features = self._apply_pad_to(features=features) + return features, feature_lengths