Skip to content

Commit

Permalink
support mqa gradiengt ckpt sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Apr 16, 2024
1 parent 5b1005d commit a88f4db
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions wenet/branchformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Encoder definition."""

import torch
from typing import List, Union
from typing import List, Optional, Union

from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
from wenet.branchformer.cgmlp import ConvolutionalGatingMLP
Expand Down Expand Up @@ -55,10 +55,15 @@ def __init__(
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
causal: bool = False,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
super().__init__(input_size, output_size, attention_heads,
cgmlp_linear_units, num_blocks, dropout_rate,
Expand All @@ -67,12 +72,17 @@ def __init__(
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa, layer_norm_type, norm_eps)
# super().__init__()

encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
query_bias,
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
)

cgmlp_layer = ConvolutionalGatingMLP
Expand Down

0 comments on commit a88f4db

Please sign in to comment.