diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py new file mode 100644 index 0000000000..0746d0036a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py @@ -0,0 +1,287 @@ +import math +import random +from typing import Dict, Optional, Tuple + +import numpy as np +import torch + + +class SpecAugment(torch.nn.Module): + """ + SpecAugment performs three augmentations: + - time warping of the feature matrix + - masking of ranges of features (frequency bands) + - masking of ranges of frames (time) + + The current implementation works with batches, but processes each example separately + in a loop rather than simultaneously to achieve different augmentation parameters for + each example. + """ + + def __init__( + self, + time_warp_factor: Optional[int] = 80, + num_feature_masks: int = 2, + features_mask_size: int = 27, + num_frame_masks: int = 10, + frames_mask_size: int = 100, + max_frames_mask_fraction: float = 0.15, + p=0.9, + ): + """ + SpecAugment's constructor. + + :param time_warp_factor: parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + :param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable. + :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). + This is the ``F`` parameter from the SpecAugment paper. + :param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable. + :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). + This is the ``T`` parameter from the SpecAugment paper. + :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length + of the utterance (or supervision segment). + This is the parameter denoted by ``p`` in the SpecAugment paper. + :param p: the probability of applying this transform. + It is different from ``p`` in the SpecAugment paper! + """ + super().__init__() + assert 0 <= p <= 1 + assert num_feature_masks >= 0 + assert num_frame_masks > 0 + assert features_mask_size > 0 + assert frames_mask_size > 0 + self.time_warp_factor = time_warp_factor + self.num_feature_masks = num_feature_masks + self.features_mask_size = features_mask_size + self.num_frame_masks = num_frame_masks + self.frames_mask_size = frames_mask_size + self.max_frames_mask_fraction = max_frames_mask_fraction + self.p = p + + def forward( + self, + features: torch.Tensor, + supervision_segments: Optional[torch.IntTensor] = None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes SpecAugment for a batch of feature matrices. + + Since the batch will usually already be padded, the user can optionally + provide a ``supervision_segments`` tensor that will be used to apply SpecAugment + only to selected areas of the input. The format of this input is described below. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + :param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features`` -- there may be either + less or more than the batch size. + The second dimension encoder three kinds of information: + the sequence index of the corresponding feature matrix in `features`, + the start frame index, and the number of frames for each segment. + :return: an augmented tensor of shape ``(B, T, F)``. + """ + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of " + "single-channel feature matrices." + ) + features = features.clone() + # 1 (True) represents masked area; + # 0 (False) represents original un-masked area. + time_masked_area = torch.zeros_like(features) + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + ( + features[sequence_idx], + time_masked_area[sequence_idx], + ) = self._forward_single(features[sequence_idx]) + + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + end_frame = start_frame + num_frames + ( + features[sequence_idx, start_frame:end_frame], + time_masked_area[sequence_idx, start_frame:end_frame], + ) = self._forward_single( + features[sequence_idx, start_frame:end_frame], + warp=True, + mask=False, + ) + # ... and then time-mask the full feature matrices. Note that in this mode, + # it might happen that masks are applied to different sequences/examples + # than the time warping. + for sequence_idx in range(features.size(0)): + ( + features[sequence_idx], + time_masked_area[sequence_idx], + ) = self._forward_single( + features[sequence_idx], warp=False, mask=True + ) + + return features, time_masked_area + + def _forward_single( + self, features: torch.Tensor, warp: bool = True, mask: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply SpecAugment to a single feature matrix of shape (T, F). + """ + time_masked_area = torch.zeros_like(features) + if random.random() > self.p: + # Randomly choose whether this transform is applied + # No augmentation, no masked area. + return features, time_masked_area + if warp: + if self.time_warp_factor is not None and self.time_warp_factor >= 1: + features = time_warp(features, factor=self.time_warp_factor) + if mask: + mean = features.mean() + # Frequency masking + features, _ = mask_along_axis_optimized( + features, + mask_size=self.features_mask_size, + mask_times=self.num_feature_masks, + mask_value=mean, + axis=2, + ) + # Time masking + max_tot_mask_frames = self.max_frames_mask_fraction * features.size( + 0 + ) + num_frame_masks = min( + self.num_frame_masks, + math.ceil(max_tot_mask_frames / self.frames_mask_size), + ) + max_mask_frames = min( + self.frames_mask_size, max_tot_mask_frames // num_frame_masks + ) + features, time_masked_area = mask_along_axis_optimized( + features, + mask_size=max_mask_frames, + mask_times=num_frame_masks, + mask_value=mean, + axis=1, + ) + + return features, time_masked_area + + def state_dict(self) -> Dict: + return dict( + time_warp_factor=self.time_warp_factor, + num_feature_masks=self.num_feature_masks, + features_mask_size=self.features_mask_size, + num_frame_masks=self.num_frame_masks, + frames_mask_size=self.frames_mask_size, + max_frames_mask_fraction=self.max_frames_mask_fraction, + p=self.p, + ) + + def load_state_dict(self, state_dict: Dict): + self.time_warp_factor = state_dict.get( + "time_warp_factor", self.time_warp_factor + ) + self.num_feature_masks = state_dict.get( + "num_feature_masks", self.num_feature_masks + ) + self.features_mask_size = state_dict.get( + "features_mask_size", self.features_mask_size + ) + self.num_frame_masks = state_dict.get( + "num_frame_masks", self.num_frame_masks + ) + self.frames_mask_size = state_dict.get( + "frames_mask_size", self.frames_mask_size + ) + self.max_frames_mask_fraction = state_dict.get( + "max_frames_mask_fraction", self.max_frames_mask_fraction + ) + self.p = state_dict.get("p", self.p) + + +def mask_along_axis_optimized( + features: torch.Tensor, + mask_size: int, + mask_times: int, + mask_value: float, + axis: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply Frequency and Time masking along axis. + Frequency and Time masking as described in the SpecAugment paper. + + :param features: input tensor of shape ``(T, F)`` + :mask_size: the width size for masking. + :mask_times: the number of masking regions. + :mask_value: Value to assign to the masked regions. + :axis: Axis to apply masking on (1 -> time, 2 -> frequency) + """ + if axis not in [1, 2]: + raise ValueError("Only Frequency and Time masking are supported!") + + # 1 (True) represents masked area; + # 0 (False) represents original un-masked area. + masked_area = torch.zeros_like(features) + features = features.unsqueeze(0) + masked_area = masked_area.unsqueeze(0) + features = features.reshape([-1] + list(features.size()[-2:])) + + values = torch.randint(int(0), int(mask_size), (1, mask_times)) + min_values = torch.rand(1, mask_times) * (features.size(axis) - values) + mask_starts = (min_values.long()).squeeze() + mask_ends = (min_values.long() + values.long()).squeeze() + + if axis == 1: + if mask_times == 1: + features[:, mask_starts:mask_ends] = mask_value + return features.squeeze(0), masked_area + for (mask_start, mask_end) in zip(mask_starts, mask_ends): + features[:, mask_start:mask_end] = mask_value + masked_area[:, mask_start:mask_end] = 1 + else: + if mask_times == 1: + features[:, :, mask_starts:mask_ends] = mask_value + masked_area[:, :, mask_starts:mask_ends] = 1 + return features.squeeze(0), masked_area + for (mask_start, mask_end) in zip(mask_starts, mask_ends): + features[:, :, mask_start:mask_end] = mask_value + masked_area[:, :, mask_start:mask_end] = 1 + + features = features.squeeze(0) + masked_area = masked_area.squeeze(0) + return features, masked_area + + +def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: + """ + Time warping as described in the SpecAugment paper. + Implementation based on Espresso: + https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51 + + :param features: input tensor of shape ``(T, F)`` + :param factor: time warping parameter. + :return: a warped tensor of shape ``(T, F)`` + """ + t = features.size(0) + if t - factor <= factor + 1: + return features + center = np.random.randint(factor + 1, t - factor) + warped = np.random.randint(center - factor, center + factor + 1) + if warped == center: + return features + features = features.unsqueeze(0).unsqueeze(0) + left = torch.nn.functional.interpolate( + features[:, :, :center, :], + size=(warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + right = torch.nn.functional.interpolate( + features[:, :, center:, :], + size=(t - warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 66bb33e8d0..305049d692 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -23,7 +23,7 @@ from icefall.utils import add_sos -from quantization.prediction import JointCodebookLoss +from multi_quantization.prediction import JointCodebookLoss class Transducer(nn.Module): @@ -41,6 +41,8 @@ def __init__( joiner_dim: int, vocab_size: int, num_codebooks: int = 0, + masked_scale: float = 1.0, + unmasked_scale: float = 1.0, ): """ Args: @@ -60,6 +62,10 @@ def __init__( contains unnormalized probs, i.e., not processed by log-softmax. num_codebooks: Used by distillation loss. + masked_scale: + scale of codebook loss of masked area. + unmasked_scale: + scale of codebook loss of unmasked area. """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) @@ -75,8 +81,12 @@ def __init__( self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if num_codebooks > 0: self.codebook_loss_net = JointCodebookLoss( - predictor_channels=encoder_dim, num_codebooks=num_codebooks + predictor_channels=encoder_dim, + num_codebooks=num_codebooks, + reduction="none", ) + self.masked_scale = masked_scale + self.unmasked_scale = unmasked_scale def forward( self, @@ -88,6 +98,7 @@ def forward( lm_scale: float = 0.0, warmup: float = 1.0, codebook_indexes: torch.Tensor = None, + time_masked_area: torch.Tensor = None, ) -> torch.Tensor: """ Args: @@ -113,6 +124,8 @@ def forward( warmup > 1 "are fully warmed up" and all modules will be active. codebook_indexes: codebook_indexes extracted from a teacher model. + time_masked_area: + masked area by SpecAugment, 1 represents masked. Returns: Return the transducer loss. @@ -140,6 +153,22 @@ def forward( codebook_loss = self.codebook_loss_net( middle_layer_output, codebook_indexes ) + codebook_loss = codebook_loss.reshape(codebook_indexes.shape) + target_t = codebook_loss.shape[1] + time_masked_area = time_masked_area.bool() + time_masked_area = time_masked_area[ + :, : target_t * 4 : 4, 0 # noqa E203 + ] + assert time_masked_area.shape == codebook_loss.shape[:-1] + time_masked_area = time_masked_area.unsqueeze(2).to( + codebook_loss.device + ) + masked_loss = (time_masked_area * codebook_loss).sum() + unmasked_loss = (~time_masked_area * codebook_loss).sum() + codebook_loss = ( + self.masked_scale * masked_loss + + self.unmasked_scale * unmasked_loss + ) else: # when codebook index is not available. codebook_loss = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index feb58f457b..3cec4326e9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -177,6 +177,18 @@ def get_parser(): changed.""", ) + parser.add_argument( + "--masked-scale", + type=float, + default=1.0, + ) + + parser.add_argument( + "--unmasked-scale", + type=float, + default=1.0, + ) + parser.add_argument( "--lr-batches", type=float, @@ -378,6 +390,8 @@ def get_params() -> AttributeDict: # two successive codebook_index are concatenated together. # Detailed in function Transducer::concat_sucessive_codebook_indexes. "num_codebooks": 16, # used to construct distillation loss + "masked_scale": 1.0, + "unmasked_scale": 1.0, } ) @@ -436,6 +450,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: num_codebooks=params.num_codebooks if params.enable_distiallation else 0, + masked_scale=params.masked_scale, + unmasked_scale=params.unmasked_scale, ) return model @@ -602,7 +618,7 @@ def compute_loss( if isinstance(model, DDP) else next(model.parameters()).device ) - feature = batch["inputs"] + feature, time_masked_area = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) @@ -631,6 +647,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, codebook_indexes=codebook_indexes, + time_masked_area=time_masked_area, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid @@ -1089,7 +1106,9 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) + args.exp_dir = Path( + f"{args.exp_dir}-masked_scale-{args.masked_scale}-un-{args.unmasked_scale}-{args.spec_aug_max_frames_mask_fraction}" + ) world_size = args.world_size assert world_size >= 1 diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index e83009d4a0..b391b28150 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -32,7 +32,6 @@ K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, - SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, @@ -41,6 +40,7 @@ from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader +from aug import SpecAugment from icefall.utils import str2bool @@ -183,6 +183,12 @@ def add_arguments(cls, parser: argparse.ArgumentParser): help="When enabled, use SpecAugment for training dataset.", ) + group.add_argument( + "--spec-aug-max-frames-mask-fraction", + type=float, + default=0.15, + ) + group.add_argument( "--spec-aug-time-warp-factor", type=int, @@ -272,6 +278,7 @@ def train_dataloaders( features_mask_size=27, num_feature_masks=2, frames_mask_size=100, + max_frames_mask_fraction=self.args.spec_aug_max_frames_mask_fraction, ) ) else: