Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding temperature scaling on Joiner logits: #789

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}

decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty);
model_.get(),
lm_.get(),
config_.max_active_paths,
config_.lm_config.scale,
unk_id_,
config_.blank_penalty,
config_.temperature_scale);

} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
model_.get(), unk_id_, config_.blank_penalty);
model_.get(),
unk_id_,
config_.blank_penalty,
config_.temperature_scale);

} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
Expand Down Expand Up @@ -141,11 +151,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}

decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty);
model_.get(),
lm_.get(),
config_.max_active_paths,
config_.lm_config.scale,
unk_id_,
config_.blank_penalty,
config_.temperature_scale);

} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
model_.get(), unk_id_, config_.blank_penalty);
model_.get(),
unk_id_,
config_.blank_penalty,
config_.temperature_scale);

} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
Expand Down
5 changes: 4 additions & 1 deletion sherpa-onnx/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
po->Register("temperature-scale", &temperature_scale,
"Temperature scale for confidence computation in decoding.");
}

bool OnlineRecognizerConfig::Validate() const {
Expand Down Expand Up @@ -142,7 +144,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "hotwords_score=" << hotwords_score << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "blank_penalty=" << blank_penalty << ")";
os << "blank_penalty=" << blank_penalty << ", ";
os << "temperature_scale=" << temperature_scale << ")";

return os.str();
}
Expand Down
20 changes: 14 additions & 6 deletions sherpa-onnx/csrc/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,23 @@ struct OnlineRecognizerConfig {

float blank_penalty = 0.0;

float temperature_scale = 2.0;

OnlineRecognizerConfig() = default;

OnlineRecognizerConfig(
const FeatureExtractorConfig &feat_config,
const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config,
const OnlineModelConfig &model_config,
const OnlineLMConfig &lm_config,
const EndpointConfig &endpoint_config,
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
bool enable_endpoint, const std::string &decoding_method,
int32_t max_active_paths, const std::string &hotwords_file,
float hotwords_score, float blank_penalty)
bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths,
const std::string &hotwords_file,
float hotwords_score,
float blank_penalty,
float temperature_scale)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
Expand All @@ -114,9 +121,10 @@ struct OnlineRecognizerConfig {
enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths),
hotwords_score(hotwords_score),
hotwords_file(hotwords_file),
blank_penalty(blank_penalty) {}
hotwords_score(hotwords_score),
blank_penalty(blank_penalty),
temperature_scale(temperature_scale) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ void OnlineTransducerGreedySearchDecoder::Decode(

// export the per-token log scores
if (y != 0 && y != unk_id_) {
// apply temperature-scaling
for (int32_t n = 0; n < vocab_size; ++n) {
p_logit[n] /= temperature_scale_;
}
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
// save time by doing it only for
// emitted symbols
Expand Down
10 changes: 8 additions & 2 deletions sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@ namespace sherpa_onnx {
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
public:
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
int32_t unk_id, float blank_penalty)
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
int32_t unk_id,
float blank_penalty,
float temperature_scale)
: model_(model),
unk_id_(unk_id),
blank_penalty_(blank_penalty),
temperature_scale_(temperature_scale) {}

OnlineTransducerDecoderResult GetEmptyResult() const override;

Expand All @@ -29,6 +34,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
OnlineTransducerModel *model_; // Not owned
int32_t unk_id_;
float blank_penalty_;
float temperature_scale_;
};

} // namespace sherpa_onnx
Expand Down
23 changes: 18 additions & 5 deletions sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));

float *p_logit = logit.GetTensorMutableData<float>();

// copy raw logits, apply temperature-scaling (for confidences)
// Note: temperature scaling is used only for the confidences,
// the decoding algorithm uses the original logits
int32_t p_logit_items = vocab_size * num_hyps;
std::vector<float> logit_with_temperature(p_logit_items);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change p_logit in-place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it cannot be done in-place

the idea is to apply temperature only for computation of confidences,
the decoding continues to use the original values

this is why the logit values are copied to a new buffer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation.

{
std::copy(p_logit,
p_logit + p_logit_items,
logit_with_temperature.begin());
for (float& elem : logit_with_temperature) {
elem /= temperature_scale_;
}
LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps);
}

if (blank_penalty_ > 0.0) {
// assuming blank id is 0
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
Expand Down Expand Up @@ -188,10 +204,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
// score of the transducer
// export the per-token log scores
if (new_token != 0 && new_token != unk_id_) {
const Hypothesis &prev_i = prev[hyp_index];
// subtract 'prev[i]' path scores, which were added before
// getting topk tokens
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
float y_prob = logit_with_temperature[start * vocab_size + k];
new_hyp.ys_probs.push_back(y_prob);

if (lm_) { // export only when LM is used
Expand All @@ -213,7 +226,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
cur.push_back(std::move(hyps));
p_logprob += (end - start) * vocab_size;
} // for (int32_t b = 0; b != batch_size; ++b)
}
} // for (int32_t t = 0; t != num_frames; ++t)

for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ class OnlineTransducerModifiedBeamSearchDecoder
OnlineLM *lm,
int32_t max_active_paths,
float lm_scale, int32_t unk_id,
float blank_penalty)
float blank_penalty,
float temperature_scale)
: model_(model),
lm_(lm),
max_active_paths_(max_active_paths),
lm_scale_(lm_scale),
unk_id_(unk_id),
blank_penalty_(blank_penalty) {}
blank_penalty_(blank_penalty),
temperature_scale_(temperature_scale) {}

OnlineTransducerDecoderResult GetEmptyResult() const override;

Expand All @@ -50,6 +52,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
float lm_scale_; // used only when lm_ is not nullptr
int32_t unk_id_;
float blank_penalty_;
float temperature_scale_;
};

} // namespace sherpa_onnx
Expand Down
30 changes: 22 additions & 8 deletions sherpa-onnx/python/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,30 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(
py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool, const std::string &,
int32_t, const std::string &, float, float>(),
py::arg("feat_config"), py::arg("model_config"),
py::init<const FeatureExtractorConfig &,
const OnlineModelConfig &,
const OnlineLMConfig &,
const EndpointConfig &,
const OnlineCtcFstDecoderConfig &,
bool,
const std::string &,
int32_t,
const std::string &,
float,
float,
float>(),
py::arg("feat_config"),
py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(),
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
py::arg("enable_endpoint"),
py::arg("decoding_method"),
py::arg("max_active_paths") = 4,
py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0,
py::arg("blank_penalty") = 0.0,
py::arg("temperature_scale") = 2.0)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
Expand All @@ -72,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def_readwrite("temperature_scale", &PyClass::temperature_scale)
.def("__str__", &PyClass::ToString);
}

Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/online_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def from_transducer(
model_type: str = "",
lm: str = "",
lm_scale: float = 0.1,
temperature_scale: float = 2.0,
):
"""
Please refer to
Expand Down Expand Up @@ -123,6 +124,10 @@ def from_transducer(
hotwords_score:
The hotword score of each token for biasing word/phrase. Used only if
hotwords_file is given with modified_beam_search as decoding method.
temperature_scale:
Temperature scaling for output symbol confidence estiamation.
It affects only confidence values, the decoding uses the original
logits without temperature.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
model_type:
Expand Down Expand Up @@ -193,6 +198,7 @@ def from_transducer(
hotwords_score=hotwords_score,
hotwords_file=hotwords_file,
blank_penalty=blank_penalty,
temperature_scale=temperature_scale,
)

self.recognizer = _Recognizer(recognizer_config)
Expand Down
Loading