Skip to content

Commit

Permalink
[cif] Deal With Replicate Codes (#1749)
Browse files Browse the repository at this point in the history
* cif related

* solve flake8 warnings

* [cif]add copyright

* [cif] modify some hints

* [cif] deal with replicate codes

---------

Co-authored-by: root <root@asr-wanghe9-sbyesj-0.asr-wanghe9-sbyesj.lizr-a.svc.dev00-bcebj.local>
  • Loading branch information
MrSupW and root authored Mar 15, 2023
1 parent f047f8c commit 0c9e40f
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 393 deletions.
6 changes: 3 additions & 3 deletions wenet/cif/asr_cif_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from wenet.transformer.encoder import TransformerEncoder
from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss
from wenet.cif.predictor import MAELoss
from wenet.cif.utils import make_pad_mask
from wenet.utils.mask import make_pad_mask
from wenet.cif.search.beam_search import Hypothesis
from wenet.utils.common import (IGNORE_ID, add_sos_eos, log_add,
remove_duplicates_and_blank, th_accuracy)
Expand Down Expand Up @@ -139,7 +139,7 @@ def _calc_att_loss(
) -> Tuple[torch.Tensor, float, torch.Tensor]:
encoder_out_mask = (~make_pad_mask(
encoder_out_lens,
maxlen=encoder_out.size(1))[:, None, :]).to(encoder_out.device)
max_len=encoder_out.size(1))[:, None, :]).to(encoder_out.device)
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
Expand Down Expand Up @@ -167,7 +167,7 @@ def _calc_att_loss(
def calc_predictor(self, encoder_out, encoder_out_lens):

encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))
~make_pad_mask(encoder_out_lens, max_len=encoder_out.size(1))
[:, None, :]).to(encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = \
self.predictor(
Expand Down
107 changes: 1 addition & 106 deletions wenet/cif/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,117 +13,12 @@
# under the License. Modified from
# FunASR(https://github.com/alibaba-damo-academy/FunASR)

import math
from typing import Optional, Tuple
from typing import Optional

import torch
from torch import nn


class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""

def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)

def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)

return q, k, v

def forward_attention(
self, value: torch.Tensor, scores: torch.Tensor,
mask: Optional[torch.Tensor]
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k)
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2)
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2)
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)

p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)

def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, mask: Optional[torch.Tensor],
) -> torch.Tensor:
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)


class MultiHeadedAttentionSANMDecoder(nn.Module):
"""Multi-Head Attention layer.
Expand Down
17 changes: 10 additions & 7 deletions wenet/cif/cif_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

from typeguard import check_argument_types

from wenet.cif.utils import make_pad_mask, sequence_mask
from wenet.cif.attention import MultiHeadedAttention, \
MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from wenet.cif.decoder_layer import DecoderLayer, DecoderLayerSANM
from wenet.cif.embedding import PositionalEncoding
from wenet.utils.mask import make_pad_mask
from wenet.cif.utils import sequence_mask
from wenet.transformer.attention import MultiHeadedAttention
from wenet.cif.attention import MultiHeadedAttentionSANMDecoder,\
MultiHeadedAttentionCrossAtt
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.cif.decoder_layer import DecoderLayerSANM
from wenet.transformer.embedding import PositionalEncoding
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.cif.positionwise_feed_forward import \
PositionwiseFeedForwardDecoderSANM
Expand Down Expand Up @@ -130,7 +133,7 @@ def forward(

memory = hs_pad
memory_mask = (~make_pad_mask(
hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
hlens, max_len=memory.size(1)))[:, None, :].to(memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
Expand Down Expand Up @@ -239,7 +242,7 @@ def forward(

memory = hs_pad
memory_mask = (~make_pad_mask(hlens,
maxlen=memory.size(1)))[:, None, :] \
max_len=memory.size(1)))[:, None, :] \
.to(memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
Expand Down
137 changes: 0 additions & 137 deletions wenet/cif/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,143 +19,6 @@
import torch.nn as nn


class DecoderLayer(nn.Module):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear`
instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first
block.
concat_after (bool): Whether to concat attention layer's input and
output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""

def __init__(
self,
size: int,
self_attn: nn.Module,
src_attn: nn.Module,
feed_forward: nn.Module,
dropout_rate: float,
normalize_before: bool = True,
concat_after: bool = False,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-12)
self.norm2 = nn.LayerNorm(size, eps=1e-12)
self.norm3 = nn.LayerNorm(size, eps=1e-12)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
else:
self.concat_linear1 = nn.Identity()
self.concat_linear2 = nn.Identity()

def forward(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
cache: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in,
size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)

if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]

if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt,
tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)

residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory,
memory_mask))
if not self.normalize_before:
x = self.norm2(x)

residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)

if cache is not None:
x = torch.cat([cache, x], dim=1)

return x, tgt_mask, memory, memory_mask


class DecoderLayerSANM(nn.Module):
"""Single decoder layer module.
Expand Down
Loading

0 comments on commit 0c9e40f

Please sign in to comment.