Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer] Add moe_noisy_gate #2495

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
gate_type: str = 'normal',
):
super().__init__()
attention_dim = encoder_output_size
Expand Down Expand Up @@ -131,7 +132,8 @@ def __init__(
activation,
mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated),
n_expert_activated=n_expert_activated,
gate_type=gate_type),
dropout_rate,
normalize_before,
layer_norm_type,
Expand Down Expand Up @@ -360,6 +362,7 @@ def __init__(
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
gate_type: str = 'normal'
):

super().__init__()
Expand Down Expand Up @@ -393,7 +396,8 @@ def __init__(
mlp_type=mlp_type,
mlp_bias=mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated)
n_expert_activated=n_expert_activated,
gate_type=gate_type)

self.right_decoder = TransformerDecoder(
vocab_size,
Expand Down Expand Up @@ -423,7 +427,8 @@ def __init__(
mlp_type=mlp_type,
mlp_bias=mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated)
n_expert_activated=n_expert_activated,
gate_type=gate_type)

def forward(
self,
Expand Down
6 changes: 5 additions & 1 deletion wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def __init__(
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
gate_type: str = 'normal',
):
""" Construct TransformerEncoder

Expand Down Expand Up @@ -423,7 +424,8 @@ def __init__(
activation,
mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated),
n_expert_activated=n_expert_activated,
gate_type=gate_type),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
Expand Down Expand Up @@ -474,6 +476,7 @@ def __init__(
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
gate_type: str = 'normal',
):
"""Construct ConformerEncoder

Expand Down Expand Up @@ -522,6 +525,7 @@ def __init__(
mlp_bias,
n_expert,
n_expert_activated,
gate_type,
)
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
Expand Down
14 changes: 13 additions & 1 deletion wenet/transformer/positionwise_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Positionwise feed forward layer definition."""

import torch

import torch.nn.functional as F

class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
Expand Down Expand Up @@ -66,6 +66,8 @@ class MoEFFNLayer(torch.nn.Module):

Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219

Noisy-gate reference from https://arxiv.org/pdf/1701.06538.pdf
Args:
n_expert: number of expert.
n_expert_activated: The actual number of experts used for each frame
Expand All @@ -84,6 +86,7 @@ def __init__(
bias: bool = False,
n_expert: int = 8,
n_expert_activated: int = 2,
gate_type: str = 'normal',
):
super(MoEFFNLayer, self).__init__()
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
Expand All @@ -93,6 +96,9 @@ def __init__(
for _ in range(n_expert))
self.n_expert = n_expert
self.n_expert_activated = n_expert_activated
self.gate_type = gate_type
if self.gate_type == 'noisy':
self.noisy_gate = torch.nn.Linear(idim, n_expert, bias=False)

def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Foward function.
Expand All @@ -106,6 +112,12 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor:
) # batch size, sequence length, embedding dimension (idim)
xs = xs.view(-1, D) # (B*L, D)
router = self.gate(xs) # (B*L, n_expert)
if self.gate_type == 'noisy':
noisy_router = self.noisy_gate(xs)
noisy_router = (
torch.randn_like(router) * F.softplus(noisy_router)
) * self.training
router = router + noisy_router
logits, selected_experts = torch.topk(
router, self.n_expert_activated
) # probs:(B*L, n_expert_activated), selected_exp: (B*L, n_expert_activated)
Expand Down
Loading