Skip to content

Commit

Permalink
Modify ActivationBalancer for speed (#612)
Browse files Browse the repository at this point in the history
* add a probability to apply ActivationBalancer

* minor fix

* minor fix
  • Loading branch information
yaozengwei authored Oct 13, 2022
1 parent 1c07d2f commit aa58c2e
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 5 deletions.
1 change: 0 additions & 1 deletion .github/scripts/run-pre-trained-conformer-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ cd egs/librispeech/ASR

repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
git lfs install
git clone $repo

log "Downloading pre-trained model from $repo_url"
git clone $repo_url
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def forward(
value: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down Expand Up @@ -1059,7 +1059,7 @@ def multi_head_attention_forward(
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down
9 changes: 7 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import collections
import random
from itertools import repeat
from typing import Optional, Tuple

Expand Down Expand Up @@ -636,6 +637,7 @@ class ActivationBalancer(torch.nn.Module):
max_abs: the maximum average-absolute-value per channel, which
we allow, before we start to modify the derivatives to prevent
this.
balance_prob: the probability to apply the ActivationBalancer.
"""

def __init__(
Expand All @@ -646,6 +648,7 @@ def __init__(
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0,
balance_prob: float = 0.25,
):
super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim
Expand All @@ -654,17 +657,19 @@ def __init__(
self.max_factor = max_factor
self.min_abs = min_abs
self.max_abs = max_abs
assert 0 < balance_prob <= 1, balance_prob
self.balance_prob = balance_prob

def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or is_jit_tracing():
if random.random() >= self.balance_prob:
return x
else:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.max_factor / self.balance_prob,
self.min_abs,
self.max_abs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch
import torch.nn as nn
from scaling import (
ActivationBalancer,
BasicNorm,
ScaledConv1d,
ScaledConv2d,
Expand Down Expand Up @@ -294,6 +295,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = convert_basic_norm(m)
elif isinstance(m, ScaledLSTM):
d[name] = scaled_lstm_to_lstm(m)
elif isinstance(m, ActivationBalancer):
d[name] = nn.Identity()

for k, v in d.items():
if "." in k:
Expand Down

0 comments on commit aa58c2e

Please sign in to comment.