diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 23fad3df1..a2c244525 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -103,11 +103,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } decoder_ = std::make_unique( - model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale, unk_id_, config_.blank_penalty); + model_.get(), + lm_.get(), + config_.max_active_paths, + config_.lm_config.scale, + unk_id_, + config_.blank_penalty, + config_.temperature_scale); + } else if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( - model_.get(), unk_id_, config_.blank_penalty); + model_.get(), + unk_id_, + config_.blank_penalty, + config_.temperature_scale); + } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); @@ -141,11 +151,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } decoder_ = std::make_unique( - model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale, unk_id_, config_.blank_penalty); + model_.get(), + lm_.get(), + config_.max_active_paths, + config_.lm_config.scale, + unk_id_, + config_.blank_penalty, + config_.temperature_scale); + } else if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( - model_.get(), unk_id_, config_.blank_penalty); + model_.get(), + unk_id_, + config_.blank_penalty, + config_.temperature_scale); + } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 8bd0c16ad..a7fdbdff3 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -96,6 +96,8 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { po->Register("decoding-method", &decoding_method, "decoding method," "now support greedy_search and modified_beam_search."); + po->Register("temperature-scale", &temperature_scale, + "Temperature scale for confidence computation in decoding."); } bool OnlineRecognizerConfig::Validate() const { @@ -142,7 +144,8 @@ std::string OnlineRecognizerConfig::ToString() const { os << "hotwords_score=" << hotwords_score << ", "; os << "hotwords_file=\"" << hotwords_file << "\", "; os << "decoding_method=\"" << decoding_method << "\", "; - os << "blank_penalty=" << blank_penalty << ")"; + os << "blank_penalty=" << blank_penalty << ", "; + os << "temperature_scale=" << temperature_scale << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 308cb08f7..d8503bd13 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -96,16 +96,23 @@ struct OnlineRecognizerConfig { float blank_penalty = 0.0; + float temperature_scale = 2.0; + OnlineRecognizerConfig() = default; OnlineRecognizerConfig( const FeatureExtractorConfig &feat_config, - const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config, + const OnlineModelConfig &model_config, + const OnlineLMConfig &lm_config, const EndpointConfig &endpoint_config, const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, - bool enable_endpoint, const std::string &decoding_method, - int32_t max_active_paths, const std::string &hotwords_file, - float hotwords_score, float blank_penalty) + bool enable_endpoint, + const std::string &decoding_method, + int32_t max_active_paths, + const std::string &hotwords_file, + float hotwords_score, + float blank_penalty, + float temperature_scale) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), @@ -114,9 +121,10 @@ struct OnlineRecognizerConfig { enable_endpoint(enable_endpoint), decoding_method(decoding_method), max_active_paths(max_active_paths), - hotwords_score(hotwords_score), hotwords_file(hotwords_file), - blank_penalty(blank_penalty) {} + hotwords_score(hotwords_score), + blank_penalty(blank_penalty), + temperature_scale(temperature_scale) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index 05523dbb3..03447fc18 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -144,6 +144,10 @@ void OnlineTransducerGreedySearchDecoder::Decode( // export the per-token log scores if (y != 0 && y != unk_id_) { + // apply temperature-scaling + for (int32_t n = 0; n < vocab_size; ++n) { + p_logit[n] /= temperature_scale_; + } LogSoftmax(p_logit, vocab_size); // renormalize probabilities, // save time by doing it only for // emitted symbols diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h index c68c32dcf..716f88484 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h @@ -15,8 +15,13 @@ namespace sherpa_onnx { class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { public: OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, - int32_t unk_id, float blank_penalty) - : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} + int32_t unk_id, + float blank_penalty, + float temperature_scale) + : model_(model), + unk_id_(unk_id), + blank_penalty_(blank_penalty), + temperature_scale_(temperature_scale) {} OnlineTransducerDecoderResult GetEmptyResult() const override; @@ -29,6 +34,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { OnlineTransducerModel *model_; // Not owned int32_t unk_id_; float blank_penalty_; + float temperature_scale_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index 84fb46059..ea3f78f4b 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -129,6 +129,22 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); float *p_logit = logit.GetTensorMutableData(); + + // copy raw logits, apply temperature-scaling (for confidences) + // Note: temperature scaling is used only for the confidences, + // the decoding algorithm uses the original logits + int32_t p_logit_items = vocab_size * num_hyps; + std::vector logit_with_temperature(p_logit_items); + { + std::copy(p_logit, + p_logit + p_logit_items, + logit_with_temperature.begin()); + for (float& elem : logit_with_temperature) { + elem /= temperature_scale_; + } + LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps); + } + if (blank_penalty_ > 0.0) { // assuming blank id is 0 SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_); @@ -188,10 +204,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( // score of the transducer // export the per-token log scores if (new_token != 0 && new_token != unk_id_) { - const Hypothesis &prev_i = prev[hyp_index]; - // subtract 'prev[i]' path scores, which were added before - // getting topk tokens - float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob; + float y_prob = logit_with_temperature[start * vocab_size + k]; new_hyp.ys_probs.push_back(y_prob); if (lm_) { // export only when LM is used @@ -213,7 +226,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( cur.push_back(std::move(hyps)); p_logprob += (end - start) * vocab_size; } // for (int32_t b = 0; b != batch_size; ++b) - } + } // for (int32_t t = 0; t != num_frames; ++t) for (int32_t b = 0; b != batch_size; ++b) { auto &hyps = cur[b]; diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index 92e9a69c9..839aa768a 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -22,13 +22,15 @@ class OnlineTransducerModifiedBeamSearchDecoder OnlineLM *lm, int32_t max_active_paths, float lm_scale, int32_t unk_id, - float blank_penalty) + float blank_penalty, + float temperature_scale) : model_(model), lm_(lm), max_active_paths_(max_active_paths), lm_scale_(lm_scale), unk_id_(unk_id), - blank_penalty_(blank_penalty) {} + blank_penalty_(blank_penalty), + temperature_scale_(temperature_scale) {} OnlineTransducerDecoderResult GetEmptyResult() const override; @@ -50,6 +52,7 @@ class OnlineTransducerModifiedBeamSearchDecoder float lm_scale_; // used only when lm_ is not nullptr int32_t unk_id_; float blank_penalty_; + float temperature_scale_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index bd98c94e2..79f154699 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -50,17 +50,30 @@ static void PybindOnlineRecognizerConfig(py::module *m) { using PyClass = OnlineRecognizerConfig; py::class_(*m, "OnlineRecognizerConfig") .def( - py::init(), - py::arg("feat_config"), py::arg("model_config"), + py::init(), + py::arg("feat_config"), + py::arg("model_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config") = EndpointConfig(), py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), - py::arg("enable_endpoint"), py::arg("decoding_method"), - py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0) + py::arg("enable_endpoint"), + py::arg("decoding_method"), + py::arg("max_active_paths") = 4, + py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 0, + py::arg("blank_penalty") = 0.0, + py::arg("temperature_scale") = 2.0) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -72,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { .def_readwrite("hotwords_file", &PyClass::hotwords_file) .def_readwrite("hotwords_score", &PyClass::hotwords_score) .def_readwrite("blank_penalty", &PyClass::blank_penalty) + .def_readwrite("temperature_scale", &PyClass::temperature_scale) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index a82ab1703..520000028 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -58,6 +58,7 @@ def from_transducer( model_type: str = "", lm: str = "", lm_scale: float = 0.1, + temperature_scale: float = 2.0, ): """ Please refer to @@ -123,6 +124,10 @@ def from_transducer( hotwords_score: The hotword score of each token for biasing word/phrase. Used only if hotwords_file is given with modified_beam_search as decoding method. + temperature_scale: + Temperature scaling for output symbol confidence estiamation. + It affects only confidence values, the decoding uses the original + logits without temperature. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. model_type: @@ -193,6 +198,7 @@ def from_transducer( hotwords_score=hotwords_score, hotwords_file=hotwords_file, blank_penalty=blank_penalty, + temperature_scale=temperature_scale, ) self.recognizer = _Recognizer(recognizer_config)