From bfaa8a3e61a724f0122f14c66f33583063ee5b68 Mon Sep 17 00:00:00 2001 From: "di.wu" Date: Thu, 7 Dec 2023 23:11:14 +0800 Subject: [PATCH] [train] u2++-lite training support [train] add instructions for use --- examples/aishell/s0/README.md | 13 +++ .../s0/conf/train_u2++_lite_conformer.yaml | 91 +++++++++++++++++++ examples/aishell/s0/run.sh | 3 + wenet/k2/model.py | 10 +- wenet/paraformer/paraformer.py | 13 ++- wenet/transformer/asr_model.py | 64 ++++++++++--- wenet/transformer/ctc.py | 9 +- wenet/utils/executor.py | 1 - wenet/utils/train_utils.py | 26 +++++- 9 files changed, 201 insertions(+), 29 deletions(-) create mode 100644 examples/aishell/s0/conf/train_u2++_lite_conformer.yaml diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index 85ba82b33c..e3ba79f20c 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -32,6 +32,19 @@ | HLG(k2 LM) + attention rescoring | 4.32 | 4.70 | | HLG(k2 LM) + attention rescoring + LFMMI | 4.11 | 4.47 | +## U2++ lite Conformer Result (uio shard) + +* Feature info: using fbank feature, dither=1.0, cmvn, oneline speed perturb +* Training info: lr 0.001, batch size 16, 8 gpu, acc_grad 1, load a well trained model and continue training 80 epochs with u2++ lite config +* Decoding info: ctc_weight 0.3, reverse_weight 0.5 average_num 30 +* Git hash: 73185808fa1463b0163a922dc722513b7baabe9e + +| decoding mode/chunk size | full | 16 | +|---------------------------|-------|-------| +| ctc greedy search | 5.21 | 5.91 | +| ctc prefix beam search | 5.20 | 5.91 | +| attention rescoring | 4.67 | 5.10 | + ## Unified Conformer Result * Feature info: using fbank feature, dither=0, cmvn, oneline speed perturb diff --git a/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml b/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml new file mode 100644 index 0000000000..1eb280de2b --- /dev/null +++ b/examples/aishell/s0/conf/train_u2++_lite_conformer.yaml @@ -0,0 +1,91 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 8 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + causal: true + use_dynamic_chunk: true + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + use_dynamic_left_chunk: false + +# decoder related +decoder: bitransformer +decoder_conf: + attention_heads: 4 + linear_units: 1024 + num_blocks: 3 + r_num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + reverse_weight: 0.3 + apply_non_blank_embedding: true # warning: had better use a well trained model as init model + +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + spec_sub: true + spec_sub_conf: + num_t_sub: 3 + max_t: 30 + spec_trim: false + spec_trim_conf: + max_t: 50 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 16 + +grad_clip: 5 +accum_grad: 1 +max_epoch: 360 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 25000 diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index 4abce735f2..3f940d1ea5 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -38,6 +38,9 @@ train_set=train # 4. conf/train_unified_transformer.yaml: Unified dynamic chunk transformer # 5. conf/train_u2++_conformer.yaml: U2++ conformer # 6. conf/train_u2++_transformer.yaml: U2++ transformer +# 7. conf/train_u2++_conformer.yaml: U2++ lite conformer, must load a well +# trained model, and freeze encoder module, otherwise there will be a +# autograd error train_config=conf/train_conformer.yaml cmvn=true dir=exp/conformer diff --git a/wenet/k2/model.py b/wenet/k2/model.py index f38c0c6df4..271e450da2 100644 --- a/wenet/k2/model.py +++ b/wenet/k2/model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Tuple import torch from torch.nn.utils.rnn import pad_sequence @@ -49,9 +49,9 @@ def __init__( @torch.jit.ignore(drop=True) def _forward_ctc(self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, text: torch.Tensor, - text_lengths: torch.Tensor) -> torch.Tensor: - loss_ctc = self._calc_lfmmi_loss(encoder_out, encoder_mask, text) - return loss_ctc + text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask, text) + return loss_ctc, ctc_probs @torch.jit.ignore(drop=True) def load_lfmmi_resource(self): @@ -106,7 +106,7 @@ def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text): for i in text ] loss = self.lfmmi(dense_fsa_vec=dense_fsa_vec, texts=text) / len(text) - return loss + return loss, ctc_probs def load_hlg_resource_if_necessary(self, hlg, word): try: diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index d36972a472..32df2dde10 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -136,8 +136,9 @@ def forward( # 3.1 ctc branhch loss_ctc: Optional[torch.Tensor] = None if self.ctc_weight != 0.0: - loss_ctc = self._forward_ctc(encoder_out, encoder_out_mask, text, - text_lengths) + loss_ctc, ctc_probs = self._forward_ctc(encoder_out, + encoder_out_mask, + text, text_lengths) # TODO(Mddc): thu acc loss_decoder = self.criterion_att(decoder_out, ys_pad) loss = loss_decoder @@ -152,10 +153,12 @@ def forward( @torch.jit.ignore(drop=True) def _forward_ctc(self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, text: torch.Tensor, - text_lengths: torch.Tensor) -> torch.Tensor: + text_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: encoder_out_lens = encoder_mask.squeeze(1).sum(1) - loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths) - return loss_ctc + loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, + text, text_lengths) + return loss_ctc, ctc_probs @torch.jit.ignore(drop=True) def _sampler(self, encoder_out, encoder_out_mask, ys_pad, ys_pad_lens, diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 193cb46a2a..bac9fafeaa 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Tuple import torch +from torch.nn.utils.rnn import pad_sequence from wenet.transformer.ctc import CTC from wenet.transformer.decoder import TransformerDecoder @@ -25,6 +26,7 @@ ctc_prefix_beam_search, attention_beam_search, attention_rescoring, DecodeResult) +from wenet.utils.mask import make_pad_mask from wenet.utils.common import (IGNORE_ID, add_sos_eos, th_accuracy, reverse_pad_list) from wenet.utils.context_graph import ContextGraph @@ -45,6 +47,7 @@ def __init__( lsm_weight: float = 0.0, length_normalized_loss: bool = False, special_tokens: dict = None, + apply_non_blank_embedding: bool = False, ): assert 0.0 <= ctc_weight <= 1.0, ctc_weight @@ -65,6 +68,7 @@ def __init__( self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.reverse_weight = reverse_weight + self.apply_non_blank_embedding = apply_non_blank_embedding self.encoder = encoder self.decoder = decoder @@ -102,19 +106,25 @@ def forward( encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out_lens = encoder_mask.squeeze(1).sum(1) - # 2a. Attention-decoder branch + # 2a. CTC branch + if self.ctc_weight != 0.0: + loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text, + text_lengths) + else: + loss_ctc = None + + # 2b. Attention-decoder branch + # use non blank (token level) embedding for decoder + if self.apply_non_blank_embedding: + assert self.ctc_weight != 0 + encoder_out, encoder_mask = self.filter_blank_embedding( + ctc_probs, encoder_out) if self.ctc_weight != 1.0: loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, text, text_lengths) else: loss_att = None - - # 2b. CTC branch - if self.ctc_weight != 0.0: - loss_ctc = self._forward_ctc(encoder_out, encoder_mask, text, - text_lengths) - else: - loss_ctc = None + acc_att = None if loss_ctc is None: loss = loss_att @@ -128,10 +138,39 @@ def forward( @torch.jit.ignore(drop=True) def _forward_ctc(self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, text: torch.Tensor, - text_lengths: torch.Tensor) -> torch.Tensor: + text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: encoder_out_lens = encoder_mask.squeeze(1).sum(1) - loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths) - return loss_ctc + loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, + text, text_lengths) + return loss_ctc, ctc_probs + + def filter_blank_embedding( + self, ctc_probs: torch.Tensor, + encoder_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = encoder_out.size(0) + maxlen = encoder_out.size(1) + top1_index = torch.argmax(ctc_probs, dim=2) + indices = [] + for j in range(batch_size): + indices.append( + torch.tensor( + [i for i in range(maxlen) if top1_index[j][i] != 0])) + + select_encoder_out = [ + torch.index_select(encoder_out[i, :, :], 0, + indices[i].to(encoder_out.device)) + for i in range(batch_size) + ] + select_encoder_out = pad_sequence(select_encoder_out, + batch_first=True, + padding_value=0).to( + encoder_out.device) + xs_lens = torch.tensor([len(indices[i]) for i in range(batch_size) + ]).to(encoder_out.device) + T = select_encoder_out.size(1) + encoder_mask = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + encoder_out = select_encoder_out + return encoder_out, encoder_mask def _calc_att_loss( self, @@ -257,6 +296,9 @@ def decode( else: ctc_prefix_result = ctc_prefix_beam_search( ctc_probs, encoder_lens, beam_size, context_graph) + if self.apply_non_blank_embedding: + encoder_out, _ = self.filter_blank_embedding( + ctc_probs, encoder_out) results['attention_rescoring'] = attention_rescoring( self, ctc_prefix_result, encoder_out, encoder_lens, ctc_weight, reverse_weight) diff --git a/wenet/transformer/ctc.py b/wenet/transformer/ctc.py index 08ec50630f..e1e2d2d896 100644 --- a/wenet/transformer/ctc.py +++ b/wenet/transformer/ctc.py @@ -13,6 +13,8 @@ # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) +from typing import Tuple + import torch import torch.nn.functional as F @@ -46,7 +48,9 @@ def __init__( reduction=reduction_type) def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, - ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> torch.Tensor: + ys_pad: torch.Tensor, + ys_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate CTC loss. Args: @@ -63,7 +67,8 @@ def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens) # Batch-size average loss = loss / ys_hat.size(1) - return loss + ys_hat = ys_hat.transpose(0, 1) + return loss, ys_hat def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor: """log_softmax of frame activations diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index cfe1084b0f..4a61fe952e 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -38,7 +38,6 @@ def train(self, model, optimizer, scheduler, data_loader, writer, configs, info_dict["tag"] = "TRAIN" logging.info('using accumulate grad, new batch size is {} times' ' larger than before'.format(info_dict['accum_grad'])) - # A context manager to be used in conjunction with an instance of # torch.nn.parallel.DistributedDataParallel to be able to train # with uneven inputs across participating processes. diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 7c3abcf06c..93a8b012bb 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -66,12 +66,15 @@ def add_model_args(parser): default=None, type=str, help="Pre-trained model to initialize encoder") - parser.add_argument( - "--enc_init_mods", - default="encoder.", - type=lambda s: [str(mod) for mod in s.split(",") if s != ""], - help="List of encoder modules \ + parser.add_argument('--enc_init_mods', + default="encoder.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of encoder modules \ to initialize ,separated by a comma") + parser.add_argument('--freeze_modules', + default="", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help='free module names',) parser.add_argument('--lfmmi_dir', default='', required=False, @@ -239,6 +242,12 @@ def check_modify_and_save_config(args, configs, symbol_table): data = yaml.dump(configs) fout.write(data) + if configs["model_conf"]["apply_non_blank_embedding"]: + logging.warn( + 'Had better load a well trained model' + 'if apply_non_blank_embedding is true !!!' + ) + return configs @@ -601,3 +610,10 @@ def log_per_epoch(writer, info_dict): if int(os.environ.get('RANK', 0)) == 0: writer.add_scalar('epoch/cv_loss', info_dict["cv_loss"], epoch) writer.add_scalar('epoch/lr', info_dict["lr"], epoch) + +def freeze_modules(model, args): + for name, param in model.named_parameters(): + for module_name in args.freeze_modules: + if module_name in name: + param.requires_grad = False + logging.debug("{} module is freezed".format(name))