From d513d456b8b51e6a52a2005efee71394d767ab48 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Wed, 30 Oct 2024 10:14:34 +0800 Subject: [PATCH] Add prefix beam search and corresponding decoding methods (#1786) * Add prefix beam search / shallow fussion / hotwords in librispeech ctc decode * Add librispeech cr-ctc prefix beam search results --- egs/librispeech/ASR/RESULTS.md | 9 +- egs/librispeech/ASR/zipformer/ctc_decode.py | 238 ++++++- icefall/decode.py | 674 +++++++++++++++++++- icefall/utils.py | 11 + 4 files changed, 908 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 6a669f072d..e5a82dfda2 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -153,6 +153,7 @@ You can use to deploy it. | decoding method | test-clean | test-other | comment | |--------------------------------------|------------|------------|---------------------| | ctc-greedy-decoding | 2.57 | 5.95 | --epoch 50 --avg 25 | +| ctc-prefix-beam-search | 2.52 | 5.85 | --epoch 50 --avg 25 | The training command using 2 32G-V100 GPUs is: ```bash @@ -184,7 +185,7 @@ export CUDA_VISIBLE_DEVICES="0,1" The decoding command is: ```bash export CUDA_VISIBLE_DEVICES="0" -for m in ctc-greedy-search; do +for m in ctc-greedy-search ctc-prefix-beam-search; do ./zipformer/ctc_decode.py \ --epoch 50 \ --avg 25 \ @@ -212,6 +213,7 @@ You can use to deploy it. | decoding method | test-clean | test-other | comment | |--------------------------------------|------------|------------|---------------------| | ctc-greedy-decoding | 2.12 | 4.62 | --epoch 50 --avg 24 | +| ctc-prefix-beam-search | 2.1 | 4.61 | --epoch 50 --avg 24 | The training command using 4 32G-V100 GPUs is: ```bash @@ -238,7 +240,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" The decoding command is: ```bash export CUDA_VISIBLE_DEVICES="0" -for m in ctc-greedy-search; do +for m in ctc-greedy-search ctc-prefix-beam-search; do ./zipformer/ctc_decode.py \ --epoch 50 \ --avg 24 \ @@ -262,6 +264,7 @@ You can use to deploy it. | decoding method | test-clean | test-other | comment | |--------------------------------------|------------|------------|---------------------| | ctc-greedy-decoding | 2.03 | 4.37 | --epoch 50 --avg 26 | +| ctc-prefix-beam-search | 2.02 | 4.35 | --epoch 50 --avg 26 | The training command using 2 80G-A100 GPUs is: ```bash @@ -292,7 +295,7 @@ export CUDA_VISIBLE_DEVICES="0,1" The decoding command is: ```bash export CUDA_VISIBLE_DEVICES="0" -for m in ctc-greedy-search; do +for m in ctc-greedy-search ctc-prefix-beam-search; do ./zipformer/ctc_decode.py \ --epoch 50 \ --avg 26 \ diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 9db4299592..156989b780 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -111,6 +111,7 @@ import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -129,8 +130,14 @@ find_checkpoints, load_checkpoint, ) + +from icefall.context_graph import ContextGraph, ContextState + from icefall.decode import ( ctc_greedy_search, + ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, get_lattice, nbest_decoding, nbest_oracle, @@ -140,7 +147,11 @@ rescore_with_n_best_list, rescore_with_whole_lattice, ) + +from icefall.ngram_lm import NgramLm, NgramLmStateCost from icefall.lexicon import Lexicon +from icefall.lm_wrapper import LmScorer + from icefall.utils import ( AttributeDict, get_texts, @@ -255,6 +266,12 @@ def get_parser(): lattice, rescore them with the attention decoder. - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM rescored lattice, rescore them with the attention decoder. + - (10) ctc-prefix-beam-search. Extract n paths with the given beam, the best + path of the n paths is the decoding result. + - (11) ctc-prefix-beam-search-attention-decoder-rescoring. Extract n paths with + the given beam, rescore them with the attention decoder. + - (12) ctc-prefix-beam-search-shallow-fussion. Use NNLM shallow fussion during + beam search, LODR and hotwords are also supported in this decoding method. """, ) @@ -280,6 +297,23 @@ def get_parser(): """, ) + parser.add_argument( + "--nnlm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--nnlm-scale", + type=float, + default=0, + help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion. + Used only when `--use-shallow-fusion` is set to True. + """, + ) + parser.add_argument( "--hlg-scale", type=float, @@ -297,11 +331,52 @@ def get_parser(): """, ) + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--lodr-ngram", + type=str, + help="The path to the lodr ngram", + ) + + parser.add_argument( + "--lodr-lm-scale", + type=float, + default=0, + help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.", + ) + + parser.add_argument( + "--context-score", + type=float, + default=0, + help=""" + The bonus score of each token for the context biasing words/phrases. + 0 means don't use contextual biasing. + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + parser.add_argument( "--skip-scoring", type=str2bool, default=False, - help="""Skip scoring, but still save the ASR output (for eval sets).""" + help="""Skip scoring, but still save the ASR output (for eval sets).""", ) add_model_arguments(parser) @@ -314,11 +389,12 @@ def get_decoding_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "beam": 4, # for prefix-beam-search } ) return params @@ -333,6 +409,9 @@ def decode_one_batch( batch: dict, word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -377,10 +456,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ - if HLG is not None: - device = HLG.device - else: - device = H.device + device = params.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -411,6 +487,51 @@ def decode_one_batch( key = "ctc-greedy-search" return {key: hyps} + if params.decoding_method == "ctc-prefix-beam-search": + token_ids = ctc_prefix_beam_search( + ctc_output=ctc_output, encoder_out_lens=encoder_out_lens + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring": + best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output=ctc_output, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + ans = dict() + for a_scale_str, token_ids in best_path_dict.items(): + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + token_ids = ctc_prefix_beam_search_shallow_fussion( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + NNLM=NNLM, + LODR_lm=LODR_lm, + LODR_lm_scale=params.lodr_lm_scale, + context_graph=context_graph, + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search-shallow-fussion" + return {key: hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -584,6 +705,9 @@ def decode_dataset( bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -634,6 +758,9 @@ def decode_dataset( batch=batch, word_table=word_table, G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, ) for name, hyps in hyps_dict.items(): @@ -664,9 +791,7 @@ def save_asr_output( """ for key, results in results_dict.items(): - recogs_filename = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recogs_filename, texts=results) @@ -680,7 +805,8 @@ def save_wer_results( results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): if params.decoding_method in ( - "attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring" + "attention-decoder-rescoring-with-ngram", + "whole-lattice-rescoring", ): # Set it to False since there are too many logs. enable_log = False @@ -721,6 +847,7 @@ def save_wer_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -735,8 +862,11 @@ def main(): set_caching_enabled(True) # lhotse assert params.decoding_method in ( - "ctc-greedy-search", "ctc-decoding", + "ctc-greedy-search", + "ctc-prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", "1best", "nbest", "nbest-rescoring", @@ -762,6 +892,16 @@ def main(): params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_left-context-{params.left_context_frames}" + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + if params.nnlm_scale != 0: + params.suffix += f"_nnlm-scale-{params.nnlm_scale}" + if params.lodr_lm_scale != 0: + params.suffix += f"_lodr-scale-{params.lodr_lm_scale}" + if params.context_score != 0: + params.suffix += f"_context_score-{params.context_score}" + if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -771,6 +911,7 @@ def main(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) + params.device = device logging.info(f"Device: {device}") logging.info(params) @@ -786,14 +927,24 @@ def main(): params.sos_id = 1 if params.decoding_method in [ - "ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram" + "ctc-decoding", + "ctc-greedy-search", + "ctc-prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", + "attention-decoder-rescoring-no-ngram", ]: HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) + H = None + if params.decoding_method in [ + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) bpe_model = spm.SentencePieceProcessor() bpe_model.load(str(params.lang_dir / "bpe.model")) else: @@ -844,7 +995,8 @@ def main(): G = k2.Fsa.from_dict(d) if params.decoding_method in [ - "whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram" + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later @@ -858,6 +1010,51 @@ def main(): else: G = None + # only load the neural network LM if required + NNLM = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.nnlm_scale != 0 + ): + NNLM = LmScorer( + lm_type=params.nnlm_type, + params=params, + device=device, + lm_scale=params.nnlm_scale, + ) + NNLM.to(device) + NNLM.eval() + + LODR_lm = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.lodr_lm_scale != 0 + ): + assert os.path.exists( + params.lodr_ngram + ), f"LODR ngram does not exists, given path : {params.lodr_ngram}" + logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}") + LODR_lm = NgramLm( + params.lodr_ngram, + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {LODR_lm.lm.num_states}") + + context_graph = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.context_score != 0 + ): + assert os.path.exists( + params.context_file + ), f"context_file does not exists, given path : {params.context_file}" + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(bpe_model.encode(line.strip())) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + logging.info("About to create model") model = get_model(params) @@ -967,6 +1164,9 @@ def main(): bpe_model=bpe_model, word_table=lexicon.word_table, G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, ) save_asr_output( diff --git a/icefall/decode.py b/icefall/decode.py index dd3af1e99b..5f90ee168c 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -15,11 +16,16 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Union +from dataclasses import dataclass, field +from multiprocessing.pool import Pool +from typing import Dict, List, Optional, Tuple, Union import k2 import torch +from icefall.context_graph import ContextGraph, ContextState +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.lm_wrapper import LmScorer from icefall.utils import add_eos, add_sos, get_texts DEFAULT_LM_SCALE = [ @@ -1497,3 +1503,667 @@ def ctc_greedy_search( hyps = [h[h != blank_id].tolist() for h in hyps] return hyps + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] = field(default_factory=list) + + # The log prob of ys that ends with blank token. + # It contains only one entry. + log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32) + + # The log prob of ys that ends with non blank token. + # It contains only one entry. + log_prob_non_blank: torch.Tensor = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] = field(default_factory=list) + + # The lm score of ys + # May contain external LM score (including LODR score) and contextual biasing score + # It contains only one entry + lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32) + + # the lm log_probs for next token given the history ys + # The number of elements should be equal to vocabulary size. + lm_log_probs: Optional[torch.Tensor] = None + + # the RNNLM states (h and c in LSTM) + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # LODR (N-gram LM) state + LODR_state: Optional[NgramLmStateCost] = None + + # N-gram LM state + Ngram_state: Optional[NgramLmStateCost] = None + + # Context graph state + context_state: Optional[ContextState] = None + + # This is the total score of current path, acoustic plus external LM score. + @property + def tot_score(self) -> torch.Tensor: + return self.log_prob + self.lm_score + + # This is only the probability from model output (i.e External LM score not included). + @property + def log_prob(self) -> torch.Tensor: + return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank) + + @property + def key(self) -> tuple: + """Return a tuple representation of self.ys""" + return tuple(self.ys) + + def clone(self) -> "Hypothesis": + return Hypothesis( + ys=self.ys, + log_prob_blank=self.log_prob_blank, + log_prob_non_blank=self.log_prob_non_blank, + timestamp=self.timestamp, + lm_log_probs=self.lm_log_probs, + lm_score=self.lm_score, + state=self.state, + LODR_state=self.LODR_state, + Ngram_state=self.Ngram_state, + context_state=self.context_state, + ) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[tuple, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob_blank, hyp.log_prob_blank, out=old_hyp.log_prob_blank + ) + torch.logaddexp( + old_hyp.log_prob_non_blank, + hyp.log_prob_non_blank, + out=old_hyp.log_prob_non_blank, + ) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `tot_score`. + Args: + length_norm: + If True, the `tot_score` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `tot_score`. + """ + if length_norm: + return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys)) + else: + return max(self._data.values(), key=lambda hyp: hyp.tot_score) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + Caution: + `self` is modified **in-place**. + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose tot_score is less than threshold. + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `tot_score` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.tot_score > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + Args: + length_norm: + If True, the `tot_score` of a hypothesis is normalized by the + number of tokens in it. + """ + hyps = list(self._data.items()) + + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].tot_score / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].tot_score, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: tuple): + return key in self._data + + def __getitem__(self, key: tuple): + return self._data[key] + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(str(s)) + + +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def _step_worker( + log_probs: torch.Tensor, + indexes: torch.Tensor, + B: HypothesisList, + beam: int = 4, + blank_id: int = 0, + nnlm_scale: float = 0, + LODR_lm_scale: float = 0, + context_graph: Optional[ContextGraph] = None, +) -> HypothesisList: + """The worker to decode one step. + Args: + log_probs: + topk log_probs of current step (i.e. the kept tokens of first pass pruning), + the shape is (beam,) + topk_indexes: + The indexes of the topk_values above, the shape is (beam,) + B: + An instance of HypothesisList containing the kept hypothesis. + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + lm_scale: + The scale of nn lm. + LODR_lm_scale: + The scale of the LODR_lm + context_graph: + A ContextGraph instance containing contextual phrases. + Return: + Returns the updated HypothesisList. + """ + A = list(B) + B = HypothesisList() + for h in range(len(A)): + hyp = A[h] + for k in range(log_probs.size(0)): + log_prob, index = log_probs[k], indexes[k] + new_token = index.item() + update_prefix = False + new_hyp = hyp.clone() + if new_token == blank_id: + # Case 0: *a + ε => *a + # *aε + ε => *a + # Prefix does not change, update log_prob of blank + new_hyp.log_prob_non_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + new_hyp.log_prob_blank = hyp.log_prob + log_prob + B.add(new_hyp) + elif len(hyp.ys) > 0 and hyp.ys[-1] == new_token: + # Case 1: *a + a => *a + # Prefix does not change, update log_prob of non_blank + new_hyp.log_prob_non_blank = hyp.log_prob_non_blank + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + B.add(new_hyp) + + # Case 2: *aε + a => *aa + # Prefix changes, update log_prob of blank + new_hyp = hyp.clone() + # Caution: DO NOT use append, as clone is shallow copy + new_hyp.ys = hyp.ys + [new_token] + new_hyp.log_prob_non_blank = hyp.log_prob_blank + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + update_prefix = True + else: + # Case 3: *a + b => *ab, *aε + b => *ab + # Prefix changes, update log_prob of non_blank + # Caution: DO NOT use append, as clone is shallow copy + new_hyp.ys = hyp.ys + [new_token] + new_hyp.log_prob_non_blank = hyp.log_prob + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + update_prefix = True + + if update_prefix: + lm_score = hyp.lm_score + if hyp.lm_log_probs is not None: + lm_score = lm_score + hyp.lm_log_probs[new_token] * nnlm_scale + new_hyp.lm_log_probs = None + + if context_graph is not None and hyp.context_state is not None: + ( + context_score, + new_context_state, + matched_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + lm_score = lm_score + context_score + new_hyp.context_state = new_context_state + + if hyp.LODR_state is not None: + state_cost = hyp.LODR_state.forward_one_step(new_token) + # calculate the score of the latest token + current_ngram_score = state_cost.lm_score - hyp.LODR_state.lm_score + assert current_ngram_score <= 0.0, ( + state_cost.lm_score, + hyp.LODR_state.lm_score, + ) + lm_score = lm_score + LODR_lm_scale * current_ngram_score + new_hyp.LODR_state = state_cost + + new_hyp.lm_score = lm_score + B.add(new_hyp) + B = B.topk(beam) + return B + + +def _sequence_worker( + topk_values: torch.Tensor, + topk_indexes: torch.Tensor, + B: HypothesisList, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, +) -> HypothesisList: + """The worker to decode one sequence. + Args: + topk_values: + topk log_probs of model output (i.e. the kept tokens of first pass pruning), + the shape is (T, beam) + topk_indexes: + The indexes of the topk_values above, the shape is (T, beam) + B: + An instance of HypothesisList containing the kept hypothesis. + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + Return: + Returns the updated HypothesisList. + """ + B.add(Hypothesis()) + for j in range(encoder_out_lens): + log_probs, indexes = topk_values[j], topk_indexes[j] + B = _step_worker(log_probs, indexes, B, beam, blank_id) + return B + + +def ctc_prefix_beam_search( + ctc_output: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, + process_pool: Optional[Pool] = None, + return_nbest: Optional[bool] = False, +) -> Union[List[List[int]], List[HypothesisList]]: + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks". + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + process_pool: + The process pool for parallel decoding, if not provided, it will use all + you cpu cores by default. + return_nbest: + If true, return a list of HypothesisList, return a list of list of decoded token ids otherwise. + """ + batch_size, num_frames, vocab_size = ctc_output.shape + + # TODO: using a larger beam for first pass pruning + topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam) + topk_values = topk_values.cpu() + topk_indexes = topk_indexes.cpu() + + B = [HypothesisList() for _ in range(batch_size)] + + pool = Pool() if process_pool is None else process_pool + arguments = [] + for i in range(batch_size): + arguments.append( + ( + topk_values[i], + topk_indexes[i], + B[i], + encoder_out_lens[i].item(), + beam, + blank_id, + ) + ) + async_results = pool.starmap_async(_sequence_worker, arguments) + B = list(async_results.get()) + if process_pool is None: + pool.close() + pool.join() + if return_nbest: + return B + else: + best_hyps = [b.get_most_probable() for b in B] + return [hyp.ys for hyp in best_hyps] + + +def ctc_prefix_beam_search_shallow_fussion( + ctc_output: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, + LODR_lm: Optional[NgramLm] = None, + LODR_lm_scale: Optional[float] = 0, + NNLM: Optional[LmScorer] = None, + context_graph: Optional[ContextGraph] = None, +) -> List[List[int]]: + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add + nervous language model shallow fussion, it also supports contextual + biasing with a given grammar. + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + LODR_lm: + A low order n-gram LM, whose score will be subtracted during shallow fusion + LODR_lm_scale: + The scale of the LODR_lm + LM: + A neural net LM, e.g an RNNLM or transformer LM + context_graph: + A ContextGraph instance containing contextual phrases. + Return: + Returns a list of list of decoded token ids. + """ + batch_size, num_frames, vocab_size = ctc_output.shape + # TODO: using a larger beam for first pass pruning + topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam) + topk_values = topk_values.cpu() + topk_indexes = topk_indexes.cpu() + encoder_out_lens = encoder_out_lens.tolist() + device = ctc_output.device + + nnlm_scale = 0 + init_scores = None + init_states = None + if NNLM is not None: + nnlm_scale = NNLM.lm_scale + sos_id = getattr(NNLM, "sos_id", 1) + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_scores, init_states = NNLM.score_token(sos_token, lens) + init_scores, init_states = init_scores.cpu(), ( + init_states[0].cpu(), + init_states[1].cpu(), + ) + + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[], + log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32), + log_prob_blank=torch.zeros(1, dtype=torch.float32), + lm_score=torch.zeros(1, dtype=torch.float32), + state=init_states, + lm_log_probs=None if init_scores is None else init_scores.reshape(-1), + LODR_state=None if LODR_lm is None else NgramLmStateCost(LODR_lm), + context_state=None if context_graph is None else context_graph.root, + ) + ) + for j in range(num_frames): + for i in range(batch_size): + if j < encoder_out_lens[i]: + log_probs, indexes = topk_values[i][j], topk_indexes[i][j] + B[i] = _step_worker( + log_probs=log_probs, + indexes=indexes, + B=B[i], + beam=beam, + blank_id=blank_id, + nnlm_scale=nnlm_scale, + LODR_lm_scale=LODR_lm_scale, + context_graph=context_graph, + ) + if NNLM is None: + continue + # update lm_log_probs + token_list = [] # a list of list + hs = [] + cs = [] + indexes = [] # (batch_idx, key) + for batch_idx, hyps in enumerate(B): + for hyp in hyps: + if hyp.lm_log_probs is None: # those hyps that prefix changes + if NNLM.lm_type == "rnn": + token_list.append([hyp.ys[-1]]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append([sos_id] + hyp.ys[:]) + indexes.append((batch_idx, hyp.key)) + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if NNLM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + state = None + + scores, lm_states = NNLM.score_token(tokens_to_score, x_lens, state) + scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu()) + assert scores.size(0) == len(indexes), (scores.size(0), len(indexes)) + for i in range(scores.size(0)): + batch_idx, key = indexes[i] + B[batch_idx][key].lm_log_probs = scores[i] + if NNLM.lm_type == "rnn": + state = ( + lm_states[0][:, i, :].unsqueeze(1), + lm_states[1][:, i, :].unsqueeze(1), + ) + B[batch_idx][key].state = state + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + for hyps in B: + for hyp in hyps: + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + hyp.lm_score += context_score + hyp.context_state = new_context_state + + best_hyps = [b.get_most_probable() for b in B] + return [hyp.ys for hyp in best_hyps] + + +def ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output: torch.Tensor, + attention_decoder: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 8, + blank_id: int = 0, + attention_scale: Optional[float] = None, + process_pool: Optional[Pool] = None, +): + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add + attention decoder rescoring. + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + attention_decoder: + The attention decoder. + encoder_out: + The output of encoder, the shape is (B, T, D) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + attention_scale: + The scale of attention decoder score, if not provided it will search in + a default list (see the code below). + process_pool: + The process pool for parallel decoding, if not provided, it will use all + you cpu cores by default. + """ + # List[HypothesisList] + nbest = ctc_prefix_beam_search( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + beam=beam, + blank_id=blank_id, + return_nbest=True, + ) + + device = ctc_output.device + + hyp_shape = get_hyps_shape(nbest).to(device) + hyp_to_utt_map = hyp_shape.row_ids(1).to(torch.long) + # the shape of encoder_out is (N, T, C), so we use axis=0 here + expanded_encoder_out = encoder_out.index_select(0, hyp_to_utt_map) + expanded_encoder_out_lens = encoder_out_lens.index_select(0, hyp_to_utt_map) + + nbest = [list(x) for x in nbest] + token_ids = [] + scores = [] + for hyps in nbest: + for hyp in hyps: + token_ids.append(hyp.ys) + scores.append(hyp.log_prob.reshape(1)) + scores = torch.cat(scores).to(device) + + nll = attention_decoder.nll( + encoder_out=expanded_encoder_out, + encoder_out_lens=expanded_encoder_out_lens, + token_ids=token_ids, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + + start_indexes = hyp_shape.row_splits(1)[0:-1] + for a_scale in attention_scale_list: + tot_scores = scores + a_scale * attention_scores + ragged_tot_scores = k2.RaggedTensor(hyp_shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + max_indexes = max_indexes - start_indexes + max_indexes = max_indexes.cpu() + best_path = [nbest[i][max_indexes[i]].ys for i in range(len(max_indexes))] + key = f"attention_scale_{a_scale}" + ans[key] = best_path + return ans diff --git a/icefall/utils.py b/icefall/utils.py index 0682252f95..41eebadd46 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -19,8 +19,10 @@ import argparse import collections +import json import logging import os +import pathlib import random import re import subprocess @@ -180,6 +182,15 @@ def __delattr__(self, key): return raise AttributeError(f"No such attribute '{key}'") + def __str__(self, indent: int = 2): + tmp = {} + for k, v in self.items(): + # PosixPath is ont JSON serializable + if isinstance(v, pathlib.Path) or isinstance(v, torch.device): + v = str(v) + tmp[k] = v + return json.dumps(tmp, indent=indent, sort_keys=True) + def encode_supervisions( supervisions: dict,