diff --git a/wenet/branchformer/encoder.py b/wenet/branchformer/encoder.py index 5c909f3fd..57ebe7cf0 100644 --- a/wenet/branchformer/encoder.py +++ b/wenet/branchformer/encoder.py @@ -16,7 +16,7 @@ """Encoder definition.""" import torch -from typing import List, Union +from typing import List, Optional, Union from wenet.branchformer.encoder_layer import BranchformerEncoderLayer from wenet.branchformer.cgmlp import ConvolutionalGatingMLP @@ -55,10 +55,15 @@ def __init__( global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, causal: bool = False, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, gradient_checkpointing: bool = False, use_sdpa: bool = False, layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, ): super().__init__(input_size, output_size, attention_heads, cgmlp_linear_units, num_blocks, dropout_rate, @@ -67,12 +72,17 @@ def __init__( static_chunk_size, use_dynamic_chunk, global_cmvn, use_dynamic_left_chunk, gradient_checkpointing, use_sdpa, layer_norm_type, norm_eps) - # super().__init__() encoder_selfattn_layer_args = ( attention_heads, output_size, attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, ) cgmlp_layer = ConvolutionalGatingMLP