diff --git a/sherpa/bin/lstm_transducer_stateless/beam_search.py b/sherpa/bin/lstm_transducer_stateless/beam_search.py index b71338ade..d9882a3e7 100644 --- a/sherpa/bin/lstm_transducer_stateless/beam_search.py +++ b/sherpa/bin/lstm_transducer_stateless/beam_search.py @@ -132,7 +132,7 @@ def process( processed_lens = (num_processed_frames >> 2) + encoder_out_lens if self.decoding_method == "fast_beam_search_nbest": - next_hyp_list, next_trailing_blank_frames = fast_beam_search_nbest( + res = fast_beam_search_nbest( model=model, encoder_out=encoder_out, processed_lens=processed_lens, @@ -144,10 +144,7 @@ def process( temperature=self.beam_search_params["temperature"], ) elif self.decoding_method == "fast_beam_search_nbest_LG": - ( - next_hyp_list, - next_trailing_blank_frames, - ) = fast_beam_search_nbest_LG( + res = fast_beam_search_nbest_LG( model=model, encoder_out=encoder_out, processed_lens=processed_lens, @@ -159,10 +156,7 @@ def process( temperature=self.beam_search_params["temperature"], ) elif self.decoding_method == "fast_beam_search": - ( - next_hyp_list, - next_trailing_blank_frames, - ) = fast_beam_search_one_best( + res = fast_beam_search_one_best( model=model, encoder_out=encoder_out, processed_lens=processed_lens, @@ -177,8 +171,12 @@ def process( next_state_list = unstack_states(next_states) for i, s in enumerate(stream_list): s.states = next_state_list[i] - s.hyp = next_hyp_list[i] - s.num_trailing_blank_frames = next_trailing_blank_frames[i] + s.hyp = res.hyps[i] + s.num_trailing_blank_frames = res.num_trailing_blanks[i] + s.frame_offset += encoder_out.size(1) + s.segment_frame_offset += encoder_out.size(1) + s.timestamps = res.timestamps[i] + s.tokens = res.tokens[i] def get_texts(self, stream: Stream) -> str: """ @@ -200,6 +198,22 @@ def get_texts(self, stream: Stream) -> str: return result + def get_tokens(self, stream: Stream) -> str: + """ + Return tokens after decoding + Args: + stream: + Stream to be processed. + """ + tokens = stream.tokens + + if hasattr(self, "sp"): + result = [self.sp.id_to_piece(i) for i in tokens] + else: + result = [self.token_table[i] for i in tokens] + + return result + class GreedySearch: def __init__( @@ -244,6 +258,10 @@ def init_stream(self, stream: Stream): self.beam_search_params["blank_id"] ] * self.beam_search_params["context_size"] + # timestamps[i] is the frame number on which stream.hyp[i+context_size] + # is decoded + stream.timestamps = [] # containing frame numbers after subsampling + @torch.no_grad() def process( self, @@ -266,9 +284,13 @@ def process( chunk_length = server.chunk_length batch_size = len(stream_list) chunk_length_pad = server.chunk_length_pad - state_list, feature_list = [], [] - decoder_out_list, hyp_list = [], [] + state_list = [] + feature_list = [] + decoder_out_list = [] + hyp_list = [] num_trailing_blank_frames_list = [] + frame_offset_list = [] + timestamps_list = [] for s in stream_list: decoder_out_list.append(s.decoder_out) @@ -282,6 +304,8 @@ def process( feature_list.append(b) num_trailing_blank_frames_list.append(s.num_trailing_blank_frames) + frame_offset_list.append(s.segment_frame_offset) + timestamps_list.append(s.timestamps) features = torch.stack(feature_list, dim=0).to(device) states = stack_states(state_list) @@ -311,12 +335,15 @@ def process( next_decoder_out, next_hyp_list, next_trailing_blank_frames, + next_timestamps, ) = streaming_greedy_search( model=model, encoder_out=encoder_out, decoder_out=decoder_out, hyps=hyp_list, num_trailing_blank_frames=num_trailing_blank_frames_list, + frame_offset=frame_offset_list, + timestamps=timestamps_list, ) next_decoder_out_list = next_decoder_out.split(1) @@ -327,6 +354,9 @@ def process( s.decoder_out = next_decoder_out_list[i] s.hyp = next_hyp_list[i] s.num_trailing_blank_frames = next_trailing_blank_frames[i] + s.timestamps = next_timestamps[i] + s.frame_offset += encoder_out.size(1) + s.segment_frame_offset += encoder_out.size(1) def get_texts(self, stream: Stream) -> str: """ @@ -344,6 +374,21 @@ def get_texts(self, stream: Stream) -> str: return result + def get_tokens(self, stream: Stream) -> str: + """ + Return tokens after decoding + Args: + stream: + Stream to be processed. + """ + hyp = stream.hyp[self.beam_search_params["context_size"] :] + if hasattr(self, "sp"): + result = [self.sp.id_to_piece(i) for i in hyp] + else: + result = [self.token_table[i] for i in hyp] + + return result + class ModifiedBeamSearch: def __init__(self, beam_search_params: dict): @@ -383,6 +428,7 @@ def process( state_list = [] hyps_list = [] feature_list = [] + frame_offset_list = [] for s in stream_list: state_list.append(s.states) hyps_list.append(s.hyps) @@ -393,6 +439,7 @@ def process( b = torch.cat(f, dim=0) feature_list.append(b) + frame_offset_list.append(s.segment_frame_offset) features = torch.stack(feature_list, dim=0).to(device) states = stack_states(state_list) @@ -415,6 +462,7 @@ def process( model=model, encoder_out=encoder_out, hyps=hyps_list, + frame_offset=frame_offset_list, num_active_paths=self.beam_search_params["num_active_paths"], ) @@ -422,8 +470,14 @@ def process( for i, s in enumerate(stream_list): s.states = next_state_list[i] s.hyps = next_hyps_list[i] - trailing_blanks = s.hyps.get_most_probable(True).num_trailing_blanks + + best_hyp = s.hyps.get_most_probable(True) + + trailing_blanks = best_hyp.num_trailing_blanks + s.timestamps = best_hyp.timestamps s.num_trailing_blank_frames = trailing_blanks + s.frame_offset += encoder_out.size(1) + s.segment_frame_offset += encoder_out.size(1) def get_texts(self, stream: Stream) -> str: hyp = stream.hyps.get_most_probable(True).ys[ @@ -436,3 +490,20 @@ def get_texts(self, stream: Stream) -> str: result = "".join(result) return result + + def get_tokens(self, stream: Stream) -> str: + """ + Return tokens after decoding + Args: + stream: + Stream to be processed. + """ + hyp = stream.hyps.get_most_probable(True).ys[ + self.beam_search_params["context_size"] : + ] + if hasattr(self, "sp"): + result = [self.sp.id_to_piece(i) for i in hyp] + else: + result = [self.token_table[i] for i in hyp] + + return result diff --git a/sherpa/bin/lstm_transducer_stateless/stream.py b/sherpa/bin/lstm_transducer_stateless/stream.py index d73bc1e0f..fada40cec 100644 --- a/sherpa/bin/lstm_transducer_stateless/stream.py +++ b/sherpa/bin/lstm_transducer_stateless/stream.py @@ -132,9 +132,15 @@ def __init__( self.subsampling_factor = subsampling_factor self.log_eps = math.log(1e-10) - # whenever an endpoint is detected, it is incremented + # increment on endpointing self.segment = 0 + # Number of frames decoded so far (after subsampling) + self.frame_offset = 0 # never reset + + # frame offset within the current segment after subsampling + self.segment_frame_offset = 0 # reset on endpointing + def accept_waveform( self, sampling_rate: float, @@ -225,5 +231,6 @@ def endpoint_detected( self.num_trailing_blank_frames = 0 self.processed_frames = 0 self.segment += 1 + self.segment_frame_offset = 0 return detected diff --git a/sherpa/bin/lstm_transducer_stateless/streaming_server.py b/sherpa/bin/lstm_transducer_stateless/streaming_server.py index bb7a04ce9..5adf6e4ca 100755 --- a/sherpa/bin/lstm_transducer_stateless/streaming_server.py +++ b/sherpa/bin/lstm_transducer_stateless/streaming_server.py @@ -54,6 +54,7 @@ RnntLstmModel, add_beam_search_arguments, add_online_endpoint_arguments, + convert_timestamp, ) @@ -298,6 +299,7 @@ def __init__( raise ValueError( f"Decoding method {decoding_method} is not supported." ) + self.decoding_method = decoding_method if hasattr(self, "sp"): self.beam_search.sp = self.sp @@ -484,15 +486,27 @@ async def handle_connection_impl( while len(stream.features) > self.chunk_length_pad: await self.compute_and_decode(stream) hyp = self.beam_search.get_texts(stream) + tokens = self.beam_search.get_tokens(stream) segment = stream.segment + timestamps = convert_timestamp( + frames=stream.timestamps, + subsampling_factor=stream.subsampling_factor, + ) + + frame_offset = stream.frame_offset * stream.subsampling_factor + is_final = stream.endpoint_detected(self.online_endpoint_config) if is_final: self.beam_search.init_stream(stream) message = { + "method": self.decoding_method, "segment": segment, + "frame_offset": frame_offset, "text": hyp, + "tokens": tokens, + "timestamps": timestamps, "final": is_final, } @@ -509,10 +523,20 @@ async def handle_connection_impl( stream.features = [] hyp = self.beam_search.get_texts(stream) + tokens = self.beam_search.get_tokens(stream) + frame_offset = stream.frame_offset * stream.subsampling_factor + timestamps = convert_timestamp( + frames=stream.timestamps, + subsampling_factor=stream.subsampling_factor, + ) message = { + "method": self.decoding_method, "segment": stream.segment, + "frame_offset": frame_offset, "text": hyp, + "tokens": tokens, + "timestamps": timestamps, "final": True, # end of connection, always set final to True } diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py b/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py index 1ebc449ff..91b00f6d7 100755 --- a/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py @@ -75,12 +75,23 @@ async def receive_results(socket: websockets.WebSocketServerProtocol): async for message in socket: result = json.loads(message) + method = result["method"] segment = result["segment"] is_final = result["final"] text = result["text"] + tokens = result["tokens"] + timestamps = result["timestamps"] if is_final: - ans.append(dict(segment=segment, text=text)) + ans.append( + dict( + method=method, + segment=segment, + text=text, + tokens=tokens, + timestamps=timestamps, + ) + ) logging.info(f"Final result of segment {segment}: {text}") continue @@ -121,8 +132,16 @@ async def run(server_addr: str, server_port: int, test_wav: str): decoding_results = await receive_task s = "" for r in decoding_results: + s += f"method: {r['method']}\n" s += f"segment: {r['segment']}\n" s += f"text: {r['text']}\n" + + token_time = [] + for token, time in zip(r["tokens"], r["timestamps"]): + token_time.append((token, time)) + + s += f"timestamps: {r['timestamps']}\n" + s += f"(token, time): {token_time}\n" logging.info(f"{test_wav}\n{s}") diff --git a/sherpa/csrc/hypothesis.h b/sherpa/csrc/hypothesis.h index 3132267eb..de88088d7 100644 --- a/sherpa/csrc/hypothesis.h +++ b/sherpa/csrc/hypothesis.h @@ -32,6 +32,10 @@ struct Hypothesis { // The predicted tokens so far. Newly predicated tokens are appended. std::vector ys; + // timestamps[i] contains the frame number after subsampling + // on which ys[i] is decoded. + std::vector timestamps; + // The total score of ys in log space. double log_prob = 0; diff --git a/sherpa/csrc/rnnt_beam_search.cc b/sherpa/csrc/rnnt_beam_search.cc index 561154b81..a9126a035 100644 --- a/sherpa/csrc/rnnt_beam_search.cc +++ b/sherpa/csrc/rnnt_beam_search.cc @@ -205,8 +205,10 @@ std::vector> GreedySearch( torch::Tensor StreamingGreedySearch( RnntModel &model, // NOLINT torch::Tensor encoder_out, torch::Tensor decoder_out, + const std::vector &frame_offset, std::vector> *hyps, - std::vector *num_trailing_blank_frames) { + std::vector *num_trailing_blank_frames, + std::vector> *timestamps) { TORCH_CHECK(encoder_out.dim() == 3, encoder_out.dim(), " vs ", 3); TORCH_CHECK(decoder_out.dim() == 2, decoder_out.dim(), " vs ", 2); @@ -246,6 +248,7 @@ torch::Tensor StreamingGreedySearch( if (index != blank_id && index != unk_id) { emitted = true; (*hyps)[n].push_back(index); + (*timestamps)[n].push_back(t + frame_offset[n]); (*num_trailing_blank_frames)[n] = 0; } else { (*num_trailing_blank_frames)[n] += 1; @@ -430,6 +433,7 @@ std::vector> ModifiedBeamSearch( std::vector StreamingModifiedBeamSearch( RnntModel &model, // NOLINT torch::Tensor encoder_out, std::vector in_hyps, + const std::vector &frame_offset, int32_t num_active_paths /*= 4*/) { TORCH_CHECK(encoder_out.dim() == 3, encoder_out.dim(), " vs ", 3); @@ -522,6 +526,7 @@ std::vector StreamingModifiedBeamSearch( int32_t new_token = topk_token_indexes_acc[j]; if (new_token != blank_id && new_token != unk_id) { new_hyp.ys.push_back(new_token); + new_hyp.timestamps.push_back(t + frame_offset[k]); new_hyp.num_trailing_blanks = 0; } else { new_hyp.num_trailing_blanks += 1; diff --git a/sherpa/csrc/rnnt_beam_search.h b/sherpa/csrc/rnnt_beam_search.h index 160a16dfc..be8586cdf 100644 --- a/sherpa/csrc/rnnt_beam_search.h +++ b/sherpa/csrc/rnnt_beam_search.h @@ -55,17 +55,29 @@ std::vector> GreedySearch( * device as `model`. * @param decoder_out A 2-D tensor of shape (N, C). It should be on the same * device as `model`. + * @param frame_offset Its shape is (N,). The i-th element contains the number + * of frames after subsampling we have decoded so far for + * the i-th utterance. * @param hyps The decoded tokens. Note: It is modified in-place. - * @param num_trailing_blank_frames Number of trailing blank frames. It is - * updated in-place. + * @param num_trailing_blank_frames Its shape is (N,). The i-th element + * contains the number of trailing blank + * frames after subsampling for the i-th + * utterance. It is updated in-place. + * @param timestamps Its shape is (N,). timestamps[i].size() == hyps[i].size() + * timestamps[i][k] is the frame number after subsampling + * on which hyps[i][k] is decoded. It is modified in-place. * * @return Return the decoder output for the next chunk. */ torch::Tensor StreamingGreedySearch( RnntModel &model, // NOLINT torch::Tensor encoder_out, torch::Tensor decoder_out, + const std::vector &frame_offset, std::vector> *hyps, - std::vector *num_trailing_blank_frames); + std::vector *num_trailing_blank_frames, + std::vector> *timestamps + +); /** RNN-T modified beam search for offline recognition. * @@ -101,6 +113,9 @@ std::vector> ModifiedBeamSearch( * @param encoder_out A 3-D tensor of shape (N, T, C). It should be on the same * device as `model`. * @param hyps The decoded results from the previous chunk. + * @param frame_offset Its shape is (N,). The i-th element contains the number + * of frames after subsampling we have decoded so far for + * the i-th utterance. * @param num_active_paths Number of active paths for each utterance. * Note: Due to merging paths with identical token * sequences, the actual number of active paths for @@ -111,7 +126,7 @@ std::vector> ModifiedBeamSearch( std::vector StreamingModifiedBeamSearch( RnntModel &model, // NOLINT torch::Tensor encoder_out, std::vector hyps, - int32_t num_active_paths = 4); + const std::vector &frame_offset, int32_t num_active_paths = 4); } // namespace sherpa diff --git a/sherpa/python/csrc/hypothesis.cc b/sherpa/python/csrc/hypothesis.cc index 75974defe..e8522bc27 100644 --- a/sherpa/python/csrc/hypothesis.cc +++ b/sherpa/python/csrc/hypothesis.cc @@ -40,6 +40,10 @@ void PybindHypothesis(py::module &m) { // NOLINT .def_property_readonly( "ys", [](const PyClass &self) -> std::vector { return self.ys; }) + .def_property_readonly("timestamps", + [](const PyClass &self) -> std::vector { + return self.timestamps; + }) .def_property_readonly("num_trailing_blanks", [](const PyClass &self) -> int32_t { return self.num_trailing_blanks; diff --git a/sherpa/python/csrc/rnnt_beam_search.cc b/sherpa/python/csrc/rnnt_beam_search.cc index a873a4dc3..05d1e808e 100644 --- a/sherpa/python/csrc/rnnt_beam_search.cc +++ b/sherpa/python/csrc/rnnt_beam_search.cc @@ -66,15 +66,25 @@ RNN-T greedy search for streaming recognition. Output from the decoder network. Its shape is ``(batch_size, decoder_out_dim)`` and its dtype is ``torch::kFloat``. It should be on the same device as ``model``. + frame_offset: + Its shape is (N,). The i-th element contains the number of frames after + subsampling we have decoded so far for the i-th utterance. hyps: The decoded tokens from the previous chunk. num_trailing_blank_frames: - Number of trailing blank frames decoded so far. + Its shape is (N,). The i-th element contains the number of trailing blank + frames after subsampling for the i-th utterance. It is updated in-place. + timestamps: + Its shape is (N,). timestamps[i].size() == hyps[i].size() timestamps[i][k] + is the frame number after subsampling on which hyps[i][k] is decoded. It + is modified in-place. Returns: Return a tuple containing: - - The decoder output for the current chunk. - - The decoded tokens for the current chunk. + - The decoder output + - The decoded tokens + - Number of trailing blank frames + - Time stamps )doc"; static constexpr const char *kModifiedBeamSearchDoc = R"doc( @@ -117,6 +127,9 @@ RNN-T modified beam search for streaming recognition. It should be on the same device as ``model``. hyps: Decoded results from the previous chunk. + frame_offset: + Its shape is (N,). The i-th element contains the number of frames after + subsampling we have decoded so far for the i-th utterance. num_active_paths Number of active paths for each utterance. Note: Due to merging paths with identical token sequences, the actual number of active path for each @@ -135,15 +148,20 @@ void PybindRnntBeamSearch(py::module &m) { // NOLINT "streaming_greedy_search", [](RnntModel &model, torch::Tensor encoder_out, torch::Tensor decoder_out, std::vector> &hyps, - std::vector &num_trailing_blank_frames) + std::vector &num_trailing_blank_frames, + const std::vector &frame_offset, + std::vector> ×tamps) -> std::tuple>, - std::vector> { - decoder_out = StreamingGreedySearch(model, encoder_out, decoder_out, - &hyps, &num_trailing_blank_frames); - return {decoder_out, hyps, num_trailing_blank_frames}; + std::vector, + std::vector>> { + decoder_out = StreamingGreedySearch( + model, encoder_out, decoder_out, frame_offset, &hyps, + &num_trailing_blank_frames, ×tamps); + return {decoder_out, hyps, num_trailing_blank_frames, timestamps}; }, py::arg("model"), py::arg("encoder_out"), py::arg("decoder_out"), py::arg("hyps"), py::arg("num_trailing_blank_frames"), + py::arg("frame_offset"), py::arg("timestamps"), py::call_guard(), kStreamingGreedySearchDoc); m.def("modified_beam_search", &ModifiedBeamSearch, py::arg("model"), @@ -153,7 +171,7 @@ void PybindRnntBeamSearch(py::module &m) { // NOLINT m.def("streaming_modified_beam_search", &StreamingModifiedBeamSearch, py::arg("model"), py::arg("encoder_out"), py::arg("hyps"), - py::arg("num_active_paths") = 4, + py::arg("frame_offset"), py::arg("num_active_paths") = 4, py::call_guard(), kStreamingModifiedBeamSearchDoc); } diff --git a/sherpa/python/sherpa/__init__.py b/sherpa/python/sherpa/__init__.py index b8f0aabef..8cef2bcec 100644 --- a/sherpa/python/sherpa/__init__.py +++ b/sherpa/python/sherpa/__init__.py @@ -35,8 +35,9 @@ add_online_endpoint_arguments, endpoint_detected, ) +from .timestamp import convert_timestamp from .utils import ( add_beam_search_arguments, count_num_trailing_zeros, - get_texts_and_num_trailing_blanks, + get_fast_beam_search_results, ) diff --git a/sherpa/python/sherpa/decode.py b/sherpa/python/sherpa/decode.py index 782609c5f..20eb00bac 100644 --- a/sherpa/python/sherpa/decode.py +++ b/sherpa/python/sherpa/decode.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import List, Tuple import k2 @@ -20,7 +21,7 @@ from _sherpa import RnntModel from .nbest import Nbest -from .utils import get_texts_and_num_trailing_blanks +from .utils import FastBeamSearchResults, get_fast_beam_search_results VALID_FAST_BEAM_SEARCH_METHOD = [ "fast_beam_search_nbest_LG", @@ -39,7 +40,7 @@ def fast_beam_search_nbest_LG( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> Tuple[List[List[int]], List[int]]: +) -> FastBeamSearchResults: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -80,9 +81,7 @@ def fast_beam_search_nbest_LG( temperature: Softmax temperature. Returns: - Return a tuple containing: - - the decoded result - - number of trailing blanks + Return the decoded result. """ lattice = fast_beam_search( @@ -146,8 +145,7 @@ def fast_beam_search_nbest_LG( best_hyp_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - hyps, num_trailing_blanks = get_texts_and_num_trailing_blanks(best_path) - return hyps, num_trailing_blanks + return get_fast_beam_search_results(best_path) def fast_beam_search_nbest( @@ -160,7 +158,7 @@ def fast_beam_search_nbest( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, -) -> Tuple[List[List[int]], List[int]]: +) -> FastBeamSearchResults: """It limits the maximum number of symbols per frame to 1. The process to get the results is: @@ -201,9 +199,7 @@ def fast_beam_search_nbest( temperature: Softmax temperature. Returns: - Return a tuple containing: - - the decoded result - - number of trailing blanks + Return the decoded result. """ lattice = fast_beam_search( @@ -230,8 +226,7 @@ def fast_beam_search_nbest( best_path = k2.index_fsa(nbest.fsa, max_indexes) - hyps, num_trailing_blanks = get_texts_and_num_trailing_blanks(best_path) - return hyps, num_trailing_blanks + return get_fast_beam_search_results(best_path) def fast_beam_search_one_best( @@ -241,7 +236,7 @@ def fast_beam_search_one_best( rnnt_decoding_config: k2.RnntDecodingConfig, rnnt_decoding_streams_list: List[k2.RnntDecodingStream], temperature: float = 1.0, -) -> Tuple[List[List[int]], List[int]]: +) -> FastBeamSearchResults: """It limits the maximum number of symbols per frame to 1. A lattice is first obtained using fast beam search, and then @@ -269,9 +264,7 @@ def fast_beam_search_one_best( temperature: Softmax temperature. Returns: - Return a tuple containing: - - the decoded result - - number of trailing blanks + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -284,8 +277,7 @@ def fast_beam_search_one_best( best_path = one_best_decoding(lattice) - hyps, num_trailing_blanks = get_texts_and_num_trailing_blanks(best_path) - return hyps, num_trailing_blanks + return get_fast_beam_search_results(best_path) def fast_beam_search( diff --git a/sherpa/python/sherpa/timestamp.py b/sherpa/python/sherpa/timestamp.py new file mode 100644 index 000000000..3385a9d04 --- /dev/null +++ b/sherpa/python/sherpa/timestamp.py @@ -0,0 +1,27 @@ +from typing import List + + +def convert_timestamp( + frames: List[int], + subsampling_factor: int, + frame_shift_ms: float = 10, +) -> List[float]: + """Convert frame numbers to time (in seconds) given subsampling factor + and frame shift (in milliseconds). + + Args: + frames: + A list of frame numbers after subsampling. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + Return: + Return the time in seconds corresponding to each given frame. + """ + frame_shift = frame_shift_ms / 1000.0 + ans = [] + for f in frames: + ans.append(f * subsampling_factor * frame_shift) + + return ans diff --git a/sherpa/python/sherpa/utils.py b/sherpa/python/sherpa/utils.py index d3a82ea54..15008a69d 100644 --- a/sherpa/python/sherpa/utils.py +++ b/sherpa/python/sherpa/utils.py @@ -1,10 +1,29 @@ import argparse +from dataclasses import dataclass from pathlib import Path -from typing import List, Tuple, Union +from typing import List, Tuple import k2 +@dataclass +class FastBeamSearchResults: + # hyps[i] is the recognition results for the i-th utterance. + # It may contain either token IDs or word IDs depending on the actual + # decoding method. + hyps: List[List[int]] + + # Number of trailing blank for each utterance in the batch + num_trailing_blanks: List[int] + + # Decoded token IDs for each utterance in the batch + tokens: List[List[int]] + + # timestamps[i][k] contains the frame number on which tokens[i][k] + # is decoded + timestamps: List[List[int]] + + def str2bool(v): """Used in argparse.ArgumentParser.add_argument to indicate that a type is a bool type and user can enter @@ -24,7 +43,7 @@ def str2bool(v): raise argparse.ArgumentTypeError("Boolean value expected.") -def count_num_trailing_zeros(labels: List[int]): +def count_num_trailing_zeros(labels: List[int]) -> int: """Return the number of trailing zeros in labels.""" n = 0 for v in reversed(labels): @@ -35,9 +54,20 @@ def count_num_trailing_zeros(labels: List[int]): return n -def get_texts_and_num_trailing_blanks( +def get_tokens_and_timestamps(labels: List[int]) -> Tuple[List[int], List[int]]: + tokens = [] + timestamps = [] + for i, v in enumerate(labels): + if v != 0: + tokens.append(v) + timestamps.append(i) + + return tokens, timestamps + + +def get_fast_beam_search_results( best_paths: k2.Fsa, -) -> Tuple[Union[List[List[int]], k2.RaggedTensor], List[int]]: +) -> FastBeamSearchResults: """Extract the texts (as word IDs) from the best-path FSAs. Args: best_paths: @@ -46,9 +76,7 @@ def get_texts_and_num_trailing_blanks( of k2.shortest_path (otherwise the returned values won't be meaningful). Returns: - Return a tuple containing - - a list of lists of int, containing the label sequences we decoded. - - number of trailing blank frames + Return the result """ if isinstance(best_paths.aux_labels, k2.RaggedTensor): # remove 0's and -1's. @@ -76,11 +104,21 @@ def get_texts_and_num_trailing_blanks( ).tolist() num_trailing_blanks = [] + tokens = [] + timestamps = [] for labels in labels_list: # [:-1] to remove the last -1 num_trailing_blanks.append(count_num_trailing_zeros(labels[:-1])) - - return aux_labels.tolist(), num_trailing_blanks + token, time = get_tokens_and_timestamps(labels[:-1]) + tokens.append(token) + timestamps.append(time) + + return FastBeamSearchResults( + hyps=aux_labels.tolist(), + num_trailing_blanks=num_trailing_blanks, + tokens=tokens, + timestamps=timestamps, + ) def add_beam_search_arguments(): diff --git a/sherpa/python/test/CMakeLists.txt b/sherpa/python/test/CMakeLists.txt index fd93fc69c..dee938c28 100644 --- a/sherpa/python/test/CMakeLists.txt +++ b/sherpa/python/test/CMakeLists.txt @@ -19,6 +19,7 @@ endfunction() set(py_test_files test_hypothesis.py test_online_endpoint.py + test_timestamp.py test_utils.py ) diff --git a/sherpa/python/test/test_timestamp.py b/sherpa/python/test/test_timestamp.py new file mode 100755 index 000000000..3f8ac1475 --- /dev/null +++ b/sherpa/python/test/test_timestamp.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R test_timestamp_py + +import unittest + +import sherpa + + +class TestTimeStamp(unittest.TestCase): + def test_convert_timestamp(self): + subsampling_factor = 4 + frame_shift_ms = 10 + + frames = [0, 1, 2, 3, 5, 8, 10] + timestamps = sherpa.convert_timestamp( + frames, + subsampling_factor=4, + frame_shift_ms=10, + ) + for i in range(len(frames)): + assert timestamps[i] == ( + frames[i] * subsampling_factor * frame_shift_ms / 1000 + ), (frames[i], timestamps[i]) + + +if __name__ == "__main__": + unittest.main() diff --git a/sherpa/python/test/test_utils.py b/sherpa/python/test/test_utils.py index 16832605d..f2ed3c16f 100755 --- a/sherpa/python/test/test_utils.py +++ b/sherpa/python/test/test_utils.py @@ -35,7 +35,7 @@ def test_count_number_trailing_zeros(self): assert sherpa.count_num_trailing_zeros([1, 0, 0]) == 2 assert sherpa.count_num_trailing_zeros([0, 0, 0]) == 3 - def test_get_texts_and_num_trailing_blanks_case1(self): + def test_fast_beam_search_results_case1(self): s1 = """ 0 1 0 0 0.0 1 2 1 1 0.2 @@ -63,13 +63,10 @@ def test_get_texts_and_num_trailing_blanks_case1(self): fsa3 = k2.Fsa.from_str(s3, acceptor=False) fsa = k2.Fsa.from_fsas([fsa1, fsa2, fsa3]) - ( - aux_labels, - num_trailing_blanks, - ) = sherpa.get_texts_and_num_trailing_blanks(fsa) + res = sherpa.get_fast_beam_search_results(fsa) - assert aux_labels == [[1, 5], [1], [1]] - assert num_trailing_blanks == [0, 2, 1] + assert res.hyps == [[1, 5], [1], [1]] + assert res.num_trailing_blanks == [0, 2, 1] if __name__ == "__main__":