Skip to content

Commit

Permalink
[RNN-T] transducer's simple_lm_proj shoud be predictor.output_size an…
Browse files Browse the repository at this point in the history
…d delay_panalty is 0 when step less than 2*warmup_step (#1940)

* when use k2, predictor's output_size
* fix some delay_penalty bug
  • Loading branch information
Zth9730 authored Aug 11, 2023
1 parent 886f88f commit cbd87d1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
13 changes: 13 additions & 0 deletions wenet/transducer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def cache_to_batch(self,
_ = cache
raise NotImplementedError("this is a base precictor")

def output_size(self):
raise NotImplementedError("this is a base precictor")

def forward(
self,
input: torch.Tensor,
Expand Down Expand Up @@ -69,6 +72,7 @@ def __init__(self,
super().__init__()
self.n_layers = num_layers
self.hidden_size = hidden_size
self._output_size = output_size
# disable rnn base out projection
self.embed = nn.Embedding(voca_size, embed_size)
self.dropout = nn.Dropout(embed_dropout)
Expand All @@ -83,6 +87,9 @@ def __init__(self,
dropout=dropout)
self.projection = nn.Linear(hidden_size, output_size)

def output_size(self):
return self._output_size

def forward(
self,
input: torch.Tensor,
Expand Down Expand Up @@ -232,6 +239,9 @@ def __init__(self,
self.norm = nn.LayerNorm(self.embed_size, eps=layer_norm_epsilon)
self.activatoin = get_activation(activation)

def output_size(self):
return self.embed_size

def init_state(self,
batch_size: int,
device: torch.device,
Expand Down Expand Up @@ -390,6 +400,9 @@ def __init__(self,
self.norm = nn.LayerNorm(embed_size, eps=layer_norm_epsilon)
self.activatoin = get_activation(activation)

def output_size(self):
return self.embed_size

def init_state(self,
batch_size: int,
device: torch.device,
Expand Down
4 changes: 2 additions & 2 deletions wenet/transducer/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
self.simple_am_proj = torch.nn.Linear(
self.encoder.output_size(), vocab_size)
self.simple_lm_proj = torch.nn.Linear(
self.predictor.embed_size, vocab_size)
self.predictor.output_size(), vocab_size)

# Note(Mddct): decoder also means predictor in transducer,
# but here decoder is attention decoder
Expand Down Expand Up @@ -512,7 +512,7 @@ def compute_loss(model: Transducer,
reduction="mean")
else:
delay_penalty = model.delay_penalty
if steps > 2 * model.warmup_steps:
if steps < 2 * model.warmup_steps:
delay_penalty = 0.00
ys_in_pad = ys_in_pad.type(torch.int64)
boundary = torch.zeros((encoder_out.size(0), 4),
Expand Down

0 comments on commit cbd87d1

Please sign in to comment.