Skip to content

Commit

Permalink
format code and update lora attention && encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
fclearner committed Mar 28, 2024
1 parent 2d61280 commit 24dcb27
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
10 changes: 7 additions & 3 deletions wenet/finetune/lora/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.
"""Multi-Head Attention layer definition with lora."""

from typing import List
from typing import Optional, List

import torch
from torch import nn
Expand All @@ -43,10 +43,12 @@ def __init__(self,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
lora_rank: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.0,
lora_list: List[str] = ['q', 'k', 'v', 'o']):
lora_list: Optional[List[str]] = None):
"""Construct an MultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa)
Expand Down Expand Up @@ -90,10 +92,12 @@ def __init__(self,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
lora_rank: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.0,
lora_list: List[str] = ['q', 'k', 'v', 'o']):
lora_list: Optional[List[str]] = None):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa, lora_rank, lora_alpha,
Expand Down
28 changes: 21 additions & 7 deletions wenet/finetune/lora/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition with lora."""

from typing import List
from typing import Optional, List

import torch

Expand Down Expand Up @@ -61,10 +61,12 @@ def __init__(
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
lora_rank: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.0,
lora_list: List[str] = ['q', 'k', 'v', 'o'],
lora_list: Optional[List[str]] = None,
):
""" Construct TransformerEncoder
Expand All @@ -75,8 +77,10 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa, layer_norm_type, norm_eps)
use_dynamic_left_chunk, query_bias, key_bias,
value_bias, mlp_bias, activation_type,
gradient_checkpointing, use_sdpa, mlp_type,
layer_norm_type, norm_eps, n_kv_head, head_dim)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
Expand All @@ -87,6 +91,7 @@ def __init__(
attention_dropout_rate,
query_bias, key_bias,
value_bias, use_sdpa,
n_kv_head, head_dim,
lora_rank, lora_alpha,
lora_dropout,
lora_list),
Expand Down Expand Up @@ -138,10 +143,12 @@ def __init__(
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
lora_rank: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.0,
lora_list: List[str] = ['q', 'k', 'v', 'o'],
lora_list: Optional[List[str]] = None,
):
"""Construct ConformerEncoder
Expand All @@ -165,8 +172,13 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa, layer_norm_type, norm_eps)
use_dynamic_left_chunk, positionwise_conv_kernel_size,
macaron_style, selfattention_layer_type,
activation_type, use_cnn_module, cnn_module_kernel,
causal, cnn_module_norm, query_bias, key_bias,
value_bias, mlp_bias, conv_bias,
gradient_checkpointing, use_sdpa, mlp_type,
layer_norm_type, norm_eps, n_kv_head, head_dim)
activation = WENET_ACTIVATION_CLASSES[activation_type]()

# self-attention module definition
Expand All @@ -178,6 +190,8 @@ def __init__(
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
lora_rank,
lora_alpha,
lora_dropout,
Expand Down
1 change: 1 addition & 0 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def add_lora_args(parser):
default=0,
type=float,
help="lora dropout param.")
return parser


def add_ddp_args(parser):
Expand Down

0 comments on commit 24dcb27

Please sign in to comment.