From aa58c2ee02062fde5f4c455aa880babadb2fa17e Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Thu, 13 Oct 2022 15:14:28 +0800 Subject: [PATCH] Modify ActivationBalancer for speed (#612) * add a probability to apply ActivationBalancer * minor fix * minor fix --- .github/scripts/run-pre-trained-conformer-ctc.sh | 1 - .../ASR/pruned_transducer_stateless2/conformer.py | 4 ++-- .../ASR/pruned_transducer_stateless2/scaling.py | 9 +++++++-- .../pruned_transducer_stateless3/scaling_converter.py | 3 +++ 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 6368b0bbdd..96c3206161 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -12,7 +12,6 @@ cd egs/librispeech/ASR repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 git lfs install -git clone $repo log "Downloading pre-trained model from $repo_url" git clone $repo_url diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bce8a6bd19..b04a74a19f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -932,7 +932,7 @@ def forward( value: Tensor, pos_emb: Tensor, key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, + need_weights: bool = False, attn_mask: Optional[Tensor] = None, left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: @@ -1059,7 +1059,7 @@ def multi_head_attention_forward( out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, + need_weights: bool = False, attn_mask: Optional[Tensor] = None, left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 9f839cbe00..8c572a9ef1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -16,6 +16,7 @@ import collections +import random from itertools import repeat from typing import Optional, Tuple @@ -636,6 +637,7 @@ class ActivationBalancer(torch.nn.Module): max_abs: the maximum average-absolute-value per channel, which we allow, before we start to modify the derivatives to prevent this. + balance_prob: the probability to apply the ActivationBalancer. """ def __init__( @@ -646,6 +648,7 @@ def __init__( max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0, + balance_prob: float = 0.25, ): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim @@ -654,9 +657,11 @@ def __init__( self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs + assert 0 < balance_prob <= 1, balance_prob + self.balance_prob = balance_prob def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or is_jit_tracing(): + if random.random() >= self.balance_prob: return x else: return ActivationBalancerFunction.apply( @@ -664,7 +669,7 @@ def forward(self, x: Tensor) -> Tensor: self.channel_dim, self.min_positive, self.max_positive, - self.max_factor, + self.max_factor / self.balance_prob, self.min_abs, self.max_abs, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 43f5d409c6..f2f691eb1a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -30,6 +30,7 @@ import torch import torch.nn as nn from scaling import ( + ActivationBalancer, BasicNorm, ScaledConv1d, ScaledConv2d, @@ -294,6 +295,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): d[name] = convert_basic_norm(m) elif isinstance(m, ScaledLSTM): d[name] = scaled_lstm_to_lstm(m) + elif isinstance(m, ActivationBalancer): + d[name] = nn.Identity() for k, v in d.items(): if "." in k: