Skip to content

Commit

Permalink
Implement Tokens in Swift and Kotlin (#227)
Browse files Browse the repository at this point in the history
Co-authored-by: duc <[email protected]>
  • Loading branch information
w11wo and ductranminh authored Aug 5, 2023
1 parent c575673 commit 64efbd8
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 1 deletion.
45 changes: 45 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,60 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
const auto &text = result.text;

auto r = new SherpaOnnxOnlineRecognizerResult;
// copy text
r->text = new char[text.size() + 1];
std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
const_cast<char *>(r->text)[text.size()] = 0;

// copy json
const auto &json = result.AsJsonString();
r->json = new char[json.size() + 1];
std::copy(json.begin(), json.end(), const_cast<char *>(r->json));
const_cast<char *>(r->json)[json.size()] = 0;

// copy tokens
auto count = result.tokens.size();
if (count > 0) {
size_t total_length = 0;
for (const auto& token : result.tokens) {
// +1 for the null character at the end of each token
total_length += token.size() + 1;
}

r->count = count;
// Each word ends with nullptr
r->tokens = new char[total_length];
memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
total_length);
r->timestamps = new float[r->count];
char **tokens_temp = new char*[r->count];
int32_t pos = 0;
for (int32_t i = 0; i < r->count; ++i) {
tokens_temp[i] = const_cast<char*>(r->tokens) + pos;
memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)),
result.tokens[i].c_str(), result.tokens[i].size());
// +1 to move past the null character
pos += result.tokens[i].size() + 1;
r->timestamps[i] = result.timestamps[i];
}

r->tokens_arr = tokens_temp;
} else {
r->count = 0;
r->timestamps = nullptr;
r->tokens = nullptr;
r->tokens_arr = nullptr;
}

return r;
}

void DestroyOnlineRecognizerResult(const SherpaOnnxOnlineRecognizerResult *r) {
delete[] r->text;
delete[] r->json;
delete[] r->tokens;
delete[] r->tokens_arr;
delete[] r->timestamps;
delete r;
}

Expand Down
29 changes: 28 additions & 1 deletion sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
} SherpaOnnxOnlineRecognizerConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult {
// Recognized text
const char *text;
// TODO(fangjun): Add more fields

// Pointer to continuous memory which holds string based tokens
// which are seperated by \0
const char *tokens;

// a pointer array contains the address of the first item in tokens
const char *const *tokens_arr;

// Pointer to continuous memory which holds timestamps
float *timestamps;

// The number of tokens/timestamps in above pointer
int32_t count;

/** Return a json string.
*
* The returned string contains:
* {
* "text": "The recognition result",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "segment": x,
* "start_time": x,
* "is_final": true|false
* }
*/
const char *json;
} SherpaOnnxOnlineRecognizerResult;

/// Note: OnlineRecognizer here means StreamingRecognizer.
Expand Down
28 changes: 28 additions & 0 deletions sherpa-onnx/jni/jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class SherpaOnnx {
return result.text;
}

const std::vector<std::string> GetTokens() const {
auto result = recognizer_.GetResult(stream_.get());
return result.tokens;
}

bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }

bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
Expand Down Expand Up @@ -312,6 +317,29 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
return env->NewStringUTF(text.c_str());
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto tokens = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetTokens();
int size = tokens.size();
jclass stringClass = env->FindClass("java/lang/String");

// convert C++ list into jni string array
jobjectArray result = env->NewObjectArray(size, stringClass, NULL);
for (int i = 0; i < size; i++) {
// Convert the C++ string to a C string
const char* cstr = tokens[i].c_str();

// Convert the C string to a jstring
jstring jstr = env->NewStringUTF(cstr);

// Set the array element
env->SetObjectArrayElement(result, i, jstr);
}

return result;
}

// see
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
static jobject NewInteger(JNIEnv *env, int32_t value) {
Expand Down
20 changes: 20 additions & 0 deletions swift-api-examples/SherpaOnnx.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ class SherpaOnnxOnlineRecongitionResult {
return String(cString: result.pointee.text)
}

var count: Int32 {
return result.pointee.count
}

var tokens: [String] {
if let tokensPointer = result.pointee.tokens_arr {
var tokens: [String] = []
for index in 0..<count {
if let tokenPointer = tokensPointer[Int(index)] {
let token = String(cString: tokenPointer)
tokens.append(token)
}
}
return tokens
} else {
let tokens: [String] = []
return tokens
}
}

init(result: UnsafePointer<SherpaOnnxOnlineRecognizerResult>!) {
self.result = result
}
Expand Down

0 comments on commit 64efbd8

Please sign in to comment.