Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify ActivationBalancer for speed #612

Merged
merged 9 commits into from
Oct 13, 2022
1 change: 0 additions & 1 deletion .github/scripts/run-pre-trained-conformer-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ cd egs/librispeech/ASR

repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
git lfs install
git clone $repo

log "Downloading pre-trained model from $repo_url"
git clone $repo_url
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def forward(
value: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down Expand Up @@ -1059,7 +1059,7 @@ def multi_head_attention_forward(
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down
9 changes: 7 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import collections
import random
from itertools import repeat
from typing import Optional, Tuple

Expand Down Expand Up @@ -636,6 +637,7 @@ class ActivationBalancer(torch.nn.Module):
max_abs: the maximum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this.
balance_prob: the probability to apply the ActivationBalancer.
"""

def __init__(
Expand All @@ -646,6 +648,7 @@ def __init__(
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0,
balance_prob: float = 0.25,
):
super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim
Expand All @@ -654,17 +657,19 @@ def __init__(
self.max_factor = max_factor
self.min_abs = min_abs
self.max_abs = max_abs
assert 0 < balance_prob <= 1, balance_prob
self.balance_prob = balance_prob

def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or is_jit_tracing():
if random.random() >= self.balance_prob:
return x
else:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.max_factor / self.balance_prob,
self.min_abs,
self.max_abs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch
import torch.nn as nn
from scaling import (
ActivationBalancer,
BasicNorm,
ScaledConv1d,
ScaledConv2d,
Expand Down Expand Up @@ -294,6 +295,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = convert_basic_norm(m)
elif isinstance(m, ScaledLSTM):
d[name] = scaled_lstm_to_lstm(m)
elif isinstance(m, ActivationBalancer):
d[name] = nn.Identity()

for k, v in d.items():
if "." in k:
Expand Down