From 7a83dd1b3c57b93761db059d431b60d820b8914e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 7 Sep 2021 15:48:27 +0800 Subject: [PATCH 01/12] Use new APIs with k2.RaggedTensor --- .gitignore | 2 +- egs/librispeech/ASR/conformer_ctc/decode.py | 19 ++ egs/librispeech/ASR/local/compile_hlg.py | 6 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 5 +- .../ASR/tdnn_lstm_ctc/pretrained.py | 0 egs/yesno/ASR/local/compile_hlg.py | 6 +- egs/yesno/ASR/tdnn/decode.py | 1 + icefall/decode.py | 191 ++++++++++-------- icefall/lexicon.py | 11 +- icefall/utils.py | 25 ++- test/test_utils.py | 4 +- 11 files changed, 155 insertions(+), 115 deletions(-) mode change 100644 => 100755 egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py diff --git a/.gitignore b/.gitignore index 839a1c34a3..e6c84ca5e6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ path.sh exp exp*/ *.pt -download/ +download diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index ff6374d73b..cfdcff7565 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -45,6 +45,7 @@ get_texts, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -116,6 +117,17 @@ def get_parser(): """, ) + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) + return parser @@ -541,6 +553,13 @@ def main(): logging.info(f"averaging {filenames}") model.load_state_dict(average_checkpoints(filenames)) + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 19a1ddd238..407fb7d888 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -102,14 +102,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 - assert isinstance(LG.aux_labels, k2.RaggedInt) - LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") LG = k2.connect(LG) - LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index afdebd12b2..87e9cddb42 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -99,8 +99,10 @@ def get_params() -> AttributeDict: # - nbest-rescoring # - whole-lattice-rescoring "method": "whole-lattice-rescoring", + # "method": "1best", + # "method": "nbest", # num_paths is used when method is "nbest" and "nbest-rescoring" - "num_paths": 30, + "num_paths": 100, } ) return params @@ -424,6 +426,7 @@ def main(): torch.save( {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" ) + return model.to(device) model.eval() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py old mode 100644 new mode 100755 diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index f2fafd0136..41a9274553 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -80,14 +80,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 - assert isinstance(LG.aux_labels, k2.RaggedInt) - LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") LG = k2.connect(LG) - LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index aa7b07b981..54fdbb3cc3 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -296,6 +296,7 @@ def main(): torch.save( {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" ) + return model.to(device) model.eval() diff --git a/icefall/decode.py b/icefall/decode.py index de32194018..3f6e5fc848 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -84,8 +84,8 @@ def _intersect_device( for start, end in splits: indexes = torch.arange(start, end).to(b_to_a_map) - fsas = k2.index(b_fsas, indexes) - b_to_a = k2.index(b_to_a_map, indexes) + fsas = k2.index_fsa(b_fsas, indexes) + b_to_a = k2.index_select(b_to_a_map, indexes) path_lattice = k2.intersect_device( a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a ) @@ -215,18 +215,16 @@ def nbest_decoding( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) - # Note: the above operation supports also the case when - # lattice.aux_labels is a ragged tensor. In that case, - # `remove_axis=True` is used inside the pybind11 binding code, - # so the resulting `word_seq` still has 3 axes, like `path`. - # The 3 axes are [seq][path][word_id] + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove 0 (epsilon) and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove sequences with identical word sequences. # @@ -234,12 +232,12 @@ def nbest_decoding( # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, _, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=False, need_new2old_indexes=True + unique_word_seq, _, new2old = word_seq.unique( + need_num_repeats=False, need_new2old_indexes=True ) # Note: unique_word_seq still has the same axes as word_seq - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path belongs @@ -247,7 +245,7 @@ def nbest_decoding( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -275,35 +273,35 @@ def nbest_decoding( use_double_scores=use_double_scores, log_semiring=False ) - # RaggedFloat currently supports float32 only. - # If Ragged is wrapped, we can use k2.RaggedDouble here - ragged_tot_scores = k2.RaggedFloat( - seq_to_path_shape, tot_scores.to(torch.float32) - ) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + argmax_indexes = ragged_tot_scores.argmax() # Since we invoked `k2.ragged.unique_sequences`, which reorders # the index from `path`, we use `new2old` here to convert argmax_indexes # to the indexes into `path`. # # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) - # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + # best_path is a k2.RaggedTensor with 2 axes [path][arc_pos] + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][token_id] + # labels is a k2.RaggedTensor with 2 axes [path][token_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + # lattice.aux_labels is a k2.RaggedTensor with 2 axes, so + # aux_labels is also a k2.RaggedTensor with 2 axes + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels @@ -426,33 +424,36 @@ def rescore_with_n_best_list( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove epsilons and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove paths that has identical word sequences. # - # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] + # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] # except that there are no repeated paths with the same word_seq # within a sequence. # - # num_repeats is also a k2.RaggedInt with 2 axes containing the + # num_repeats is also a k2.RaggedTensor with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.tot_size(1) + # num_repeats.numel() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=True, need_new2old_indexes=True + unique_word_seq, num_repeats, new2old = word_seq.unique( + need_num_repeats=True, need_new2old_indexes=True ) - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path @@ -461,7 +462,7 @@ def rescore_with_n_best_list( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -485,39 +486,42 @@ def rescore_with_n_best_list( use_double_scores=True, log_semiring=False ) - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) ans = dict() for lm_scale in lm_scale_list: tot_scores = am_scores / lm_scale + lm_scores - # Remember that we used `k2.ragged.unique_sequences` to remove repeated + # Remember that we used `k2.RaggedTensor.unique` to remove repeated # paths to avoid redundant computation in `k2.intersect_device`. # Now we use `num_repeats` to correct the scores for each path. # # NOTE(fangjun): It is commented out as it leads to a worse WER # tot_scores = tot_scores * num_repeats.values() - ragged_tot_scores = k2.RaggedFloat( - seq_to_path_shape, tot_scores.to(torch.float32) - ) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) + argmax_indexes = ragged_tot_scores.argmax() # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][phone_id] + # labels is a k2.RaggedTensor with 2 axes [path][phone_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + # lattice.aux_labels is a k2.RaggedTensor tensor with 2 axes, so + # aux_labels is also a k2.RaggedTensor with 2 axes + + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels @@ -659,12 +663,16 @@ def nbest_oracle( scale=scale, ) - word_seq = k2.index(lattice.aux_labels, path) - word_seq = k2.ragged.remove_values_leq(word_seq, 0) - unique_word_seq, _, _ = k2.ragged.unique_sequences( - word_seq, need_num_repeats=False, need_new2old_indexes=False + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) + + word_seq = word_seq.remove_values_leq(0) + unique_word_seq, _, _ = word_seq.unique( + need_num_repeats=False, need_new2old_indexes=False ) - unique_word_ids = k2.ragged.to_list(unique_word_seq) + unique_word_ids = unique_word_seq.tolist() assert len(unique_word_ids) == len(ref_texts) # unique_word_ids[i] contains all hypotheses of the i-th utterance @@ -743,33 +751,36 @@ def rescore_with_attention_decoder( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove epsilons and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove paths that has identical word sequences. # - # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] + # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] # except that there are no repeated paths with the same word_seq # within a sequence. # - # num_repeats is also a k2.RaggedInt with 2 axes containing the + # num_repeats is also a k2.RaggedTensor with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.tot_size(1) + # num_repeats.numel() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seq.tot_size(1) - unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=True, need_new2old_indexes=True + unique_word_seq, num_repeats, new2old = word_seq.unique( + need_num_repeats=True, need_new2old_indexes=True ) - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path @@ -778,7 +789,7 @@ def rescore_with_attention_decoder( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -796,20 +807,23 @@ def rescore_with_attention_decoder( # CAUTION: The "tokens" attribute is set in the file # local/compile_hlg.py - token_seq = k2.index(lattice.tokens, path) + if isinstance(lattice.tokens, torch.Tensor): + token_seq = k2.ragged.index(lattice.tokens, path) + else: + token_seq = lattice.tokens.index(path, remove_axis=True) # Remove epsilons and -1 from token_seq - token_seq = k2.ragged.remove_values_leq(token_seq, 0) + token_seq = token_seq.remove_values_leq(0) # Remove the seq axis. - token_seq = k2.ragged.remove_axis(token_seq, 0) + token_seq = token_seq.remove_axis(0) - token_seq, _ = k2.ragged.index( - token_seq, indexes=new2old, axis=0, need_value_indexes=False + token_seq, _ = token_seq.index( + indexes=new2old, axis=0, need_value_indexes=False ) # Now word in unique_word_seq has its corresponding token IDs. - token_ids = k2.ragged.to_list(token_seq) + token_ids = token_seq.tolist() num_word_seqs = new2old.numel() @@ -849,7 +863,7 @@ def rescore_with_attention_decoder( else: attention_scale_list = [attention_scale] - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) ans = dict() for n_scale in ngram_lm_scale_list: @@ -859,23 +873,28 @@ def rescore_with_attention_decoder( + n_scale * ngram_lm_scores + a_scale * attention_scores ) - ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) + argmax_indexes = ragged_tot_scores.argmax() - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][token_id] + # labels is a k2.RaggedTensor with 2 axes [path][token_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + if isinstance(lattice.aux_labels, torch.Tensor): + aux_labels = k2.index_select(lattice.aux_labels, best_path.data) + else: + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels diff --git a/icefall/lexicon.py b/icefall/lexicon.py index f1127c7cf8..6730bac493 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -157,7 +157,7 @@ def __init__( lang_dir / "lexicon.txt" ) - def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedInt: + def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor: """Read a BPE lexicon from file and convert it to a k2 ragged tensor. @@ -200,19 +200,18 @@ def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedInt: ) values = torch.tensor(token_ids, dtype=torch.int32) - return k2.RaggedInt(shape, values) + return k2.RaggedTensor(shape, values) - def words_to_piece_ids(self, words: List[str]) -> k2.RaggedInt: + def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor: """Convert a list of words to a ragged tensor contained word piece IDs. """ word_ids = [self.word_table[w] for w in words] word_ids = torch.tensor(word_ids, dtype=torch.int32) - ragged, _ = k2.ragged.index( - self.ragged_lexicon, + ragged, _ = self.ragged_lexicon.index( indexes=word_ids, - need_value_indexes=False, axis=0, + need_value_indexes=False, ) return ragged diff --git a/icefall/utils.py b/icefall/utils.py index 2994c2d475..b78dfe8ad8 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -199,26 +199,25 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: Returns a list of lists of int, containing the label sequences we decoded. """ - if isinstance(best_paths.aux_labels, k2.RaggedInt): + if isinstance(best_paths.aux_labels, k2.RaggedTensor): # remove 0's and -1's. - aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0) - aux_shape = k2r.compose_ragged_shapes( - best_paths.arcs.shape(), aux_labels.shape() - ) + aux_labels = best_paths.aux_labels.remove_values_leq(0) + # TODO: change arcs.shape() to arcs.shape + aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) # remove the states and arcs axes. - aux_shape = k2r.remove_axis(aux_shape, 1) - aux_shape = k2r.remove_axis(aux_shape, 1) - aux_labels = k2.RaggedInt(aux_shape, aux_labels.values()) + aux_shape = aux_shape.remove_axis(1) + aux_shape = aux_shape.remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data) else: # remove axis corresponding to states. - aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1) - aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels) + aux_shape = best_paths.arcs.shape().remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) # remove 0's and -1's. - aux_labels = k2r.remove_values_leq(aux_labels, 0) + aux_labels = aux_labels.remove_values_leq(0) - assert aux_labels.num_axes() == 2 - return k2r.to_list(aux_labels) + assert aux_labels.num_axes == 2 + return aux_labels.tolist() def store_transcripts( diff --git a/test/test_utils.py b/test/test_utils.py index 2dd79689f6..b4c9358fd1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -60,7 +60,7 @@ def test_get_texts_ragged(): 4 """ ) - fsa1.aux_labels = k2.RaggedInt("[ [1 3 0 2] [] [4 0 1] [-1]]") + fsa1.aux_labels = k2.RaggedTensor("[ [1 3 0 2] [] [4 0 1] [-1]]") fsa2 = k2.Fsa.from_str( """ @@ -70,7 +70,7 @@ def test_get_texts_ragged(): 3 """ ) - fsa2.aux_labels = k2.RaggedInt("[[3 0 5 0 8] [0 9 7 0] [-1]]") + fsa2.aux_labels = k2.RaggedTensor("[[3 0 5 0 8] [0 9 7 0] [-1]]") fsas = k2.Fsa.from_fsas([fsa1, fsa2]) texts = get_texts(fsas) assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]] From 4d06ca4d4506a42b1008ad0d3f4644764adf5a83 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 7 Sep 2021 16:57:55 +0800 Subject: [PATCH 02/12] Fix style issues. --- docs/source/conf.py | 1 - .../images/device-CPU_CUDA-orange.svg | 2 +- .../images/os-Linux_macOS-ff69b4.svg | 2 +- .../images/python-3.6_3.7_3.8_3.9-blue.svg | 2 +- ...6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg | 2 +- docs/source/recipes/index.rst | 1 - .../recipes/librispeech/tdnn_lstm_ctc.rst | 18 +++++++++--------- egs/librispeech/ASR/RESULTS.md | 1 - .../ASR/conformer_ctc/test_subsampling.py | 3 +-- .../ASR/conformer_ctc/test_transformer.py | 9 ++++----- icefall/utils.py | 1 - test/test_bpe_graph_compiler.py | 3 ++- 12 files changed, 20 insertions(+), 25 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index f97f72d54f..599df8b3e4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,7 +16,6 @@ import sphinx_rtd_theme - # -- Project information ----------------------------------------------------- project = "icefall" diff --git a/docs/source/installation/images/device-CPU_CUDA-orange.svg b/docs/source/installation/images/device-CPU_CUDA-orange.svg index b760102e39..a023a1283c 100644 --- a/docs/source/installation/images/device-CPU_CUDA-orange.svg +++ b/docs/source/installation/images/device-CPU_CUDA-orange.svg @@ -1 +1 @@ -device: CPU | CUDAdeviceCPU | CUDA \ No newline at end of file +device: CPU | CUDAdeviceCPU | CUDA diff --git a/docs/source/installation/images/os-Linux_macOS-ff69b4.svg b/docs/source/installation/images/os-Linux_macOS-ff69b4.svg index 44c1127477..178813ed47 100644 --- a/docs/source/installation/images/os-Linux_macOS-ff69b4.svg +++ b/docs/source/installation/images/os-Linux_macOS-ff69b4.svg @@ -1 +1 @@ -os: Linux | macOSosLinux | macOS \ No newline at end of file +os: Linux | macOSosLinux | macOS diff --git a/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg b/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg index 676feba2c6..befc1e19ea 100644 --- a/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg +++ b/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg @@ -1 +1 @@ -python: 3.6 | 3.7 | 3.8 | 3.9python3.6 | 3.7 | 3.8 | 3.9 \ No newline at end of file +python: 3.6 | 3.7 | 3.8 | 3.9python3.6 | 3.7 | 3.8 | 3.9 diff --git a/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg b/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg index d9b0efe1a7..496e5a9efc 100644 --- a/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg +++ b/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg @@ -1 +1 @@ -torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0torch1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0 \ No newline at end of file +torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0torch1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0 diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index db34fdca5e..36f8dfc394 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -15,4 +15,3 @@ We may add recipes for other tasks as well in the future. yesno librispeech - diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst index a59f34db76..64f0a6a08f 100644 --- a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst @@ -209,7 +209,7 @@ After downloading, you will have the following files: |-- 1221-135766-0001.flac |-- 1221-135766-0002.flac `-- trans.txt - + 6 directories, 10 files @@ -256,14 +256,14 @@ The output is: 2021-08-24 16:57:28,098 INFO [pretrained.py:266] ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - - + + 2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done @@ -297,14 +297,14 @@ The decoding output is: 2021-08-24 16:39:54,010 INFO [pretrained.py:266] ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - - + + 2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index dfc412672b..d4acf92068 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -43,4 +43,3 @@ We searched the lm_score_scale for best results, the scales that produced the WE |--|--| |test-clean|0.8| |test-other|0.9| - diff --git a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py index e3361d0c98..81fa234dd5 100755 --- a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py @@ -16,9 +16,8 @@ # limitations under the License. -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_ctc/test_transformer.py b/egs/librispeech/ASR/conformer_ctc/test_transformer.py index b90215274b..667057c513 100644 --- a/egs/librispeech/ASR/conformer_ctc/test_transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/test_transformer.py @@ -17,17 +17,16 @@ import torch +from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, + add_eos, + add_sos, + decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, ) -from torch.nn.utils.rnn import pad_sequence - def test_encoder_padding_mask(): supervisions = { diff --git a/icefall/utils.py b/icefall/utils.py index b78dfe8ad8..1130d8947a 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -26,7 +26,6 @@ from typing import Dict, Iterable, List, TextIO, Tuple, Union import k2 -import k2.ragged as k2r import kaldialign import torch import torch.distributed as dist diff --git a/test/test_bpe_graph_compiler.py b/test/test_bpe_graph_compiler.py index 67d300b7d2..e58c4f1c63 100755 --- a/test/test_bpe_graph_compiler.py +++ b/test/test_bpe_graph_compiler.py @@ -16,9 +16,10 @@ # limitations under the License. +from pathlib import Path + from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.lexicon import BpeLexicon -from pathlib import Path def test(): From c43dc893f59f3e32d9ff84cb5713a6f041806d75 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 8 Sep 2021 10:50:38 +0800 Subject: [PATCH 03/12] Update the installation doc, saying it requires at least k2 v1.7 --- docs/source/installation/images/k2-v-1.7.svg | 1 + docs/source/installation/index.rst | 23 +++++++++++++++----- 2 files changed, 19 insertions(+), 5 deletions(-) create mode 100644 docs/source/installation/images/k2-v-1.7.svg diff --git a/docs/source/installation/images/k2-v-1.7.svg b/docs/source/installation/images/k2-v-1.7.svg new file mode 100644 index 0000000000..8a74d0b55e --- /dev/null +++ b/docs/source/installation/images/k2-v-1.7.svg @@ -0,0 +1 @@ +k2: >= v1.7k2>= v1.7 diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index bcef669c84..c11cbd1be9 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -7,6 +7,7 @@ Installation - |device| - |python_versions| - |torch_versions| +- |k2_versions| .. |os| image:: ./images/os-Linux_macOS-ff69b4.svg :alt: Supported operating systems @@ -20,7 +21,10 @@ Installation .. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg :alt: Supported PyTorch versions -icefall depends on `k2 `_ and +.. |k2_versions| image:: ./images/k2-v-1.7.svg + :alt: Supported k2 versions + +``icefall`` depends on `k2 `_ and `lhotse `_. We recommend you to install ``k2`` first, as ``k2`` is bound to @@ -32,12 +36,16 @@ installs its dependency PyTorch, which can be reused by ``lhotse``. -------------- Please refer to ``_ -to install `k2`. +to install ``k2``. + +.. CAUTION:: + + You need to install ``k2`` with a version at least **v1.7**. .. HINT:: If you have already installed PyTorch and don't want to replace it, - please install a version of k2 that is compiled against the version + please install a version of ``k2`` that is compiled against the version of PyTorch you are using. (2) Install lhotse @@ -50,10 +58,15 @@ to install ``lhotse``. Install ``lhotse`` also installs its dependency `torchaudio `_. +.. CAUTION:: + + If you have installed ``torchaudio``, please consider uninstalling it before + installing ``lhotse``. Otherwise, it may update your already installed PyTorch. + (3) Download icefall -------------------- -icefall is a collection of Python scripts, so you don't need to install it +``icefall`` is a collection of Python scripts, so you don't need to install it and we don't provide a ``setup.py`` to install it. What you need is to download it and set the environment variable ``PYTHONPATH`` @@ -367,7 +380,7 @@ Now let us run the training part: .. CAUTION:: - We use ``export CUDA_VISIBLE_DEVICES=""`` so that icefall uses CPU + We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU even if there are GPUs available. The training log is given below: From 2cb438c3f081b9fa01e012998a8601cc78c56c7a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 8 Sep 2021 19:32:14 +0800 Subject: [PATCH 04/12] Extract framewise alignment information using CTC decoding. --- egs/librispeech/ASR/conformer_ctc/ali.py | 213 +++++++++++++++++++++++ icefall/decode.py | 2 +- icefall/utils.py | 22 +++ 3 files changed, 236 insertions(+), 1 deletion(-) create mode 100755 egs/librispeech/ASR/conformer_ctc/ali.py diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py new file mode 100755 index 0000000000..8d779e850b --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path + +import k2 +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import one_best_decoding +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + encode_supervisions, + get_alignments, + setup_logger, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "exp_dir": Path("conformer_ctc/exp"), + "lang_dir": Path("data/lang_bpe"), + "lm_dir": Path("data/lm"), + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "is_espnet_structure": True, + "mmi_loss": False, + "use_feat_batchnorm": True, + "output_beam": 10, + "use_double_scores": True, + } + ) + return params + + +def compute_alignments( + model: torch.nn.Module, + dl: torch.utils.data.DataLoader, + params: AttributeDict, + graph_compiler: BpeCtcTrainingGraphCompiler, + token_table: k2.SymbolTable, +): + device = graph_compiler.device + for batch_idx, batch in enumerate(dl): + feature = batch["inputs"] + + # at entry, feature is [N, T, C] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is [N, T, C] + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + lattice = k2.intersect_dense( + decoding_graph, dense_fsa_vec, params.output_beam + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + + ali_ids = get_alignments(best_path) + ali_tokens = [[token_table[i] for i in ids] for ids in ali_ids] + + frame_shift = 0.01 # 10ms, i.e., 0.01 seconds + for i, ali in enumerate(ali_tokens[0]): + print(i * params.subsampling_factor * frame_shift, ali) + import sys + + sys.exit(0) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert args.return_cuts is True + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log/ali") + logging.info("Computing alignment - started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + is_espnet_structure=params.is_espnet_structure, + mmi_loss=params.mmi_loss, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to(device) + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + test_dl = librispeech.test_dataloaders() # a list + + enabled_datasets = { + "test_clean": test_dl[0], + "test_other": test_dl[1], + } + + compute_alignments( + model=model, + dl=enabled_datasets["test_clean"], + params=params, + graph_compiler=graph_compiler, + token_table=lexicon.token_table, + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/icefall/decode.py b/icefall/decode.py index 3f6e5fc848..e3a875b168 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -878,7 +878,7 @@ def rescore_with_attention_decoder( best_path_indexes = k2.index_select(new2old, argmax_indexes) - # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] + # best_path is a k2.RaggedTensor with 2 axes [path][arc_pos] best_path, _ = path_2axes.index( indexes=best_path_indexes, axis=0, need_value_indexes=False ) diff --git a/icefall/utils.py b/icefall/utils.py index 1130d8947a..1016bcd35c 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -219,6 +219,28 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: return aux_labels.tolist() +def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: + """Extract the token IDs (from best_paths.labels) from the best-path FSAs. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + Returns: + Returns a list of lists of int, containing the token sequences we + decoded. For `ans[i]`, its length equals to the number of frames + after subsampling of the i-th utterance in the batch. + """ + # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here + label_shape = best_paths.arcs.shape().remove_axis(1) + # label_shape has axes [fsa][arc] + labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous()) + labels = labels.remove_values_eq(-1) + return labels.tolist() + + def store_transcripts( filename: Pathlike, texts: Iterable[Tuple[str, str]] ) -> None: From 8f64fb9921b51324795bae1e0987ebf8253f5efd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 13 Sep 2021 12:06:29 +0800 Subject: [PATCH 05/12] Print environment information. Print information about k2, lhotse, PyTorch, and icefall. --- egs/librispeech/ASR/RESULTS.md | 6 +- egs/librispeech/ASR/conformer_ctc/decode.py | 10 ++- .../ASR/conformer_ctc/pretrained.py | 3 +- egs/librispeech/ASR/conformer_ctc/train.py | 15 ++-- egs/librispeech/ASR/prepare.sh | 4 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 2 + .../ASR/tdnn_lstm_ctc/pretrained.py | 3 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 2 + egs/yesno/ASR/tdnn/decode.py | 2 + egs/yesno/ASR/tdnn/pretrained.py | 3 +- egs/yesno/ASR/tdnn/train.py | 3 +- icefall/utils.py | 90 ++++++++++++++++--- test/test_utils.py | 12 ++- 13 files changed, 129 insertions(+), 26 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index d04e912bfd..f58ba64515 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -38,14 +38,16 @@ python conformer_ctc/train.py --bucketing-sampler True \ --concatenate-cuts False \ --max-duration 200 \ --full-libri True \ - --world-size 4 + --world-size 4 \ + --lang-dir data/lang_bpe_5000 python conformer_ctc/decode.py --lattice-score-scale 0.5 \ --epoch 34 \ --avg 20 \ --method attention-decoder \ --max-duration 20 \ - --num-paths 100 + --num-paths 100 \ + --lang-dir data/lang_bpe_5000 ``` ### LibriSpeech training results (Tdnn-Lstm) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 85161f7373..c6a6dd85d3 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -42,6 +42,7 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -128,6 +129,13 @@ def get_parser(): """, ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_5000", + help="lang directory", + ) + return parser @@ -135,7 +143,6 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), "feature_dim": 80, "nhead": 8, @@ -151,6 +158,7 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "env_info": get_env_info(), } ) return params diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 95029fadb6..574fafcfec 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -34,7 +34,7 @@ rescore_with_attention_decoder, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -224,6 +224,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index b0dbe72adb..e3242c943a 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -43,6 +43,7 @@ from icefall.utils import ( AttributeDict, encode_supervisions, + get_env_info, setup_logger, str2bool, ) @@ -74,6 +75,13 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_5000", + help="lang directory", + ) + parser.add_argument( "--num-epochs", type=int, @@ -108,9 +116,6 @@ def get_params() -> AttributeDict: - exp_dir: It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved - - lang_dir: It contains language related input files such as - "lexicon.txt" - - lr: It specifies the initial learning rate - feature_dim: The model input dim. It has to match the one used @@ -151,7 +156,6 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "weight_decay": 1e-6, "subsampling_factor": 4, @@ -160,7 +164,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 10, + "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, "beam_size": 10, @@ -176,6 +180,7 @@ def get_params() -> AttributeDict: "use_feat_batchnorm": True, "lr_factor": 5.0, "warm_step": 80000, + "env_info": get_env_info(), } ) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index f06e013f60..3a68e0f239 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -41,6 +41,8 @@ dl_dir=$PWD/download # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( 5000 + 2000 + 1000 ) # All files generated by this script are saved in "data". @@ -190,5 +192,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then ./local/compile_hlg.py --lang-dir $lang_dir done fi - -cd data && ln -sfv lang_bpe_5000 lang_bpe diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 23b2e794cd..4dda7818d1 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -39,6 +39,7 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -103,6 +104,7 @@ def get_params() -> AttributeDict: # "method": "nbest", # num_paths is used when method is "nbest" and "nbest-rescoring" "num_paths": 100, + "env_info": get_env_info(), } ) return params diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index 4f82a989c7..523f36e3e1 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -34,7 +34,7 @@ one_best_decoding, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -159,6 +159,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 4d45d197b1..6144f4a546 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -44,6 +44,7 @@ from icefall.utils import ( AttributeDict, encode_supervisions, + get_env_info, setup_logger, str2bool, ) @@ -168,6 +169,7 @@ def get_params() -> AttributeDict: "beam_size": 10, "reduction": "sum", "use_double_scores": True, + "env_info": get_env_info(), } ) diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 54fdbb3cc3..62d8bb9d70 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -17,6 +17,7 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -256,6 +257,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() setup_logger(f"{params.exp_dir}/log/log-decode") logging.info("Decoding started") diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index fb92110e34..5b85008a63 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -29,7 +29,7 @@ from torch.nn.utils.rnn import pad_sequence from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -116,6 +116,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 39c5ef3efb..f2e9866885 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -24,7 +24,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, setup_logger, str2bool +from icefall.utils import AttributeDict, get_env_info, setup_logger, str2bool def get_parser(): @@ -483,6 +483,7 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() fix_random_seed(42) if world_size > 1: diff --git a/icefall/utils.py b/icefall/utils.py index 1016bcd35c..ad08e4d8fa 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -19,14 +19,17 @@ import logging import os import subprocess +import sys from collections import defaultdict from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Tuple, Union +from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union import k2 +import k2.version import kaldialign +import lhotse import torch import torch.distributed as dist @@ -132,17 +135,82 @@ def setup_logger( logging.getLogger("").addHandler(console) -def get_env_info(): - """ - TODO: - """ +def get_git_sha1(): + git_commit = ( + subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + dirty_commit = ( + len( + subprocess.run( + ["git", "diff", "--shortstat"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + > 0 + ) + git_commit = ( + git_commit + "-dirty" if dirty_commit else git_commit + "-clean" + ) + return git_commit + + +def get_git_date(): + git_date = ( + subprocess.run( + ["git", "log", "-1", "--format=%ad", "--date=local"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_git_branch_name(): + git_date = ( + subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_env_info() -> Dict[str, Any]: + """Get the environment information.""" return { - "k2-git-sha1": None, - "k2-version": None, - "lhotse-version": None, - "torch-version": None, - "icefall-sha1": None, - "icefall-version": None, + "k2-version": k2.version.__version__, + "k2-build-type": k2.version.__build_type__, + "k2-with-cuda": k2.with_cuda, + "k2-git-sha1": k2.version.__git_sha1__, + "k2-git-date": k2.version.__git_date__, + "lhotse-version": lhotse.__version__, + "torch-cuda-available": torch.cuda.is_available(), + "torch-cuda-version": torch.version.cuda, + "python-version": sys.version[:3], + "icefall-git-branch": get_git_branch_name(), + "icefall-git-sha1": get_git_sha1(), + "icefall-git-date": get_git_date(), + "icefall-path": str(Path(__file__).resolve().parent.parent), + "k2-path": str(Path(k2.__file__).resolve()), + "lhotse-path": str(Path(lhotse.__file__).resolve()), } diff --git a/test/test_utils.py b/test/test_utils.py index b4c9358fd1..8b0c03e953 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,7 +20,12 @@ import pytest import torch -from icefall.utils import AttributeDict, encode_supervisions, get_texts +from icefall.utils import ( + AttributeDict, + encode_supervisions, + get_env_info, + get_texts, +) @pytest.fixture @@ -108,3 +113,8 @@ def test_attribute_dict(): assert s["b"] == 20 s.c = 100 assert s["c"] == 100 + + +def test_get_env_info(): + s = get_env_info() + print(s) From 62b275962c2de2d50cda9fbcb8b66caf60a4ea46 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 13 Sep 2021 13:49:16 +0800 Subject: [PATCH 06/12] Fix CI. --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c3025d7304..51c3fedcf2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -50,6 +50,7 @@ jobs: run: | python3 -m pip install --upgrade pip pytest pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + pip install git+https://github.com/lhotse-speech/lhotse # icefall requirements pip install -r requirements.txt From d8bef0976141b8a1137c81a0a2439a12640b5284 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 16 Sep 2021 11:12:52 +0800 Subject: [PATCH 07/12] Fix CI. --- .github/workflows/test.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 003ca25c29..f2b505c868 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -46,6 +46,13 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install libnsdfile and libsox + if: startsWith(matrix.os, 'ubuntu') + run: | + sudo apt update + sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg + sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all + - name: Install Python dependencies run: | python3 -m pip install --upgrade pip pytest From 27a6d5e9cb1645cac1b2dc367be7af3db27828a3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 23 Sep 2021 20:12:39 +0800 Subject: [PATCH 08/12] Compute framewise alignment information of the LibriSpeech dataset. --- egs/librispeech/ASR/conformer_ctc/ali.py | 130 ++++++++++++++++++----- icefall/utils.py | 45 ++++++++ 2 files changed, 150 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 8d779e850b..07390f7e7e 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -18,6 +18,7 @@ import argparse import logging from pathlib import Path +from typing import List, Tuple import k2 import torch @@ -32,6 +33,7 @@ AttributeDict, encode_supervisions, get_alignments, + save_alignments, setup_logger, ) @@ -56,14 +58,33 @@ def get_parser(): "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="The lang dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--ali-dir", + type=str, + default="data/ali", + help="The experiment dir", + ) return parser def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), "feature_dim": 80, "nhead": 8, @@ -71,8 +92,6 @@ def get_params() -> AttributeDict: "subsampling_factor": 4, "num_decoder_layers": 6, "vgg_frontend": False, - "is_espnet_structure": True, - "mmi_loss": False, "use_feat_batchnorm": True, "output_beam": 10, "use_double_scores": True, @@ -86,9 +105,31 @@ def compute_alignments( dl: torch.utils.data.DataLoader, params: AttributeDict, graph_compiler: BpeCtcTrainingGraphCompiler, - token_table: k2.SymbolTable, -): +) -> List[Tuple[str, List[int]]]: + """Compute the framewise alignments of a dataset. + + Args: + model: + The neural network model. + dl: + Dataloader containing the dataset. + params: + Parameters for computing alignments. + graph_compiler: + It converts token IDs to decoding graphs. + Returns: + Return a list of tuples. Each tuple contains two entries: + - Utterance ID + - Framewise alignments (token IDs) after subsampling + """ + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + num_cuts = 0 + device = graph_compiler.device + ans = [] for batch_idx, batch in enumerate(dl): feature = batch["inputs"] @@ -97,11 +138,23 @@ def compute_alignments( feature = feature.to(device) supervisions = batch["supervisions"] + + cut_ids = [] + for cut in supervisions["cut"]: + assert len(cut.supervisions) == 1 + cut_ids.append(cut.id) + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) # nnet_output is [N, T, C] supervision_segments, texts = encode_supervisions( supervisions, subsampling_factor=params.subsampling_factor ) + # we need also to sort cut_ids as encode_supervisions() + # reorders "texts". + # In general, new2old is an identity map since lhotse sorts the returned + # cuts by duration in descending order + new2old = supervision_segments[:, 0].tolist() + cut_ids = [cut_ids[i] for i in new2old] token_ids = graph_compiler.texts_to_ids(texts) decoding_graph = graph_compiler.compile(token_ids) @@ -113,22 +166,30 @@ def compute_alignments( ) lattice = k2.intersect_dense( - decoding_graph, dense_fsa_vec, params.output_beam + decoding_graph, + dense_fsa_vec, + params.output_beam, ) best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores + lattice=lattice, + use_double_scores=params.use_double_scores, ) ali_ids = get_alignments(best_path) - ali_tokens = [[token_table[i] for i in ids] for ids in ali_ids] + assert len(ali_ids) == len(cut_ids) + ans += list(zip(cut_ids, ali_ids)) + + num_cuts += len(ali_ids) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" - frame_shift = 0.01 # 10ms, i.e., 0.01 seconds - for i, ali in enumerate(ali_tokens[0]): - print(i * params.subsampling_factor * frame_shift, ali) - import sys + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) - sys.exit(0) + return ans @torch.no_grad() @@ -138,6 +199,7 @@ def main(): args = parser.parse_args() assert args.return_cuts is True + assert args.concatenate_cuts is False params = get_params() params.update(vars(args)) @@ -169,9 +231,7 @@ def main(): num_classes=num_classes, subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - is_espnet_structure=params.is_espnet_structure, - mmi_loss=params.mmi_loss, + vgg_frontend=params.vgg_frontend, use_feat_batchnorm=params.use_feat_batchnorm, ) @@ -190,20 +250,40 @@ def main(): model.eval() librispeech = LibriSpeechAsrDataModule(args) + + train_dl = librispeech.train_dataloaders() + valid_dl = librispeech.valid_dataloaders() test_dl = librispeech.test_dataloaders() # a list + ali_dir = Path(params.ali_dir) + ali_dir.mkdir(exist_ok=True) + enabled_datasets = { "test_clean": test_dl[0], "test_other": test_dl[1], + "train-960": train_dl, + "valid": valid_dl, } - - compute_alignments( - model=model, - dl=enabled_datasets["test_clean"], - params=params, - graph_compiler=graph_compiler, - token_table=lexicon.token_table, - ) + for name, dl in enabled_datasets.items(): + logging.info(f"Processing {name}") + alignments = compute_alignments( + model=model, + dl=dl, + params=params, + graph_compiler=graph_compiler, + ) + num_utt = len(alignments) + alignments = dict(alignments) + assert num_utt == len(alignments) + filename = ali_dir / f"{name}.pt" + save_alignments( + alignments=alignments, + subsampling_factor=params.subsampling_factor, + filename=filename, + ) + logging.info( + f"For dataset {name}, its alignments are saved to {filename}" + ) torch.set_num_threads(1) diff --git a/icefall/utils.py b/icefall/utils.py index 36312f9afe..cb6cc17c5e 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -325,6 +325,51 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: return labels.tolist() +def save_alignments( + alignments: Dict[str, List[int]], + subsampling_factor: int, + filename: str, +) -> None: + """Save alignments to a file. + + Args: + alignments: + A dict containing alignments. Keys of the dict are utterances and + values are the corresponding framewise alignments after subsampling. + subsampling_factor: + The subsampling factor of the model. + filename: + Path to save the alignments. + Returns: + Return None. + """ + ali_dict = { + "subsampling_factor": subsampling_factor, + "alignments": alignments, + } + torch.save(ali_dict, filename) + + +def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: + """Load alignments from a file. + + Args: + filename: + Path to the file containing alignment information. + The file should be saved by :func:`save_alignments`. + Returns: + Return a tuple containing: + - subsampling_factor: The subsampling_factor used to compute + the alignments. + - alignments: A dict containing utterances and their corresponding + framewise alignment, after subsampling. + """ + ali_dict = torch.load(filename) + subsampling_factor = ali_dict["subsampling_factor"] + alignments = ali_dict["alignments"] + return subsampling_factor, alignments + + def store_transcripts( filename: Pathlike, texts: Iterable[Tuple[str, str]] ) -> None: From 0f3d9220d4d5d5ccca5bd75474de553fed3d528b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 25 Sep 2021 19:52:56 +0800 Subject: [PATCH 09/12] Update comments for the time to compute alignments of train-960. --- egs/librispeech/ASR/conformer_ctc/ali.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 07390f7e7e..aa5b6bc88c 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -76,7 +76,7 @@ def get_parser(): parser.add_argument( "--ali-dir", type=str, - default="data/ali", + default="data/ali_500", help="The experiment dir", ) return parser @@ -200,11 +200,15 @@ def main(): assert args.return_cuts is True assert args.concatenate_cuts is False + if args.full_libri is False: + print("Changing --full-libri to True") + args.full_libri = True params = get_params() params.update(vars(args)) setup_logger(f"{params.exp_dir}/log/ali") + logging.info("Computing alignment - started") logging.info(params) @@ -264,8 +268,19 @@ def main(): "train-960": train_dl, "valid": valid_dl, } + # For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to + # compute the alignments if you use --max-duration=500 + # + # There are 960 * 3 = 2880 hours data and it takes only + # 3 hours 40 minutes to get the alignment. + # The RTF is roughly: 3.67 / 2880 = 0.0012743 for name, dl in enabled_datasets.items(): logging.info(f"Processing {name}") + if name == "train-960": + logging.info( + "It will take about 3 hours 40 minutes for {name}, " + "which contains 960 * 3 = 2880 hours of data" + ) alignments = compute_alignments( model=model, dl=dl, From b27d67d2fbe2684ba8f19728ae089f038822192d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 28 Sep 2021 08:18:06 +0800 Subject: [PATCH 10/12] Preserve cut id in mix cut transformer. --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 8290e71d13..d3eab87a9c 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -162,7 +162,9 @@ def train_dataloaders(self) -> DataLoader: cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") logging.info("About to create train dataset") - transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + transforms = [ + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ] if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " From e4c53881e68034f2da43c017b6a4ef3e56de5b88 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 28 Sep 2021 11:41:02 +0800 Subject: [PATCH 11/12] Minor fixes. --- egs/librispeech/ASR/conformer_ctc/ali.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index aa5b6bc88c..c79c4e277f 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -274,11 +274,15 @@ def main(): # There are 960 * 3 = 2880 hours data and it takes only # 3 hours 40 minutes to get the alignment. # The RTF is roughly: 3.67 / 2880 = 0.0012743 + # + # At the end, you would see + # 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa + # 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa for name, dl in enabled_datasets.items(): logging.info(f"Processing {name}") if name == "train-960": logging.info( - "It will take about 3 hours 40 minutes for {name}, " + f"It will take about 3 hours 40 minutes for {name}, " "which contains 960 * 3 = 2880 hours of data" ) alignments = compute_alignments( From 07140e5d5c56c7defa10b2e5af4673ac289d1f84 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Oct 2021 14:21:21 +0800 Subject: [PATCH 12/12] Add doc about how to extract framewise alignments. --- egs/librispeech/ASR/conformer_ctc/README.md | 50 +++++++++++++++++++++ egs/librispeech/ASR/conformer_ctc/ali.py | 4 +- egs/librispeech/ASR/conformer_ctc/decode.py | 2 +- egs/librispeech/ASR/conformer_ctc/export.py | 2 +- egs/librispeech/ASR/conformer_ctc/train.py | 2 +- 5 files changed, 56 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md index 23b51167b6..164c3e53e1 100644 --- a/egs/librispeech/ASR/conformer_ctc/README.md +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -1,3 +1,53 @@ +## Introduction + Please visit for how to run this recipe. + +## How to compute framewise alignment information + +### Step 1: Train a model + +Please use `conformer_ctc/train.py` to train a model. +See +for how to do it. + +### Step 2: Compute framewise alignment + +Run + +``` +# Choose a checkpoint and determine the number of checkpoints to average +epoch=30 +avg=15 +./conformer_ctc/ali.py \ + --epoch $epoch \ + --avg $avg \ + --max-duration 500 \ + --bucketing-sampler 0 \ + --full-libri 1 \ + --exp-dir conformer_ctc/exp \ + --lang-dir data/lang_bpe_5000 \ + --ali-dir data/ali_5000 +``` +and you will get four files inside the folder `data/ali_5000`: + +``` +$ ls -lh data/ali_500 +total 546M +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:06 test_clean.pt +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:07 test_other.pt +-rw-r--r-- 1 kuangfangjun root 542M Sep 28 11:36 train-960.pt +-rw-r--r-- 1 kuangfangjun root 2.1M Sep 28 11:38 valid.pt +``` + +**Note**: It can take more than 3 hours to compute the alignment +for the training dataset, which contains 960 * 3 = 2880 hours of data. + +**Caution**: The model parameters in `conformer_ctc/ali.py` have to match those +in `conformer_ctc/train.py`. + +**Caution**: You have to set the parameter `preserve_id` to `True` for `CutMix`. +Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`. + +**TODO:** Add doc about how to use the extracted alignment in the other pull-request. diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index c79c4e277f..3d817a8f6a 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -33,6 +33,7 @@ AttributeDict, encode_supervisions, get_alignments, + get_env_info, save_alignments, setup_logger, ) @@ -62,7 +63,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe", + default="data/lang_bpe_5000", help="The lang dir", ) @@ -95,6 +96,7 @@ def get_params() -> AttributeDict: "use_feat_batchnorm": True, "output_beam": 10, "use_double_scores": True, + "env_info": get_env_info(), } ) return params diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3fb5d262dc..bddb832b08 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -143,7 +143,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe", + default="data/lang_bpe_5000", help="The lang dir", ) diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index 8241c84c11..79e026daca 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -65,7 +65,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe", + default="data/lang_bpe_5000", help="""It contains language related input files such as "lexicon.txt" """, ) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index b76be8641d..ae088620f3 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -115,7 +115,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe", + default="data/lang_bpe_5000", help="""The lang dir It contains language related input files such as "lexicon.txt"