Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer] add gated mlp #2395

Merged
merged 2 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import logging

from wenet.transformer.decoder_layer import DecoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import (
WENET_EMB_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
WENET_MLP_CLASSES,
)
from wenet.utils.common import mask_to_bias
from wenet.utils.mask import (subsequent_mask, make_pad_mask)
Expand Down Expand Up @@ -80,6 +80,7 @@ def __init__(
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
):
super().__init__()
attention_dim = encoder_output_size
Expand All @@ -100,6 +101,8 @@ def __init__(
else:
self.output_layer = torch.nn.Identity()
self.num_blocks = num_blocks

mlp_class = WENET_MLP_CLASSES[mlp_type]
self.decoders = torch.nn.ModuleList([
DecoderLayer(
attention_dim,
Expand All @@ -111,8 +114,8 @@ def __init__(
attention_heads, attention_dim, src_attention_dropout_rate,
query_bias, key_bias, value_bias, use_sdpa)
if src_attention else None,
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate, activation, mlp_bias),
mlp_class(attention_dim, linear_units, dropout_rate,
activation, mlp_bias),
dropout_rate,
normalize_before,
) for _ in range(self.num_blocks)
Expand Down
17 changes: 10 additions & 7 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from wenet.transformer.convolution import ConvolutionModule
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.encoder_layer import ConformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import (
WENET_EMB_CLASSES,
WENET_MLP_CLASSES,
WENET_SUBSAMPLE_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
Expand Down Expand Up @@ -367,6 +367,7 @@ def __init__(
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
):
""" Construct TransformerEncoder

Expand All @@ -380,6 +381,7 @@ def __init__(
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
Expand All @@ -388,9 +390,9 @@ def __init__(
attention_dropout_rate,
query_bias, key_bias,
value_bias, use_sdpa),
PositionwiseFeedForward(output_size, linear_units,
dropout_rate, activation, mlp_bias),
dropout_rate, normalize_before) for _ in range(num_blocks)
mlp_class(output_size, linear_units, dropout_rate, activation,
mlp_bias), dropout_rate, normalize_before)
for _ in range(num_blocks)
])


Expand Down Expand Up @@ -429,6 +431,7 @@ def __init__(
conv_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
):
"""Construct ConformerEncoder

Expand Down Expand Up @@ -478,14 +481,14 @@ def __init__(
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal, conv_bias)

mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args),
PositionwiseFeedForward(*positionwise_layer_args),
PositionwiseFeedForward(
*positionwise_layer_args) if macaron_style else None,
mlp_class(*positionwise_layer_args),
mlp_class(*positionwise_layer_args) if macaron_style else None,
ConvolutionModule(
*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
Expand Down
36 changes: 36 additions & 0 deletions wenet/transformer/positionwise_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,39 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor:
output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
xs[batch_idx])
return output.view(B, L, D)


class GatedVariantsMLP(torch.nn.Module):
""" https://arxiv.org/pdf/2002.05202.pdf
"""

def __init__(
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.GELU(),
bias: bool = True,
):
"""Construct a PositionwiseFeedForward object."""
super(GatedVariantsMLP, self).__init__()
self.gate = torch.nn.Linear(idim, hidden_units, bias=False)
self.activation = activation
# w_1 as up proj
self.w_1 = torch.nn.Linear(idim, hidden_units, bias=bias)
self.dropout = torch.nn.Dropout(dropout_rate)
# w_2 as down proj
self.w_2 = torch.nn.Linear(hidden_units, idim, bias=bias)

def forward(self, x):
"""Foward function.
Args:
xs: input tensor (B, L, D)
Returns:
output tensor, (B, L, D)

"""
gate = self.activation(self.gate(x))
up = self.w_1(x)
fuse = gate * up
return self.w_2(self.dropout(fuse))
11 changes: 10 additions & 1 deletion wenet/utils/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright [2023-11-28] <[email protected], Xingchen Song>
import torch
from wenet.paraformer.embedding import ParaformerPositinoalEncoding
from wenet.transformer.positionwise_feed_forward import (
GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward)

from wenet.transformer.swish import Swish
from wenet.transformer.subsampling import (
Expand All @@ -23,7 +25,8 @@
from wenet.transformer.attention import (MultiHeadedAttention,
MultiHeadedCrossAttention,
RelPositionMultiHeadedAttention)
from wenet.efficient_conformer.attention import GroupedRelPositionMultiHeadedAttention
from wenet.efficient_conformer.attention import (
GroupedRelPositionMultiHeadedAttention)

WENET_ACTIVATION_CLASSES = {
"hardtanh": torch.nn.Hardtanh,
Expand Down Expand Up @@ -68,3 +71,9 @@
"grouped_rel_selfattn": GroupedRelPositionMultiHeadedAttention,
"crossattn": MultiHeadedCrossAttention,
}

WENET_MLP_CLASSES = {
'position_wise_feed_forward': PositionwiseFeedForward,
'moe': MoEFFNLayer,
'gated': GatedVariantsMLP
}
Loading