From cbd87d1fd416f09bd4a065816bd0d248a83f88b2 Mon Sep 17 00:00:00 2001 From: TianHao Zhang <32243340+Zth9730@users.noreply.github.com> Date: Fri, 11 Aug 2023 17:27:27 +0800 Subject: [PATCH] [RNN-T] transducer's simple_lm_proj shoud be predictor.output_size and delay_panalty is 0 when step less than 2*warmup_step (#1940) * when use k2, predictor's output_size * fix some delay_penalty bug --- wenet/transducer/predictor.py | 13 +++++++++++++ wenet/transducer/transducer.py | 4 ++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/wenet/transducer/predictor.py b/wenet/transducer/predictor.py index 3d1090246..141a3ec68 100644 --- a/wenet/transducer/predictor.py +++ b/wenet/transducer/predictor.py @@ -38,6 +38,9 @@ def cache_to_batch(self, _ = cache raise NotImplementedError("this is a base precictor") + def output_size(self): + raise NotImplementedError("this is a base precictor") + def forward( self, input: torch.Tensor, @@ -69,6 +72,7 @@ def __init__(self, super().__init__() self.n_layers = num_layers self.hidden_size = hidden_size + self._output_size = output_size # disable rnn base out projection self.embed = nn.Embedding(voca_size, embed_size) self.dropout = nn.Dropout(embed_dropout) @@ -83,6 +87,9 @@ def __init__(self, dropout=dropout) self.projection = nn.Linear(hidden_size, output_size) + def output_size(self): + return self._output_size + def forward( self, input: torch.Tensor, @@ -232,6 +239,9 @@ def __init__(self, self.norm = nn.LayerNorm(self.embed_size, eps=layer_norm_epsilon) self.activatoin = get_activation(activation) + def output_size(self): + return self.embed_size + def init_state(self, batch_size: int, device: torch.device, @@ -390,6 +400,9 @@ def __init__(self, self.norm = nn.LayerNorm(embed_size, eps=layer_norm_epsilon) self.activatoin = get_activation(activation) + def output_size(self): + return self.embed_size + def init_state(self, batch_size: int, device: torch.device, diff --git a/wenet/transducer/transducer.py b/wenet/transducer/transducer.py index c6b3b1f71..b57b21fdd 100644 --- a/wenet/transducer/transducer.py +++ b/wenet/transducer/transducer.py @@ -73,7 +73,7 @@ def __init__( self.simple_am_proj = torch.nn.Linear( self.encoder.output_size(), vocab_size) self.simple_lm_proj = torch.nn.Linear( - self.predictor.embed_size, vocab_size) + self.predictor.output_size(), vocab_size) # Note(Mddct): decoder also means predictor in transducer, # but here decoder is attention decoder @@ -512,7 +512,7 @@ def compute_loss(model: Transducer, reduction="mean") else: delay_penalty = model.delay_penalty - if steps > 2 * model.warmup_steps: + if steps < 2 * model.warmup_steps: delay_penalty = 0.00 ys_in_pad = ys_in_pad.type(torch.int64) boundary = torch.zeros((encoder_out.size(0), 4),