diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 6b6190a091..ed6a6ea82e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -19,6 +19,7 @@ from typing import Dict, List, Optional import k2 +import sentencepiece as spm import torch from model import Transducer @@ -34,6 +35,7 @@ def fast_beam_search_one_best( beam: float, max_states: int, max_contexts: int, + temperature: float = 1.0, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -56,6 +58,8 @@ def fast_beam_search_one_best( Max states per stream per frame. max_contexts: Max contexts pre stream per frame. + temperature: + Softmax temperature. Returns: Return the decoded result. """ @@ -67,6 +71,7 @@ def fast_beam_search_one_best( beam=beam, max_states=max_states, max_contexts=max_contexts, + temperature=temperature, ) best_path = one_best_decoding(lattice) @@ -85,6 +90,7 @@ def fast_beam_search_nbest_LG( num_paths: int, nbest_scale: float = 0.5, use_double_scores: bool = True, + temperature: float = 1.0, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -120,6 +126,8 @@ def fast_beam_search_nbest_LG( use_double_scores: True to use double precision for computation. False to use single precision. + temperature: + Softmax temperature. Returns: Return the decoded result. """ @@ -131,6 +139,7 @@ def fast_beam_search_nbest_LG( beam=beam, max_states=max_states, max_contexts=max_contexts, + temperature=temperature, ) nbest = Nbest.from_lattice( @@ -201,6 +210,7 @@ def fast_beam_search_nbest( num_paths: int, nbest_scale: float = 0.5, use_double_scores: bool = True, + temperature: float = 1.0, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -236,6 +246,8 @@ def fast_beam_search_nbest( use_double_scores: True to use double precision for computation. False to use single precision. + temperature: + Softmax temperature. Returns: Return the decoded result. """ @@ -247,6 +259,7 @@ def fast_beam_search_nbest( beam=beam, max_states=max_states, max_contexts=max_contexts, + temperature=temperature, ) nbest = Nbest.from_lattice( @@ -282,6 +295,7 @@ def fast_beam_search_nbest_oracle( ref_texts: List[List[int]], use_double_scores: bool = True, nbest_scale: float = 0.5, + temperature: float = 1.0, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -321,7 +335,8 @@ def fast_beam_search_nbest_oracle( nbest_scale: It's the scale applied to the lattice.scores. A smaller value yields more unique paths. - + temperature: + Softmax temperature. Returns: Return the decoded result. """ @@ -333,6 +348,7 @@ def fast_beam_search_nbest_oracle( beam=beam, max_states=max_states, max_contexts=max_contexts, + temperature=temperature, ) nbest = Nbest.from_lattice( @@ -373,6 +389,7 @@ def fast_beam_search( beam: float, max_states: int, max_contexts: int, + temperature: float = 1.0, ) -> k2.Fsa: """It limits the maximum number of symbols per frame to 1. @@ -392,6 +409,8 @@ def fast_beam_search( Max states per stream per frame. max_contexts: Max contexts pre stream per frame. + temperature: + Softmax temperature. Returns: Return an FsaVec with axes [utt][state][arc] containing the decoded lattice. Note: When the input graph is a TrivialGraph, the returned @@ -440,7 +459,7 @@ def fast_beam_search( project_input=False, ) logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) + log_probs = (logits / temperature).log_softmax(dim=-1) decoding_streams.advance(log_probs) decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) @@ -783,6 +802,7 @@ def modified_beam_search( encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, beam: int = 4, + temperature: float = 1.0, ) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -796,6 +816,8 @@ def modified_beam_search( encoder_out before padding. beam: Number of active paths during the beam search. + temperature: + Softmax temperature. Returns: Return a list-of-list of token IDs. ans[i] is the decoding results for the i-th utterance. @@ -879,7 +901,9 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -1043,6 +1067,7 @@ def beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + temperature: float = 1.0, ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -1056,6 +1081,8 @@ def beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + temperature: + Softmax temperature. Returns: Return the decoded result. """ @@ -1132,7 +1159,7 @@ def beam_search( ) # TODO(fangjun): Scale the blank posterior - log_prob = logits.log_softmax(dim=-1) + log_prob = (logits / temperature).log_softmax(dim=-1) # log_prob is (1, 1, 1, vocab_size) log_prob = log_prob.squeeze() # Now log_prob is (vocab_size,) @@ -1171,3 +1198,155 @@ def beam_search( best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks return ys + + +def fast_beam_search_with_nbest_rescoring( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, +) -> Dict[str, List[List[int]]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using modified beam search, and then + the shortest path within the lattice is used as the final output. + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + ans: Dict[str, List[List[int]]] = {} + for s in ngram_lm_scale_list: + key = f"ngram_lm_scale_{s}" + tot_scores = am_scores.values + s * ngram_lm_scores + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) + + ans[key] = hyps + + return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 44fc346405..8f55413e46 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -111,6 +111,7 @@ fast_beam_search_nbest_LG, fast_beam_search_nbest_oracle, fast_beam_search_one_best, + fast_beam_search_with_nbest_rescoring, greedy_search, greedy_search_batch, modified_beam_search, @@ -312,6 +313,35 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="""Softmax temperature. + The output of the model is (logits / temperature).log_softmax(). + """, + ) + + parser.add_argument( + "--lm-dir", + type=Path, + default=Path("./data/lm"), + help="""Used only when --decoding-method is + fast_beam_search_with_nbest_rescoring. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--words-txt", + type=Path, + default=Path("./data/lang_bpe_500/words.txt"), + help="""Used only when --decoding-method is + fast_beam_search_with_nbest_rescoring. + It is the word table. + """, + ) + add_model_arguments(parser) return parser @@ -324,6 +354,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + G: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -352,6 +383,11 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + G: + Optional. Used only when decoding method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, + or fast_beam_search_with_nbest_rescoring. + It an FsaVec containing an acceptor. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -397,6 +433,7 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + temperature=params.temperature, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -411,6 +448,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + temperature=params.temperature, ) for hyp in hyp_tokens: hyps.append([word_table[i] for i in hyp]) @@ -425,6 +463,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + temperature=params.temperature, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -440,6 +479,7 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + temperature=params.temperature, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -460,9 +500,32 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + temperature=params.temperature, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": + ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0] + ngram_lm_scale_list += [0.01, 0.02, 0.05] + ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8] + ngram_lm_scale_list += [1.0, 1.5, 2.5, 3] + hyp_tokens = fast_beam_search_with_nbest_rescoring( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ngram_lm_scale_list=ngram_lm_scale_list, + num_paths=params.num_paths, + G=G, + sp=sp, + word_table=word_table, + use_double_scores=True, + nbest_scale=params.nbest_scale, + temperature=params.temperature, + ) else: batch_size = encoder_out.size(0) @@ -496,6 +559,7 @@ def decode_one_batch( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" + f"temperature_{params.temperature}" ): hyps } elif params.decoding_method == "fast_beam_search": @@ -504,8 +568,23 @@ def decode_one_batch( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" + f"temperature_{params.temperature}" ): hyps } + elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": + prefix = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}_" + f"temperature_{params.temperature}_" + ) + ans: Dict[str, List[List[str]]] = {} + for key, hyp in hyp_tokens.items(): + t: List[str] = sp.decode(hyp) + ans[prefix + key] = [s.split() for s in t] + return ans elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -515,10 +594,14 @@ def decode_one_batch( key += f"nbest_scale_{params.nbest_scale}" if "LG" in params.decoding_method: key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - return {key: hyps} else: - return {f"beam_size_{params.beam_size}": hyps} + return { + ( + f"beam_size_{params.beam_size}_" + f"temperature_{params.temperature}" + ): hyps + } def decode_dataset( @@ -528,6 +611,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + G: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -546,6 +630,11 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + G: + Optional. Used only when decoding method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, + or fast_beam_search_with_nbest_rescoring. + It's an FsaVec containing an acceptor. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -576,6 +665,7 @@ def decode_dataset( word_table=word_table, decoding_graph=decoding_graph, batch=batch, + G=G, ) for name, hyps in hyps_dict.items(): @@ -642,6 +732,71 @@ def save_results( logging.info(s) +def load_ngram_LM( + lm_dir: Path, word_table: k2.SymbolTable, device: torch.device +) -> k2.Fsa: + """Read a ngram model from the given directory. + Args: + lm_dir: + It should contain either G_4_gram.pt or G_4_gram.fst.txt + word_table: + The word table mapping words to IDs and vice versa. + device: + The resulting FSA will be moved to this device. + Returns: + Return an FsaVec containing a single acceptor. + """ + lm_dir = Path(lm_dir) + assert lm_dir.is_dir(), f"{lm_dir} does not exist" + + pt_file = lm_dir / "G_4_gram.pt" + + if pt_file.is_file(): + logging.info(f"Loading pre-compiled {pt_file}") + d = torch.load(pt_file, map_location=device) + G = k2.Fsa.from_dict(d) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + return G + + txt_file = lm_dir / "G_4_gram.fst.txt" + + assert txt_file.is_file(), f"{txt_file} does not exist" + logging.info(f"Loading {txt_file}") + logging.warning("It may take 8 minutes (Will be cached for later use).") + with open(txt_file) as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # Now G is an acceptor + + first_word_disambig_id = word_table["#0"] + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + + G = k2.Fsa.from_fsas([G]).to(device) + + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + logging.info(f"Saving to {pt_file} for later use") + torch.save(G.as_dict(), pt_file) + + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + return G + + @torch.no_grad() def main(): parser = get_parser() @@ -660,6 +815,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "fast_beam_search_with_nbest_rescoring", ) params.res_dir = params.exp_dir / params.decoding_method @@ -676,6 +832,7 @@ def main(): params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-temperature-{params.temperature}" if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" @@ -685,9 +842,11 @@ def main(): params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" ) + params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-temperature-{params.temperature}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -760,6 +919,19 @@ def main(): torch.load(lg_filename, map_location=device) ) decoding_graph.scores *= params.ngram_lm_scale + elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": + logging.info(f"Loading word symbol table from {params.words_txt}") + word_table = k2.SymbolTable.from_file(params.words_txt) + + G = load_ngram_LM( + lm_dir=params.lm_dir, + word_table=word_table, + device=device, + ) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + logging.info(f"G properties_str: {G.properties_str}") else: word_table = None decoding_graph = k2.trivial_graph( @@ -792,6 +964,7 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + G=G, ) save_results(