Skip to content

Commit

Permalink
gated mlp works
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Feb 22, 2024
1 parent bbedafb commit 8c576bb
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 34 deletions.
15 changes: 10 additions & 5 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import logging

from wenet.transformer.decoder_layer import DecoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import (
WENET_EMB_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
WENET_MLP_CLASSES,
)
from wenet.utils.common import mask_to_bias
from wenet.utils.mask import (subsequent_mask, make_pad_mask)
Expand Down Expand Up @@ -75,6 +75,7 @@ def __init__(
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
):
super().__init__()
attention_dim = encoder_output_size
Expand All @@ -95,6 +96,7 @@ def __init__(
else:
self.output_layer = torch.nn.Identity()
self.num_blocks = num_blocks
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.decoders = torch.nn.ModuleList([
DecoderLayer(
attention_dim,
Expand All @@ -104,8 +106,8 @@ def __init__(
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim, src_attention_dropout_rate,
key_bias, use_sdpa) if src_attention else None,
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate, activation),
mlp_class(attention_dim, linear_units, dropout_rate,
activation),
dropout_rate,
normalize_before,
) for _ in range(self.num_blocks)
Expand Down Expand Up @@ -298,6 +300,7 @@ def __init__(
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
):

super().__init__()
Expand All @@ -318,7 +321,8 @@ def __init__(
key_bias=key_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa)
use_sdpa=use_sdpa,
mlp_type=mlp_type)

self.right_decoder = TransformerDecoder(
vocab_size,
Expand All @@ -336,7 +340,8 @@ def __init__(
key_bias=key_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa)
use_sdpa=use_sdpa,
mlp_type=mlp_type)

def forward(
self,
Expand Down
59 changes: 30 additions & 29 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from wenet.transformer.convolution import ConvolutionModule
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.encoder_layer import ConformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import (
WENET_EMB_CLASSES,
WENET_MLP_CLASSES,
WENET_SUBSAMPLE_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
Expand Down Expand Up @@ -341,28 +341,27 @@ def forward_chunk_by_chunk(
class TransformerEncoder(BaseEncoder):
"""Transformer encoder module."""

def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
key_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
):
def __init__(self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
key_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward'):
""" Construct TransformerEncoder
See Encoder for the meaning of each parameter.
Expand All @@ -375,16 +374,17 @@ def __init__(
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
output_size,
attention_dropout_rate,
key_bias, use_sdpa),
PositionwiseFeedForward(output_size, linear_units,
dropout_rate, activation),
dropout_rate, normalize_before) for _ in range(num_blocks)
mlp_class(output_size, linear_units, dropout_rate,
activation), dropout_rate, normalize_before)
for _ in range(num_blocks)
])


Expand Down Expand Up @@ -419,6 +419,7 @@ def __init__(
key_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
):
"""Construct ConformerEncoder
Expand Down Expand Up @@ -465,14 +466,14 @@ def __init__(
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal)

mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args),
PositionwiseFeedForward(*positionwise_layer_args),
PositionwiseFeedForward(
*positionwise_layer_args) if macaron_style else None,
mlp_class(*positionwise_layer_args),
mlp_class(*positionwise_layer_args) if macaron_style else None,
ConvolutionModule(
*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
Expand Down
8 changes: 8 additions & 0 deletions wenet/utils/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright [2023-11-28] <[email protected], Xingchen Song>
import torch
from wenet.paraformer.embedding import ParaformerPositinoalEncoding
from wenet.transformer.positionwise_feed_forward import (
GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward)

from wenet.transformer.swish import Swish
from wenet.transformer.subsampling import (
Expand Down Expand Up @@ -66,3 +68,9 @@
"rel_selfattn": RelPositionMultiHeadedAttention,
"grouped_rel_selfattn": GroupedRelPositionMultiHeadedAttention,
}

WENET_MLP_CLASSES = {
'position_wise_feed_forward': PositionwiseFeedForward,
'moe': MoEFFNLayer,
'gated': GatedVariantsMLP
}

0 comments on commit 8c576bb

Please sign in to comment.