From 148798bfa3fa4c2a649de84fcb7cc30bfdd62678 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 17 Aug 2023 10:39:21 +0800 Subject: [PATCH] Fix a bug for multilingual ASR --- CMakeLists.txt | 2 +- sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 93c00af40..568dcc8f2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.7.8") +set(SHERPA_ONNX_VERSION "1.7.9") # Disable warning about # diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc index 036fab5be..396e76ec6 100644 --- a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc @@ -136,8 +136,10 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape(); int32_t vocab_size = logits_shape[2]; - int32_t max_token_id = static_cast(std::distance( - p_logits, std::max_element(p_logits, p_logits + vocab_size))); + const float *p_start = p_logits + (logits_shape[1] - 1) * vocab_size; + + int32_t max_token_id = static_cast( + std::distance(p_start, std::max_element(p_start, p_start + vocab_size))); int32_t n_text_ctx = model_->TextCtx();