From 872e24306eece0e190c1b6d821dec1fd7ecfe8ab Mon Sep 17 00:00:00 2001 From: Mddct Date: Fri, 8 Mar 2024 14:42:13 +0800 Subject: [PATCH 1/2] [transformer] add gated mlp --- wenet/transformer/decoder.py | 9 +++-- wenet/transformer/encoder.py | 17 +++++---- .../transformer/positionwise_feed_forward.py | 36 +++++++++++++++++++ wenet/utils/class_utils.py | 11 +++++- 4 files changed, 62 insertions(+), 11 deletions(-) diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index a179814ec..ff4d932f3 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -20,11 +20,11 @@ import logging from wenet.transformer.decoder_layer import DecoderLayer -from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward from wenet.utils.class_utils import ( WENET_EMB_CLASSES, WENET_ATTENTION_CLASSES, WENET_ACTIVATION_CLASSES, + WENET_MLP_CLASSES, ) from wenet.utils.common import mask_to_bias from wenet.utils.mask import (subsequent_mask, make_pad_mask) @@ -80,6 +80,7 @@ def __init__( gradient_checkpointing: bool = False, tie_word_embedding: bool = False, use_sdpa: bool = False, + mlp_type: str = 'position_wise_feed_forward', ): super().__init__() attention_dim = encoder_output_size @@ -100,6 +101,8 @@ def __init__( else: self.output_layer = torch.nn.Identity() self.num_blocks = num_blocks + + mlp_class = WENET_MLP_CLASSES[mlp_type] self.decoders = torch.nn.ModuleList([ DecoderLayer( attention_dim, @@ -111,8 +114,8 @@ def __init__( attention_heads, attention_dim, src_attention_dropout_rate, query_bias, key_bias, value_bias, use_sdpa) if src_attention else None, - PositionwiseFeedForward(attention_dim, linear_units, - dropout_rate, activation, mlp_bias), + mlp_class(attention_dim, linear_units, dropout_rate, + activation, mlp_bias), dropout_rate, normalize_before, ) for _ in range(self.num_blocks) diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index fb4d81e98..705609b4a 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -22,9 +22,9 @@ from wenet.transformer.convolution import ConvolutionModule from wenet.transformer.encoder_layer import TransformerEncoderLayer from wenet.transformer.encoder_layer import ConformerEncoderLayer -from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward from wenet.utils.class_utils import ( WENET_EMB_CLASSES, + WENET_MLP_CLASSES, WENET_SUBSAMPLE_CLASSES, WENET_ATTENTION_CLASSES, WENET_ACTIVATION_CLASSES, @@ -367,6 +367,7 @@ def __init__( activation_type: str = "relu", gradient_checkpointing: bool = False, use_sdpa: bool = False, + mlp_type: str = 'position_wise_feed_forward', ): """ Construct TransformerEncoder @@ -380,6 +381,7 @@ def __init__( use_dynamic_left_chunk, gradient_checkpointing, use_sdpa) activation = WENET_ACTIVATION_CLASSES[activation_type]() + mlp_class = WENET_MLP_CLASSES[mlp_type] self.encoders = torch.nn.ModuleList([ TransformerEncoderLayer( output_size, @@ -388,9 +390,9 @@ def __init__( attention_dropout_rate, query_bias, key_bias, value_bias, use_sdpa), - PositionwiseFeedForward(output_size, linear_units, - dropout_rate, activation, mlp_bias), - dropout_rate, normalize_before) for _ in range(num_blocks) + mlp_class(output_size, linear_units, dropout_rate, activation, + mlp_bias), dropout_rate, normalize_before) + for _ in range(num_blocks) ]) @@ -429,6 +431,7 @@ def __init__( conv_bias: bool = True, gradient_checkpointing: bool = False, use_sdpa: bool = False, + mlp_type: str = 'position_wise_feed_forward', ): """Construct ConformerEncoder @@ -478,14 +481,14 @@ def __init__( convolution_layer_args = (output_size, cnn_module_kernel, activation, cnn_module_norm, causal, conv_bias) + mlp_class = WENET_MLP_CLASSES[mlp_type] self.encoders = torch.nn.ModuleList([ ConformerEncoderLayer( output_size, WENET_ATTENTION_CLASSES[selfattention_layer_type]( *encoder_selfattn_layer_args), - PositionwiseFeedForward(*positionwise_layer_args), - PositionwiseFeedForward( - *positionwise_layer_args) if macaron_style else None, + mlp_class(*positionwise_layer_args), + mlp_class(*positionwise_layer_args) if macaron_style else None, ConvolutionModule( *convolution_layer_args) if use_cnn_module else None, dropout_rate, diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index fda25eba5..7d6ab3251 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -116,3 +116,39 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( xs[batch_idx]) return output.view(B, L, D) + + +class GatedVariantsMLP(torch.nn.Module): + """ https://arxiv.org/pdf/2002.05202.pdf + """ + + def __init__( + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.GELU(), + bias: bool = True, + ): + """Construct a PositionwiseFeedForward object.""" + super(GatedVariantsMLP, self).__init__() + self.gate = torch.nn.Linear(idim, hidden_units, bias=False) + self.activation = activation + # w_1 as up proj + self.w_1 = torch.nn.Linear(idim, hidden_units, bias=bias) + self.dropout = torch.nn.Dropout(dropout_rate) + # w_2 as down proj + self.w_2 = torch.nn.Linear(hidden_units, idim, bias=bias) + + def forward(self, x): + """Foward function. + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + + """ + gate = self.activation(self.gate(x)) + up = self.w_1(x) + fuse = gate * up + return self.w_2(self.dropout(fuse)) diff --git a/wenet/utils/class_utils.py b/wenet/utils/class_utils.py index 51f22df58..3123dc4f7 100644 --- a/wenet/utils/class_utils.py +++ b/wenet/utils/class_utils.py @@ -3,6 +3,8 @@ # Copyright [2023-11-28] import torch from wenet.paraformer.embedding import ParaformerPositinoalEncoding +from wenet.transformer.positionwise_feed_forward import (GatedVariantsMLP, + MoEFFNLayer) from wenet.transformer.swish import Swish from wenet.transformer.subsampling import ( @@ -23,7 +25,8 @@ from wenet.transformer.attention import (MultiHeadedAttention, MultiHeadedCrossAttention, RelPositionMultiHeadedAttention) -from wenet.efficient_conformer.attention import GroupedRelPositionMultiHeadedAttention +from wenet.efficient_conformer.attention import ( + GroupedRelPositionMultiHeadedAttention) WENET_ACTIVATION_CLASSES = { "hardtanh": torch.nn.Hardtanh, @@ -68,3 +71,9 @@ "grouped_rel_selfattn": GroupedRelPositionMultiHeadedAttention, "crossattn": MultiHeadedCrossAttention, } + +WENET_MLP_CLASSES = { + 'position_wise_feed_forward': PositionalEncoding, + 'moe': MoEFFNLayer, + 'gated': GatedVariantsMLP +} From ccde44cedb90b937b98a36485dc53b371570abbd Mon Sep 17 00:00:00 2001 From: Mddct Date: Fri, 8 Mar 2024 15:19:43 +0800 Subject: [PATCH 2/2] fix ut --- wenet/utils/class_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wenet/utils/class_utils.py b/wenet/utils/class_utils.py index 3123dc4f7..314e1b826 100644 --- a/wenet/utils/class_utils.py +++ b/wenet/utils/class_utils.py @@ -3,8 +3,8 @@ # Copyright [2023-11-28] import torch from wenet.paraformer.embedding import ParaformerPositinoalEncoding -from wenet.transformer.positionwise_feed_forward import (GatedVariantsMLP, - MoEFFNLayer) +from wenet.transformer.positionwise_feed_forward import ( + GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward) from wenet.transformer.swish import Swish from wenet.transformer.subsampling import ( @@ -73,7 +73,7 @@ } WENET_MLP_CLASSES = { - 'position_wise_feed_forward': PositionalEncoding, + 'position_wise_feed_forward': PositionwiseFeedForward, 'moe': MoEFFNLayer, 'gated': GatedVariantsMLP }