Skip to content

Commit

Permalink
Refactor online recognizer (#250)
Browse files Browse the repository at this point in the history
* Refactor online recognizer.

Make it easier to support other streaming models.

Note that it is a breaking change for the Python API.
`sherpa_onnx.OnlineRecognizer()` used before should be
replaced by `sherpa_onnx.OnlineRecognizer.from_transducer()`.
  • Loading branch information
csukuangfj authored Aug 9, 2023
1 parent 6061318 commit 79c2ce5
Show file tree
Hide file tree
Showing 40 changed files with 670 additions and 480 deletions.
2 changes: 1 addition & 1 deletion python-api-examples/online-decode-files.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def main():
assert_file_exists(args.joiner)
assert_file_exists(args.tokens)

recognizer = sherpa_onnx.OnlineRecognizer(
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def create_recognizer():
# Please replace the model files if needed.
# See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# for download links.
recognizer = sherpa_onnx.OnlineRecognizer(
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
Expand Down
2 changes: 1 addition & 1 deletion python-api-examples/speech-recognition-from-microphone.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def create_recognizer():
# Please replace the model files if needed.
# See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# for download links.
recognizer = sherpa_onnx.OnlineRecognizer(
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
Expand Down
2 changes: 1 addition & 1 deletion python-api-examples/speech-recognition-from-url.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def create_recognizer(args):
# Please replace the model files if needed.
# See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
# for download links.
recognizer = sherpa_onnx.OnlineRecognizer(
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
Expand Down
2 changes: 1 addition & 1 deletion python-api-examples/streaming_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def get_args():


def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
recognizer = sherpa_onnx.OnlineRecognizer(
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
tokens=args.tokens,
encoder=args.encoder_model,
decoder=args.decoder_model,
Expand Down
12 changes: 6 additions & 6 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
recognizer_config.feat_config.feature_dim =
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);

recognizer_config.model_config.encoder_filename =
recognizer_config.model_config.transducer.encoder =
SHERPA_ONNX_OR(config->model_config.encoder, "");
recognizer_config.model_config.decoder_filename =
recognizer_config.model_config.transducer.decoder =
SHERPA_ONNX_OR(config->model_config.decoder, "");
recognizer_config.model_config.joiner_filename =
recognizer_config.model_config.transducer.joiner =
SHERPA_ONNX_OR(config->model_config.joiner, "");
recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
Expand Down Expand Up @@ -143,7 +143,7 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
auto count = result.tokens.size();
if (count > 0) {
size_t total_length = 0;
for (const auto& token : result.tokens) {
for (const auto &token : result.tokens) {
// +1 for the null character at the end of each token
total_length += token.size() + 1;
}
Expand All @@ -154,10 +154,10 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
total_length);
r->timestamps = new float[r->count];
char **tokens_temp = new char*[r->count];
char **tokens_temp = new char *[r->count];
int32_t pos = 0;
for (int32_t i = 0; i < r->count; ++i) {
tokens_temp[i] = const_cast<char*>(r->tokens) + pos;
tokens_temp[i] = const_cast<char *>(r->tokens) + pos;
memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)),
result.tokens[i].c_str(), result.tokens[i].size());
// +1 to move past the null character
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ set(sources
online-lm-config.cc
online-lm.cc
online-lstm-transducer-model.cc
online-model-config.cc
online-recognizer-impl.cc
online-recognizer.cc
online-rnn-lm.cc
online-stream.cc
Expand Down
16 changes: 8 additions & 8 deletions sherpa-onnx/csrc/online-conformer-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,46 +30,46 @@
namespace sherpa_onnx {

OnlineConformerTransducerModel::OnlineConformerTransducerModel(
const OnlineTransducerModelConfig &config)
const OnlineModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.encoder_filename);
auto buf = ReadFile(config.transducer.encoder);
InitEncoder(buf.data(), buf.size());
}

{
auto buf = ReadFile(config.decoder_filename);
auto buf = ReadFile(config.transducer.decoder);
InitDecoder(buf.data(), buf.size());
}

{
auto buf = ReadFile(config.joiner_filename);
auto buf = ReadFile(config.transducer.joiner);
InitJoiner(buf.data(), buf.size());
}
}

#if __ANDROID_API__ >= 9
OnlineConformerTransducerModel::OnlineConformerTransducerModel(
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
AAssetManager *mgr, const OnlineModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.encoder_filename);
auto buf = ReadFile(mgr, config.transducer.encoder);
InitEncoder(buf.data(), buf.size());
}

{
auto buf = ReadFile(mgr, config.decoder_filename);
auto buf = ReadFile(mgr, config.transducer.decoder);
InitDecoder(buf.data(), buf.size());
}

{
auto buf = ReadFile(mgr, config.joiner_filename);
auto buf = ReadFile(mgr, config.transducer.joiner);
InitJoiner(buf.data(), buf.size());
}
}
Expand Down
9 changes: 4 additions & 5 deletions sherpa-onnx/csrc/online-conformer-transducer-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"

namespace sherpa_onnx {

class OnlineConformerTransducerModel : public OnlineTransducerModel {
public:
explicit OnlineConformerTransducerModel(
const OnlineTransducerModelConfig &config);
explicit OnlineConformerTransducerModel(const OnlineModelConfig &config);

#if __ANDROID_API__ >= 9
OnlineConformerTransducerModel(AAssetManager *mgr,
const OnlineTransducerModelConfig &config);
const OnlineModelConfig &config);
#endif

std::vector<Ort::Value> StackStates(
Expand Down Expand Up @@ -88,7 +87,7 @@ class OnlineConformerTransducerModel : public OnlineTransducerModel {
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;

OnlineTransducerModelConfig config_;
OnlineModelConfig config_;

int32_t num_encoder_layers_ = 0;
int32_t T_ = 0;
Expand Down
16 changes: 8 additions & 8 deletions sherpa-onnx/csrc/online-lstm-transducer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,46 +28,46 @@
namespace sherpa_onnx {

OnlineLstmTransducerModel::OnlineLstmTransducerModel(
const OnlineTransducerModelConfig &config)
const OnlineModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.encoder_filename);
auto buf = ReadFile(config.transducer.encoder);
InitEncoder(buf.data(), buf.size());
}

{
auto buf = ReadFile(config.decoder_filename);
auto buf = ReadFile(config.transducer.decoder);
InitDecoder(buf.data(), buf.size());
}

{
auto buf = ReadFile(config.joiner_filename);
auto buf = ReadFile(config.transducer.joiner);
InitJoiner(buf.data(), buf.size());
}
}

#if __ANDROID_API__ >= 9
OnlineLstmTransducerModel::OnlineLstmTransducerModel(
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
AAssetManager *mgr, const OnlineModelConfig &config)
: env_(ORT_LOGGING_LEVEL_WARNING),
config_(config),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.encoder_filename);
auto buf = ReadFile(mgr, config.transducer.encoder);
InitEncoder(buf.data(), buf.size());
}

{
auto buf = ReadFile(mgr, config.decoder_filename);
auto buf = ReadFile(mgr, config.transducer.decoder);
InitDecoder(buf.data(), buf.size());
}

{
auto buf = ReadFile(mgr, config.joiner_filename);
auto buf = ReadFile(mgr, config.transducer.joiner);
InitJoiner(buf.data(), buf.size());
}
}
Expand Down
8 changes: 4 additions & 4 deletions sherpa-onnx/csrc/online-lstm-transducer-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
#endif

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"

namespace sherpa_onnx {

class OnlineLstmTransducerModel : public OnlineTransducerModel {
public:
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
explicit OnlineLstmTransducerModel(const OnlineModelConfig &config);

#if __ANDROID_API__ >= 9
OnlineLstmTransducerModel(AAssetManager *mgr,
const OnlineTransducerModelConfig &config);
const OnlineModelConfig &config);
#endif

std::vector<Ort::Value> StackStates(
Expand Down Expand Up @@ -86,7 +86,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
std::vector<std::string> joiner_output_names_;
std::vector<const char *> joiner_output_names_ptr_;

OnlineTransducerModelConfig config_;
OnlineModelConfig config_;

int32_t num_encoder_layers_ = 0;
int32_t T_ = 0;
Expand Down
61 changes: 61 additions & 0 deletions sherpa-onnx/csrc/online-model-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// sherpa-onnx/csrc/online-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-model-config.h"

#include <string>

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

void OnlineModelConfig::Register(ParseOptions *po) {
transducer.Register(po);

po->Register("tokens", &tokens, "Path to tokens.txt");

po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");

po->Register("debug", &debug,
"true to print model information while loading it.");

po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");

po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: conformer, lstm, zipformer, zipformer2."
"All other values lead to loading the model twice.");
}

bool OnlineModelConfig::Validate() const {
if (num_threads < 1) {
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
return false;
}

if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
return false;
}

return transducer.Validate();
}

std::string OnlineModelConfig::ToString() const {
std::ostringstream os;

os << "OnlineModelConfig(";
os << "transducer=" << transducer.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\", ";
os << "model_type=\"" << model_type << "\")";

return os.str();
}

} // namespace sherpa_onnx
48 changes: 48 additions & 0 deletions sherpa-onnx/csrc/online-model-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// sherpa-onnx/csrc/online-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_

#include <string>

#include "sherpa-onnx/csrc/online-transducer-model-config.h"

namespace sherpa_onnx {

struct OnlineModelConfig {
OnlineTransducerModelConfig transducer;
std::string tokens;
int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";

// Valid values:
// - conformer, conformer transducer from icefall
// - lstm, lstm transducer from icefall
// - zipformer, zipformer transducer from icefall
// - zipformer2, zipformer2 transducer from icefall
//
// All other values are invalid and lead to loading the model twice.
std::string model_type;

OnlineModelConfig() = default;
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
: transducer(transducer),
tokens(tokens),
num_threads(num_threads),
debug(debug),
provider(provider),
model_type(model_type) {}

void Register(ParseOptions *po);
bool Validate() const;

std::string ToString() const;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_
Loading

0 comments on commit 79c2ce5

Please sign in to comment.