diff --git a/egs/librispeech/ASR/add_alignments.sh b/egs/librispeech/ASR/add_alignments.sh new file mode 100755 index 0000000000..5e4480bf6f --- /dev/null +++ b/egs/librispeech/ASR/add_alignments.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +set -eou pipefail + +alignments_dir=data/alignment +cuts_in_dir=data/fbank +cuts_out_dir=data/fbank_ali + +python3 ./local/add_alignment_librispeech.py \ + --alignments-dir $alignments_dir \ + --cuts-in-dir $cuts_in_dir \ + --cuts-out-dir $cuts_out_dir diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py new file mode 100755 index 0000000000..cd1bcea679 --- /dev/null +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file adds alignments from https://github.com/CorentinJ/librispeech-alignments # noqa +to the existing fbank features dir (e.g., data/fbank) +and save cuts to a new dir (e.g., data/fbank_ali). +""" + +import argparse +import logging +import zipfile +from pathlib import Path +from typing import List + +from lhotse import CutSet, load_manifest_lazy +from lhotse.recipes.librispeech import parse_alignments +from lhotse.utils import is_module_available + +LIBRISPEECH_ALIGNMENTS_URL = ( + "https://drive.google.com/uc?id=1WYfgr31T-PPwMcxuAq09XZfHQO5Mw8fE" +) + +DATASET_PARTS = [ + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--alignments-dir", + type=str, + default="data/alignment", + help="The dir to save alignments.", + ) + + parser.add_argument( + "--cuts-in-dir", + type=str, + default="data/fbank", + help="The dir of the existing cuts without alignments.", + ) + + parser.add_argument( + "--cuts-out-dir", + type=str, + default="data/fbank_ali", + help="The dir to save the new cuts with alignments", + ) + + return parser + + +def download_alignments( + target_dir: str, alignments_url: str = LIBRISPEECH_ALIGNMENTS_URL +): + """ + Download and extract the alignments. + + Note: If you can not access drive.google.com, you could download the file + `LibriSpeech-Alignments.zip` from huggingface: + https://huggingface.co/Zengwei/librispeech-alignments + and extract the zip file manually. + + Args: + target_dir: + The dir to save alignments. + alignments_url: + The URL of alignments. + """ + """Modified from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/librispeech.py""" # noqa + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + completed_detector = target_dir / ".ali_completed" + if completed_detector.is_file(): + logging.info("The alignment files already exist.") + return + + ali_zip_path = target_dir / "LibriSpeech-Alignments.zip" + if not ali_zip_path.is_file(): + assert is_module_available( + "gdown" + ), 'To download LibriSpeech alignments, please install "pip install gdown"' # noqa + import gdown + + gdown.download(alignments_url, output=str(ali_zip_path)) + + with zipfile.ZipFile(str(ali_zip_path)) as f: + f.extractall(path=target_dir) + completed_detector.touch() + + +def add_alignment( + alignments_dir: str, + cuts_in_dir: str = "data/fbank", + cuts_out_dir: str = "data/fbank_ali", + dataset_parts: List[str] = DATASET_PARTS, +): + """ + Add alignment info to existing cuts. + + Args: + alignments_dir: + The dir of the alignments. + cuts_in_dir: + The dir of the existing cuts. + cuts_out_dir: + The dir to save the new cuts with alignments. + dataset_parts: + Librispeech parts to add alignments. + """ + alignments_dir = Path(alignments_dir) + cuts_in_dir = Path(cuts_in_dir) + cuts_out_dir = Path(cuts_out_dir) + cuts_out_dir.mkdir(parents=True, exist_ok=True) + + for part in dataset_parts: + logging.info(f"Processing {part}") + + cuts_in_path = cuts_in_dir / f"librispeech_cuts_{part}.jsonl.gz" + if not cuts_in_path.is_file(): + logging.info(f"{cuts_in_path} does not exist - skipping.") + continue + cuts_out_path = cuts_out_dir / f"librispeech_cuts_{part}.jsonl.gz" + if cuts_out_path.is_file(): + logging.info(f"{part} already exists - skipping.") + continue + + # parse alignments + alignments = {} + part_ali_dir = alignments_dir / "LibriSpeech" / part + for ali_path in part_ali_dir.rglob("*.alignment.txt"): + ali = parse_alignments(ali_path) + alignments.update(ali) + logging.info( + f"{part} has {len(alignments.keys())} cuts with alignments." + ) + + # add alignment attribute and write out + cuts_in = load_manifest_lazy(cuts_in_path) + with CutSet.open_writer(cuts_out_path) as writer: + for cut in cuts_in: + for idx, subcut in enumerate(cut.supervisions): + origin_id = subcut.id.split("_")[0] + if origin_id in alignments: + ali = alignments[origin_id] + else: + logging.info( + f"Warning: {origin_id} does not has alignment." + ) + ali = [] + subcut.alignment = {"word": ali} + writer.write(cut, flush=True) + + +def main(): + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + download_alignments(args.alignments_dir) + add_alignment(args.alignments_dir, args.cuts_in_dir, args.cuts_out_dir) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 5be23c50c2..052d027e31 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -91,6 +91,22 @@ --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +To evaluate symbol delay, you should: +(1) Generate cuts with word-time alignments: +./local/add_alignment_librispeech.py \ + --alignments-dir data/alignment \ + --cuts-in-dir data/fbank \ + --cuts-out-dir data/fbank_ali +(2) Set the argument "--manifest-dir data/fbank_ali" while decoding. +For example: +./lstm_transducer_stateless3/decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method greedy_search \ + --manifest-dir data/fbank_ali """ @@ -127,10 +143,12 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + DecodingResults, + parse_hyp_and_timestamp, setup_logger, - store_transcripts, + store_transcripts_and_timestamps, str2bool, - write_error_stats, + write_error_stats_with_timestamps, ) LOG_EPS = math.log(1e-10) @@ -314,7 +332,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: +) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -322,9 +340,11 @@ def decode_one_batch( if greedy_search is used, it would be "greedy_search" If beam search with a beam size of 7 is used, it would be "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. + - value: It is a tuple. `len(value[0])` and `len(value[1])` are both + equal to the batch size. `value[0][i]` and `value[1][i]` + are the decoding result and timestamps for the i-th utterance + in the given batch respectively. + Args: params: It's the return value of :func:`get_params`. @@ -343,8 +363,8 @@ def decode_one_batch( only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: - Return the decoding result. See above description for the format of - the returned dict. + Return the decoding result and timestamps. See above description for the + format of the returned dict. """ device = next(model.parameters()).device feature = batch["inputs"] @@ -370,10 +390,8 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) - hyps = [] - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( + res = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -381,11 +399,10 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( + res = fast_beam_search_nbest_LG( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -395,11 +412,10 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( + res = fast_beam_search_nbest( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -409,11 +425,10 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( + res = fast_beam_search_nbest_oracle( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -424,56 +439,67 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 ): - hyp_tokens = greedy_search_batch( + res = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( + res = modified_beam_search( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) - + tokens = [] + timestamps = [] for i in range(batch_size): # fmt: off encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - hyp = greedy_search( + res = greedy_search( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + return_timestamps=True, ) elif params.decoding_method == "beam_search": - hyp = beam_search( + res = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + return_timestamps=True, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + tokens.extend(res.tokens) + timestamps.extend(res.timestamps) + res = DecodingResults(tokens=tokens, timestamps=timestamps) + + hyps, timestamps = parse_hyp_and_timestamp( + decoding_method=params.decoding_method, + res=res, + sp=sp, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + word_table=word_table, + ) if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search": (hyps, timestamps)} elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -484,9 +510,9 @@ def decode_one_batch( if "LG" in params.decoding_method: key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - return {key: hyps} + return {key: (hyps, timestamps)} else: - return {f"beam_size_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": (hyps, timestamps)} def decode_dataset( @@ -496,7 +522,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -517,9 +545,12 @@ def decode_dataset( Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. + Its value is a list of tuples. Each tuple contains five elements: + - cut_id + - reference transcript + - predicted result + - timestamp of reference transcript + - timestamp of predicted result """ num_cuts = 0 @@ -538,6 +569,18 @@ def decode_dataset( texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + timestamps_ref = [] + for cut in batch["supervisions"]["cut"]: + for s in cut.supervisions: + time = [] + if s.alignment is not None and "word" in s.alignment: + time = [ + aliword.start + for aliword in s.alignment["word"] + if aliword.symbol != "" + ] + timestamps_ref.append(time) + hyps_dict = decode_one_batch( params=params, model=model, @@ -547,12 +590,18 @@ def decode_dataset( batch=batch, ) - for name, hyps in hyps_dict.items(): + for name, (hyps, timestamps_hyp) in hyps_dict.items(): this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + assert len(hyps) == len(texts) and len(timestamps_hyp) == len( + timestamps_ref + ) + for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip( + cut_ids, hyps, texts, timestamps_hyp, timestamps_ref + ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -570,15 +619,19 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[ + str, + List[Tuple[List[str], List[str], List[str], List[float], List[float]]], + ], ): test_set_wers = dict() + test_set_delays = dict() for key, results in results_dict.items(): recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -587,10 +640,11 @@ def save_results( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: - wer = write_error_stats( + wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer + test_set_delays[key] = (mean_delay, var_delay) logging.info("Wrote detailed error stats to {}".format(errs_filename)) @@ -604,6 +658,19 @@ def save_results( for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) + delays_info = ( + params.res_dir + / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(delays_info, "w") as f: + print("settings\tsymbol-delay", file=f) + for key, val in test_set_delays: + print( + "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), + file=f, + ) + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: @@ -611,6 +678,15 @@ def save_results( note = "" logging.info(s) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_delays: + s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) + note = "" + logging.info(s) + @torch.no_grad() def main(): diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index dc3697ae70..fa50576d8a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -377,6 +377,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index c70618ef76..0004a24eb1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -16,7 +16,7 @@ import warnings from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import k2 import sentencepiece as spm @@ -25,7 +25,13 @@ from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding -from icefall.utils import add_eos, add_sos, get_texts +from icefall.utils import ( + DecodingResults, + add_eos, + add_sos, + get_texts, + get_texts_with_timestamp, +) def fast_beam_search_one_best( @@ -37,7 +43,8 @@ def fast_beam_search_one_best( max_states: int, max_contexts: int, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -61,8 +68,12 @@ def fast_beam_search_one_best( Max contexts pre stream per frame. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -76,8 +87,11 @@ def fast_beam_search_one_best( ) best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest_LG( @@ -92,7 +106,8 @@ def fast_beam_search_nbest_LG( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -129,8 +144,12 @@ def fast_beam_search_nbest_LG( single precision. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -195,9 +214,10 @@ def fast_beam_search_nbest_LG( best_hyp_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - hyps = get_texts(best_path) - - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest( @@ -212,7 +232,8 @@ def fast_beam_search_nbest( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -249,8 +270,12 @@ def fast_beam_search_nbest( single precision. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -279,9 +304,10 @@ def fast_beam_search_nbest( best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search_nbest_oracle( @@ -297,7 +323,8 @@ def fast_beam_search_nbest_oracle( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -338,8 +365,12 @@ def fast_beam_search_nbest_oracle( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ lattice = fast_beam_search( model=model, @@ -378,8 +409,10 @@ def fast_beam_search_nbest_oracle( best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - return hyps + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) def fast_beam_search( @@ -469,8 +502,11 @@ def fast_beam_search( def greedy_search( - model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int -) -> List[int]: + model: Transducer, + encoder_out: torch.Tensor, + max_sym_per_frame: int, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """Greedy search for a single utterance. Args: model: @@ -480,8 +516,12 @@ def greedy_search( max_sym_per_frame: Maximum number of symbols per frame. If it is set to 0, the WER would be 100%. + return_timestamps: + Whether to return timestamps. Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -507,6 +547,10 @@ def greedy_search( t = 0 hyp = [blank_id] * context_size + # timestamp[i] is the frame index after subsampling + # on which hyp[i] is decoded + timestamp = [] + # Maximum symbols per utterance. max_sym_per_utt = 1000 @@ -533,6 +577,7 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) + timestamp.append(t) decoder_input = torch.tensor( [hyp[-context_size:]], device=device ).reshape(1, context_size) @@ -547,14 +592,21 @@ def greedy_search( t += 1 hyp = hyp[context_size:] # remove blanks - return hyp + if not return_timestamps: + return hyp + else: + return DecodingResults( + tokens=[hyp], + timestamps=[timestamp], + ) def greedy_search_batch( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: model: @@ -564,9 +616,12 @@ def greedy_search_batch( encoder_out_lens: A 1-D tensor of shape (N,), containing number of valid frames in encoder_out before padding. + return_timestamps: + Whether to return timestamps. Returns: - Return a list-of-list of token IDs containing the decoded results. - len(ans) equals to encoder_out.size(0). + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -591,6 +646,10 @@ def greedy_search_batch( hyps = [[blank_id] * context_size for _ in range(N)] + # timestamp[n][i] is the frame index after subsampling + # on which hyp[n][i] is decoded + timestamps = [[] for _ in range(N)] + decoder_input = torch.tensor( hyps, device=device, @@ -604,7 +663,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -626,6 +685,7 @@ def greedy_search_batch( for i, v in enumerate(y): if v not in (blank_id, unk_id): hyps[i].append(v) + timestamps[i].append(t) emitted = True if emitted: # update decoder output @@ -640,11 +700,19 @@ def greedy_search_batch( sorted_ans = [h[context_size:] for h in hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) @dataclass @@ -657,6 +725,10 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] + state_cost: Optional[NgramLmStateCost] = None @property @@ -806,7 +878,8 @@ def modified_beam_search( encoder_out_lens: torch.Tensor, beam: int = 4, temperature: float = 1.0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. Args: @@ -821,9 +894,12 @@ def modified_beam_search( Number of active paths during the beam search. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -851,6 +927,7 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], ) ) @@ -858,7 +935,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for batch_size in batch_size_list: + for (t, batch_size) in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -936,30 +1013,44 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): new_ys.append(new_token) + new_timestamp.append(t) new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) B[i].add(new_hyp) B = B + finalized_B best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] + ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - return ans + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, -) -> List[int]: + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """It limits the maximum number of symbols per frame to 1. It decodes only one utterance at a time. We keep it only for reference. @@ -974,8 +1065,13 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + return_timestamps: + Whether to return timestamps. + Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -995,6 +1091,7 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], ) ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -1053,17 +1150,24 @@ def _deprecated_modified_beam_search( for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] + new_timestamp = hyp.timestamp[:] new_token = topk_token_indexes[i] if new_token not in (blank_id, unk_id): new_ys.append(new_token) + new_timestamp.append(t) new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) B.add(new_hyp) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys + if not return_timestamps: + return ys + else: + return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) def beam_search( @@ -1071,7 +1175,8 @@ def beam_search( encoder_out: torch.Tensor, beam: int = 4, temperature: float = 1.0, -) -> List[int]: + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -1086,8 +1191,13 @@ def beam_search( Beam size. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: - Return the decoded result. + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. """ assert encoder_out.ndim == 3 @@ -1114,7 +1224,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) max_sym_per_utt = 20000 @@ -1175,7 +1285,13 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + B.add( + Hypothesis( + ys=y_star.ys[:], + log_prob=new_y_star_log_prob, + timestamp=y_star.timestamp[:], + ) + ) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -1184,7 +1300,14 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + new_timestamp = y_star.timestamp + [t] + A.add( + Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ) + ) # Check whether B contains more than "beam" elements more probable # than the most probable in A @@ -1200,7 +1323,11 @@ def beam_search( best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys + + if not return_timestamps: + return ys + else: + return DecodingResults(tokens=[ys], timestamps=[best_hyp.timestamp]) def fast_beam_search_with_nbest_rescoring( @@ -1220,7 +1347,8 @@ def fast_beam_search_with_nbest_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model. The shortest path within the @@ -1262,10 +1390,13 @@ def fast_beam_search_with_nbest_rescoring( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1343,16 +1474,18 @@ def fast_beam_search_with_nbest_rescoring( log_semiring=False, ) - ans: Dict[str, List[List[int]]] = {} + ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} for s in ngram_lm_scale_list: key = f"ngram_lm_scale_{s}" tot_scores = am_scores.values + s * ngram_lm_scores ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - ans[key] = hyps + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) return ans @@ -1376,7 +1509,8 @@ def fast_beam_search_with_nbest_rnn_rescoring( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, -) -> Dict[str, List[List[int]]]: + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, num_path are selected and rescored using a given language model and a rnn-lm. @@ -1422,10 +1556,13 @@ def fast_beam_search_with_nbest_rnn_rescoring( yields more unique paths. temperature: Softmax temperature. + return_timestamps: + Whether to return timestamps. Returns: Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the - ngram LM scale value used during decoding, i.e., 0.1. + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. """ lattice = fast_beam_search( model=model, @@ -1537,9 +1674,11 @@ def fast_beam_search_with_nbest_rnn_rescoring( ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps = get_texts(best_path) - ans[key] = hyps + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 873892bb91..13697008fa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -106,6 +106,22 @@ --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +To evaluate symbol delay, you should: +(1) Generate cuts with word-time alignments: +./local/add_alignment_librispeech.py \ + --alignments-dir data/alignment \ + --cuts-in-dir data/fbank \ + --cuts-out-dir data/fbank_ali +(2) Set the argument "--manifest-dir data/fbank_ali" while decoding. +For example: +./pruned_transducer_stateless4/decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method greedy_search \ + --manifest-dir data/fbank_ali """ @@ -142,10 +158,12 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + DecodingResults, + parse_hyp_and_timestamp, setup_logger, - store_transcripts, + store_transcripts_and_timestamps, str2bool, - write_error_stats, + write_error_stats_with_timestamps, ) LOG_EPS = math.log(1e-10) @@ -318,7 +336,7 @@ def get_parser(): "--left-context", type=int, default=64, - help="left context can be seen during decoding (in frames after subsampling)", + help="left context can be seen during decoding (in frames after subsampling)", # noqa ) parser.add_argument( @@ -350,7 +368,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: +) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -358,9 +376,10 @@ def decode_one_batch( if greedy_search is used, it would be "greedy_search" If beam search with a beam size of 7 is used, it would be "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. + - value: It is a tuple. `len(value[0])` and `len(value[1])` are both + equal to the batch size. `value[0][i]` and `value[1][i]` + are the decoding result and timestamps for the i-th utterance + in the given batch respectively. Args: params: It's the return value of :func:`get_params`. @@ -379,8 +398,8 @@ def decode_one_batch( only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: - Return the decoding result. See above description for the format of - the returned dict. + Return the decoding result and timestamps. See above description for the + format of the returned dict. """ device = next(model.parameters()).device feature = batch["inputs"] @@ -412,10 +431,8 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) - hyps = [] - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( + res = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -423,11 +440,10 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( + res = fast_beam_search_nbest_LG( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -437,11 +453,10 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( + res = fast_beam_search_nbest( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -451,11 +466,10 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( + res = fast_beam_search_nbest_oracle( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -466,56 +480,67 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 ): - hyp_tokens = greedy_search_batch( + res = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( + res = modified_beam_search( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) - + tokens = [] + timestamps = [] for i in range(batch_size): # fmt: off encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - hyp = greedy_search( + res = greedy_search( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + return_timestamps=True, ) elif params.decoding_method == "beam_search": - hyp = beam_search( + res = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + return_timestamps=True, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + tokens.extend(res.tokens) + timestamps.extend(res.timestamps) + res = DecodingResults(tokens=tokens, timestamps=timestamps) + + hyps, timestamps = parse_hyp_and_timestamp( + decoding_method=params.decoding_method, + res=res, + sp=sp, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + word_table=word_table, + ) if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search": (hyps, timestamps)} elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -526,9 +551,9 @@ def decode_one_batch( if "LG" in params.decoding_method: key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - return {key: hyps} + return {key: (hyps, timestamps)} else: - return {f"beam_size_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": (hyps, timestamps)} def decode_dataset( @@ -538,7 +563,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -559,9 +586,12 @@ def decode_dataset( Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. + Its value is a list of tuples. Each tuple contains five elements: + - cut_id + - reference transcript + - predicted result + - timestamp of reference transcript + - timestamp of predicted result """ num_cuts = 0 @@ -580,6 +610,18 @@ def decode_dataset( texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + timestamps_ref = [] + for cut in batch["supervisions"]["cut"]: + for s in cut.supervisions: + time = [] + if s.alignment is not None and "word" in s.alignment: + time = [ + aliword.start + for aliword in s.alignment["word"] + if aliword.symbol != "" + ] + timestamps_ref.append(time) + hyps_dict = decode_one_batch( params=params, model=model, @@ -589,12 +631,18 @@ def decode_dataset( batch=batch, ) - for name, hyps in hyps_dict.items(): + for name, (hyps, timestamps_hyp) in hyps_dict.items(): this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + assert len(hyps) == len(texts) and len(timestamps_hyp) == len( + timestamps_ref + ) + for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip( + cut_ids, hyps, texts, timestamps_hyp, timestamps_ref + ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -612,15 +660,19 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[ + str, + List[Tuple[List[str], List[str], List[str], List[float], List[float]]], + ], ): test_set_wers = dict() + test_set_delays = dict() for key, results in results_dict.items(): recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -629,10 +681,11 @@ def save_results( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: - wer = write_error_stats( + wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer + test_set_delays[key] = (mean_delay, var_delay) logging.info("Wrote detailed error stats to {}".format(errs_filename)) @@ -646,6 +699,19 @@ def save_results( for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) + delays_info = ( + params.res_dir + / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(delays_info, "w") as f: + print("settings\tsymbol-delay", file=f) + for key, val in test_set_delays: + print( + "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), + file=f, + ) + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: @@ -653,6 +719,15 @@ def save_results( note = "" logging.info(s) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_delays: + s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) + note = "" + logging.info(s) + @torch.no_grad() def main(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 13a5b1a515..4c55fd6096 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -386,6 +386,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, diff --git a/icefall/utils.py b/icefall/utils.py index 6c115ed169..45a49fb5c1 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -24,9 +24,10 @@ import subprocess from collections import defaultdict from contextlib import contextmanager +from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Tuple, Union +from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union import k2 import k2.version @@ -248,6 +249,86 @@ def get_texts( return aux_labels.tolist() +@dataclass +class DecodingResults: + # Decoded token IDs for each utterance in the batch + tokens: List[List[int]] + + # timestamps[i][k] contains the frame number on which tokens[i][k] + # is decoded + timestamps: List[List[int]] + + # hyps[i] is the recognition results, i.e., word IDs + # for the i-th utterance with fast_beam_search_nbest_LG. + hyps: Union[List[List[int]], k2.RaggedTensor] = None + + +def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]: + tokens = [] + timestamps = [] + for i, v in enumerate(labels): + if v != 0: + tokens.append(v) + timestamps.append(i) + + return tokens, timestamps + + +def get_texts_with_timestamp( + best_paths: k2.Fsa, return_ragged: bool = False +) -> DecodingResults: + """Extract the texts (as word IDs) and timestamps from the best-path FSAs. + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + return_ragged: + True to return a ragged tensor with two axes [utt][word_id]. + False to return a list-of-list word IDs. + Returns: + Returns a list of lists of int, containing the label sequences we + decoded. + """ + if isinstance(best_paths.aux_labels, k2.RaggedTensor): + # remove 0's and -1's. + aux_labels = best_paths.aux_labels.remove_values_leq(0) + # TODO: change arcs.shape() to arcs.shape + aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) + + # remove the states and arcs axes. + aux_shape = aux_shape.remove_axis(1) + aux_shape = aux_shape.remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values) + else: + # remove axis corresponding to states. + aux_shape = best_paths.arcs.shape().remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) + # remove 0's and -1's. + aux_labels = aux_labels.remove_values_leq(0) + + assert aux_labels.num_axes == 2 + + labels_shape = best_paths.arcs.shape().remove_axis(1) + labels_list = k2.RaggedTensor( + labels_shape, best_paths.labels.contiguous() + ).tolist() + + tokens = [] + timestamps = [] + for labels in labels_list: + token, time = get_tokens_and_timestamps(labels[:-1]) + tokens.append(token) + timestamps.append(time) + + return DecodingResults( + tokens=tokens, + timestamps=timestamps, + hyps=aux_labels if return_ragged else aux_labels.tolist(), + ) + + def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: """Extract labels or aux_labels from the best-path FSAs. @@ -352,6 +433,33 @@ def store_transcripts( print(f"{cut_id}:\thyp={hyp}", file=f) +def store_transcripts_and_timestamps( + filename: Pathlike, + texts: Iterable[Tuple[str, List[str], List[str], List[float], List[float]]], +) -> None: + """Save predicted results and reference transcripts as well as their timestamps + to a file. + + Args: + filename: + File to save the results to. + texts: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + Returns: + Return None. + """ + with open(filename, "w") as f: + for cut_id, ref, hyp, time_ref, time_hyp in texts: + print(f"{cut_id}:\tref={ref}", file=f) + print(f"{cut_id}:\thyp={hyp}", file=f) + if len(time_ref) > 0: + s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" + print(f"{cut_id}:\ttimestamp_ref={s}", file=f) + s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" + print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) + + def write_error_stats( f: TextIO, test_set_name: str, @@ -519,6 +627,211 @@ def write_error_stats( return float(tot_err_rate) +def write_error_stats_with_timestamps( + f: TextIO, + test_set_name: str, + results: List[Tuple[str, List[str], List[str], List[float], List[float]]], + enable_log: bool = True, +) -> Tuple[float, float, float]: + """Write statistics based on predicted results and reference transcripts + as well as their timestamps. + + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + + Returns: + Return total word error rate and mean delay. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + # Compute mean alignment delay on the correct words + all_delay = [] + for cut_id, ref, hyp, time_ref, time_hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + has_time_ref = len(time_ref) > 0 + if has_time_ref: + # pointer to timestamp_hyp + p_hyp = 0 + # pointer to timestamp_ref + p_ref = 0 + for ref_word, hyp_word in ali: + if ref_word == ERR: + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + if has_time_ref: + p_hyp += 1 + elif hyp_word == ERR: + dels[ref_word] += 1 + words[ref_word][4] += 1 + if has_time_ref: + p_ref += 1 + elif hyp_word != ref_word: + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + if has_time_ref: + p_hyp += 1 + p_ref += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + if has_time_ref: + all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) + p_hyp += 1 + p_ref += 1 + if has_time_ref: + assert p_hyp == len(hyp), (p_hyp, len(hyp)) + assert p_ref == len(ref), (p_ref, len(ref)) + + ref_len = sum([len(r) for _, r, _, _, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + mean_delay = "inf" + var_delay = "inf" + num_delay = len(all_delay) + if num_delay > 0: + mean_delay = sum(all_delay) / num_delay + var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay + mean_delay = "%.3f" % mean_delay + var_delay = "%.3f" % var_delay + + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) + logging.info( + f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa + f"computed on {num_delay} correct words" + ) + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for cut_id, ref, hyp, _, _ in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [ + [ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] + for x, y in ali + ] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [ + [ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] + for x, y in ali + ] + + print( + f"{cut_id}:\t" + + " ".join( + ( + ref_word + if ref_word == hyp_word + else f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali + ) + ), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted( + [(v, k) for k, v in subs.items()], reverse=True + ): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + print("", file=f) + print( + "PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f + ) + for _, word, counts in sorted( + [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True + ): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + ref_count = corr + ref_sub + dels + hyp_count = corr + hyp_sub + ins + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + return float(tot_err_rate), float(mean_delay), float(var_delay) + + class MetricsTracker(collections.defaultdict): def __init__(self): # Passing the type 'int' to the base-class constructor @@ -978,6 +1291,137 @@ def display_and_save_batch( logging.info(f"num tokens: {num_tokens}") +def convert_timestamp( + frames: List[int], + subsampling_factor: int, + frame_shift_ms: float = 10, +) -> List[float]: + """Convert frame numbers to time (in seconds) given subsampling factor + and frame shift (in milliseconds). + + Args: + frames: + A list of frame numbers after subsampling. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + Return: + Return the time in seconds corresponding to each given frame. + """ + frame_shift = frame_shift_ms / 1000.0 + time = [] + for f in frames: + time.append(f * subsampling_factor * frame_shift) + + return time + + +def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]: + """ + Parse timestamp of each word. + + Args: + tokens: + List of tokens. + timestamp: + List of timestamp of each token. + + Returns: + List of timestamp of each word. + """ + start_token = b"\xe2\x96\x81".decode() # '_' + assert len(tokens) == len(timestamp) + ans = [] + for i in range(len(tokens)): + flag = False + if i == 0 or tokens[i].startswith(start_token): + flag = True + if len(tokens[i]) == 1 and tokens[i].startswith(start_token): + # tokens[i] == start_token + if i == len(tokens) - 1: + # it is the last token + flag = False + elif tokens[i + 1].startswith(start_token): + # the next token also starts with start_token + flag = False + if flag: + ans.append(timestamp[i]) + return ans + + +def parse_hyp_and_timestamp( + res: DecodingResults, + decoding_method: str, + sp: spm.SentencePieceProcessor, + subsampling_factor: int, + frame_shift_ms: float = 10, + word_table: Optional[k2.SymbolTable] = None, +) -> Tuple[List[List[str]], List[List[float]]]: + """Parse hypothesis and timestamp. + + Args: + res: + A DecodingResults object. + decoding_method: + Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + sp: + The BPE model. + subsampling_factor: + The integer subsampling factor. + frame_shift_ms: + The float frame shift used for feature extraction. + word_table: + The word symbol table. + + Returns: + Return a list of hypothesis and timestamp. + """ + assert decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + + hyps = [] + timestamps = [] + + N = len(res.tokens) + assert len(res.timestamps) == N + use_word_table = False + if decoding_method == "fast_beam_search_nbest_LG": + assert word_table is not None + use_word_table = True + + for i in range(N): + tokens = sp.id_to_piece(res.tokens[i]) + if use_word_table: + words = [word_table[i] for i in res.hyps[i]] + else: + words = sp.decode_pieces(tokens).split() + time = convert_timestamp( + res.timestamps[i], subsampling_factor, frame_shift_ms + ) + time = parse_timestamp(tokens, time) + assert len(time) == len(words), (tokens, words) + + hyps.append(words) + timestamps.append(time) + + return hyps, timestamps + + # `is_module_available` is copied from # https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9 def is_module_available(*modules: str) -> bool: