Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIF] Deal With Replicate Codes #1749

Merged
merged 7 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions wenet/cif/asr_cif_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from wenet.transformer.encoder import TransformerEncoder
from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss
from wenet.cif.predictor import MAELoss
from wenet.cif.utils import make_pad_mask
from wenet.utils.mask import make_pad_mask
from wenet.cif.search.beam_search import Hypothesis
from wenet.utils.common import (IGNORE_ID, add_sos_eos, log_add,
remove_duplicates_and_blank, th_accuracy)
Expand Down Expand Up @@ -139,7 +139,7 @@ def _calc_att_loss(
) -> Tuple[torch.Tensor, float, torch.Tensor]:
encoder_out_mask = (~make_pad_mask(
encoder_out_lens,
maxlen=encoder_out.size(1))[:, None, :]).to(encoder_out.device)
max_len=encoder_out.size(1))[:, None, :]).to(encoder_out.device)
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
Expand Down Expand Up @@ -167,7 +167,7 @@ def _calc_att_loss(
def calc_predictor(self, encoder_out, encoder_out_lens):

encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))
~make_pad_mask(encoder_out_lens, max_len=encoder_out.size(1))
[:, None, :]).to(encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = \
self.predictor(
Expand Down
107 changes: 1 addition & 106 deletions wenet/cif/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,117 +13,12 @@
# under the License. Modified from
# FunASR(https://github.com/alibaba-damo-academy/FunASR)

import math
from typing import Optional, Tuple
from typing import Optional

import torch
from torch import nn


class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.

Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.

"""

def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)

def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.

Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).

Returns:
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).

"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)

return q, k, v

def forward_attention(
self, value: torch.Tensor, scores: torch.Tensor,
mask: Optional[torch.Tensor]
) -> torch.Tensor:
"""Compute attention context vector.

Args:
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k)
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2)
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2)

Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).

"""
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)

p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)

def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, mask: Optional[torch.Tensor],
) -> torch.Tensor:
"""Compute scaled dot product attention.

Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).

Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).

"""
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)


class MultiHeadedAttentionSANMDecoder(nn.Module):
"""Multi-Head Attention layer.

Expand Down
17 changes: 10 additions & 7 deletions wenet/cif/cif_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

from typeguard import check_argument_types

from wenet.cif.utils import make_pad_mask, sequence_mask
from wenet.cif.attention import MultiHeadedAttention, \
MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from wenet.cif.decoder_layer import DecoderLayer, DecoderLayerSANM
from wenet.cif.embedding import PositionalEncoding
from wenet.utils.mask import make_pad_mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please sort the import order by alphabet

from wenet.cif.utils import sequence_mask
from wenet.transformer.attention import MultiHeadedAttention
from wenet.cif.attention import MultiHeadedAttentionSANMDecoder,\
MultiHeadedAttentionCrossAtt
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.cif.decoder_layer import DecoderLayerSANM
from wenet.transformer.embedding import PositionalEncoding
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.cif.positionwise_feed_forward import \
PositionwiseFeedForwardDecoderSANM
Expand Down Expand Up @@ -130,7 +133,7 @@ def forward(

memory = hs_pad
memory_mask = (~make_pad_mask(
hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
hlens, max_len=memory.size(1)))[:, None, :].to(memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
Expand Down Expand Up @@ -239,7 +242,7 @@ def forward(

memory = hs_pad
memory_mask = (~make_pad_mask(hlens,
maxlen=memory.size(1)))[:, None, :] \
max_len=memory.size(1)))[:, None, :] \
.to(memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
Expand Down
137 changes: 0 additions & 137 deletions wenet/cif/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,143 +19,6 @@
import torch.nn as nn


class DecoderLayer(nn.Module):
"""Single decoder layer module.

Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear`
instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first
block.
concat_after (bool): Whether to concat attention layer's input and
output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)


"""

def __init__(
self,
size: int,
self_attn: nn.Module,
src_attn: nn.Module,
feed_forward: nn.Module,
dropout_rate: float,
normalize_before: bool = True,
concat_after: bool = False,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-12)
self.norm2 = nn.LayerNorm(size, eps=1e-12)
self.norm3 = nn.LayerNorm(size, eps=1e-12)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
else:
self.concat_linear1 = nn.Identity()
self.concat_linear2 = nn.Identity()

def forward(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
cache: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute decoded features.

Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in,
size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).

Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).

"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)

if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]

if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt,
tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)

residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory,
memory_mask))
if not self.normalize_before:
x = self.norm2(x)

residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)

if cache is not None:
x = torch.cat([cache, x], dim=1)

return x, tgt_mask, memory, memory_mask


class DecoderLayerSANM(nn.Module):
"""Single decoder layer module.

Expand Down
Loading