Skip to content

Commit

Permalink
style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
SilverSulfide committed Aug 23, 2024
1 parent 132770c commit 5aa9f82
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 11 deletions.
6 changes: 4 additions & 2 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {

decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, config_.blank_penalty,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
config_.blank_penalty,
config_.temperature_scale);

} else if (config.decoding_method == "greedy_search") {
Expand Down Expand Up @@ -156,7 +157,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {

decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, config_.blank_penalty,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
config_.blank_penalty,
config_.temperature_scale);

} else if (config.decoding_method == "greedy_search") {
Expand Down
10 changes: 6 additions & 4 deletions sherpa-onnx/csrc/online-rnn-lm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <string>
#include <utility>
#include <vector>
#include <algorithm>

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
Expand Down Expand Up @@ -76,8 +77,8 @@ class OnlineRnnLM::Impl {
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
allocator, x_shape.data(), x_shape.size());
int64_t *p_x = x.GetTensorMutableData<int64_t>();
std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
p_x);
std::copy(ys.begin() + context_size + h.cur_scored_pos,
ys.end() - 1, p_x);

// streaming forward by NN LM
auto out = ScoreToken(std::move(x),
Expand Down Expand Up @@ -176,7 +177,8 @@ class OnlineRnnLM::Impl {
states.push_back(std::move(c));
auto pair = ScoreToken(std::move(x), std::move(states));

init_scores_.value = std::move(pair.first); // only used during shallow fusion
init_scores_.value = std::move(pair.first); // only used during
// shallow fusion
init_states_ = std::move(pair.second);
}

Expand Down Expand Up @@ -234,4 +236,4 @@ void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
}


} // namespace sherpa_onnx
} // namespace sherpa_onnx
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(

// add log_prob of each hypothesis to p_logprob before taking top_k
for (int32_t i = 0; i != num_hyps; ++i) {

float log_prob = prev[i].log_prob;
if (lm_ && shallow_fusion_) {
log_prob += prev[i].lm_log_prob;
Expand Down Expand Up @@ -208,15 +207,18 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
prev_lm_log_prob; // log_prob only includes the
// score of the transducer
} else {
new_hyp.log_prob = p_logprob[k] + context_score; // for rescoring or no LM, previous token score is ignored
new_hyp.log_prob = p_logprob[k] + context_score; // rescore or no LM
// previous token
// score is ignored
}

// export the per-token log scores
if (new_token != 0 && new_token != unk_id_) {
float y_prob = logit_with_temperature[start * vocab_size + k];
new_hyp.ys_probs.push_back(y_prob);

if (lm_ && shallow_fusion_) { // export only when LM shallow fusion is used
if (lm_ && shallow_fusion_) { // export only if
// LM shallow fusion is used
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;

if (lm_scale_ != 0.0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class OnlineTransducerModifiedBeamSearchDecoder

int32_t max_active_paths_;
float lm_scale_; // used only when lm_ is not nullptr
bool shallow_fusion_; // used only when lm_ is not nullptr
bool shallow_fusion_; // used only when lm_ is not nullptr
int32_t unk_id_;
float blank_penalty_;
float temperature_scale_;
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/python/csrc/online-lm-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ namespace sherpa_onnx {
void PybindOnlineLMConfig(py::module *m) {
using PyClass = OnlineLMConfig;
py::class_<PyClass>(*m, "OnlineLMConfig")
.def(py::init<const std::string &, float, int32_t, const std::string &, bool>(),
.def(py::init<const std::string &, float, int32_t,
const std::string &, bool>(),
py::arg("model") = "", py::arg("scale") = 0.5f,
py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu"),
py::arg("shallow_fusion") = true
Expand Down

0 comments on commit 5aa9f82

Please sign in to comment.