diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 96cc822bb..38ffadb04 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -128,15 +128,60 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult( const auto &text = result.text; auto r = new SherpaOnnxOnlineRecognizerResult; + // copy text r->text = new char[text.size() + 1]; std::copy(text.begin(), text.end(), const_cast(r->text)); const_cast(r->text)[text.size()] = 0; + // copy json + const auto &json = result.AsJsonString(); + r->json = new char[json.size() + 1]; + std::copy(json.begin(), json.end(), const_cast(r->json)); + const_cast(r->json)[json.size()] = 0; + + // copy tokens + auto count = result.tokens.size(); + if (count > 0) { + size_t total_length = 0; + for (const auto& token : result.tokens) { + // +1 for the null character at the end of each token + total_length += token.size() + 1; + } + + r->count = count; + // Each word ends with nullptr + r->tokens = new char[total_length]; + memset(reinterpret_cast(const_cast(r->tokens)), 0, + total_length); + r->timestamps = new float[r->count]; + char **tokens_temp = new char*[r->count]; + int32_t pos = 0; + for (int32_t i = 0; i < r->count; ++i) { + tokens_temp[i] = const_cast(r->tokens) + pos; + memcpy(reinterpret_cast(const_cast(r->tokens + pos)), + result.tokens[i].c_str(), result.tokens[i].size()); + // +1 to move past the null character + pos += result.tokens[i].size() + 1; + r->timestamps[i] = result.timestamps[i]; + } + + r->tokens_arr = tokens_temp; + } else { + r->count = 0; + r->timestamps = nullptr; + r->tokens = nullptr; + r->tokens_arr = nullptr; + } + return r; } void DestroyOnlineRecognizerResult(const SherpaOnnxOnlineRecognizerResult *r) { delete[] r->text; + delete[] r->json; + delete[] r->tokens; + delete[] r->tokens_arr; + delete[] r->timestamps; delete r; } diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index ae30fbe4f..cb1fa7e8c 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -101,8 +101,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { } SherpaOnnxOnlineRecognizerConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult { + // Recognized text const char *text; - // TODO(fangjun): Add more fields + + // Pointer to continuous memory which holds string based tokens + // which are seperated by \0 + const char *tokens; + + // a pointer array contains the address of the first item in tokens + const char *const *tokens_arr; + + // Pointer to continuous memory which holds timestamps + float *timestamps; + + // The number of tokens/timestamps in above pointer + int32_t count; + + /** Return a json string. + * + * The returned string contains: + * { + * "text": "The recognition result", + * "tokens": [x, x, x], + * "timestamps": [x, x, x], + * "segment": x, + * "start_time": x, + * "is_final": true|false + * } + */ + const char *json; } SherpaOnnxOnlineRecognizerResult; /// Note: OnlineRecognizer here means StreamingRecognizer. diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index 95cc52942..5f9a7734e 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -58,6 +58,11 @@ class SherpaOnnx { return result.text; } + const std::vector GetTokens() const { + auto result = recognizer_.GetResult(stream_.get()); + return result.tokens; + } + bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); } bool IsReady() const { return recognizer_.IsReady(stream_.get()); } @@ -312,6 +317,29 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText( return env->NewStringUTF(text.c_str()); } +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + auto tokens = reinterpret_cast(ptr)->GetTokens(); + int size = tokens.size(); + jclass stringClass = env->FindClass("java/lang/String"); + + // convert C++ list into jni string array + jobjectArray result = env->NewObjectArray(size, stringClass, NULL); + for (int i = 0; i < size; i++) { + // Convert the C++ string to a C string + const char* cstr = tokens[i].c_str(); + + // Convert the C string to a jstring + jstring jstr = env->NewStringUTF(cstr); + + // Set the array element + env->SetObjectArrayElement(result, i, jstr); + } + + return result; +} + // see // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables static jobject NewInteger(JNIEnv *env, int32_t value) { diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index c22c938a4..70565abd9 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -99,6 +99,26 @@ class SherpaOnnxOnlineRecongitionResult { return String(cString: result.pointee.text) } + var count: Int32 { + return result.pointee.count + } + + var tokens: [String] { + if let tokensPointer = result.pointee.tokens_arr { + var tokens: [String] = [] + for index in 0..!) { self.result = result }