-
Notifications
You must be signed in to change notification settings - Fork 414
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Kotlin API for speaker diarization (#1415)
- Loading branch information
1 parent
eefc172
commit 2d412b1
Showing
7 changed files
with
412 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package com.k2fsa.sherpa.onnx | ||
|
||
fun main() { | ||
testOfflineSpeakerDiarization() | ||
} | ||
|
||
fun callback(numProcessedChunks: Int, numTotalChunks: Int, arg: Long): Int { | ||
val progress = numProcessedChunks.toFloat() / numTotalChunks * 100 | ||
val s = "%.2f".format(progress) | ||
println("Progress: ${s}%"); | ||
|
||
return 0 | ||
} | ||
|
||
fun testOfflineSpeakerDiarization() { | ||
var config = OfflineSpeakerDiarizationConfig( | ||
segmentation=OfflineSpeakerSegmentationModelConfig( | ||
pyannote=OfflineSpeakerSegmentationPyannoteModelConfig("./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"), | ||
), | ||
embedding=SpeakerEmbeddingExtractorConfig( | ||
model="./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx", | ||
), | ||
|
||
// The test wave file ./0-four-speakers-zh.wav contains four speakers, so | ||
// we use numClusters=4 here. If you don't know the number of speakers | ||
// in the test wave file, please set the threshold like below. | ||
// | ||
// clustering=FastClusteringConfig(threshold=0.5), | ||
// | ||
// WARNING: You need to tune threshold by yourself. | ||
// A larger threshold leads to fewer clusters, i.e., few speakers. | ||
// A smaller threshold leads to more clusters, i.e., more speakers. | ||
// | ||
clustering=FastClusteringConfig(numClusters=4), | ||
) | ||
|
||
val sd = OfflineSpeakerDiarization(config=config) | ||
|
||
val waveData = WaveReader.readWave( | ||
filename = "./0-four-speakers-zh.wav", | ||
) | ||
|
||
if (sd.sampleRate() != waveData.sampleRate) { | ||
println("Expected sample rate: ${sd.sampleRate()}, given: ${waveData.sampleRate}") | ||
return | ||
} | ||
|
||
// val segments = sd.process(waveData.samples) // this one is also ok | ||
val segments = sd.processWithCallback(waveData.samples, callback=::callback) | ||
for (segment in segments) { | ||
println("${segment.start} -- ${segment.end} speaker_${segment.speaker}") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
// sherpa-onnx/jni/offline-speaker-diarization.cc | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/offline-speaker-diarization.h" | ||
|
||
#include "sherpa-onnx/csrc/macros.h" | ||
#include "sherpa-onnx/jni/common.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
static OfflineSpeakerDiarizationConfig GetOfflineSpeakerDiarizationConfig( | ||
JNIEnv *env, jobject config) { | ||
OfflineSpeakerDiarizationConfig ans; | ||
|
||
jclass cls = env->GetObjectClass(config); | ||
jfieldID fid; | ||
|
||
//---------- segmentation ---------- | ||
fid = env->GetFieldID( | ||
cls, "segmentation", | ||
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationModelConfig;"); | ||
jobject segmentation_config = env->GetObjectField(config, fid); | ||
jclass segmentation_config_cls = env->GetObjectClass(segmentation_config); | ||
|
||
fid = env->GetFieldID( | ||
segmentation_config_cls, "pyannote", | ||
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationPyannoteModelConfig;"); | ||
jobject pyannote_config = env->GetObjectField(segmentation_config, fid); | ||
jclass pyannote_config_cls = env->GetObjectClass(pyannote_config); | ||
|
||
fid = env->GetFieldID(pyannote_config_cls, "model", "Ljava/lang/String;"); | ||
jstring s = (jstring)env->GetObjectField(pyannote_config, fid); | ||
const char *p = env->GetStringUTFChars(s, nullptr); | ||
ans.segmentation.pyannote.model = p; | ||
env->ReleaseStringUTFChars(s, p); | ||
|
||
fid = env->GetFieldID(segmentation_config_cls, "numThreads", "I"); | ||
ans.segmentation.num_threads = env->GetIntField(segmentation_config, fid); | ||
|
||
fid = env->GetFieldID(segmentation_config_cls, "debug", "Z"); | ||
ans.segmentation.debug = env->GetBooleanField(segmentation_config, fid); | ||
|
||
fid = env->GetFieldID(segmentation_config_cls, "provider", | ||
"Ljava/lang/String;"); | ||
s = (jstring)env->GetObjectField(segmentation_config, fid); | ||
p = env->GetStringUTFChars(s, nullptr); | ||
ans.segmentation.provider = p; | ||
env->ReleaseStringUTFChars(s, p); | ||
|
||
//---------- embedding ---------- | ||
fid = env->GetFieldID( | ||
cls, "embedding", | ||
"Lcom/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig;"); | ||
jobject embedding_config = env->GetObjectField(config, fid); | ||
jclass embedding_config_cls = env->GetObjectClass(embedding_config); | ||
|
||
fid = env->GetFieldID(embedding_config_cls, "model", "Ljava/lang/String;"); | ||
s = (jstring)env->GetObjectField(embedding_config, fid); | ||
p = env->GetStringUTFChars(s, nullptr); | ||
ans.embedding.model = p; | ||
env->ReleaseStringUTFChars(s, p); | ||
|
||
fid = env->GetFieldID(embedding_config_cls, "numThreads", "I"); | ||
ans.embedding.num_threads = env->GetIntField(embedding_config, fid); | ||
|
||
fid = env->GetFieldID(embedding_config_cls, "debug", "Z"); | ||
ans.embedding.debug = env->GetBooleanField(embedding_config, fid); | ||
|
||
fid = env->GetFieldID(embedding_config_cls, "provider", "Ljava/lang/String;"); | ||
s = (jstring)env->GetObjectField(embedding_config, fid); | ||
p = env->GetStringUTFChars(s, nullptr); | ||
ans.embedding.provider = p; | ||
env->ReleaseStringUTFChars(s, p); | ||
|
||
//---------- clustering ---------- | ||
fid = env->GetFieldID(cls, "clustering", | ||
"Lcom/k2fsa/sherpa/onnx/FastClusteringConfig;"); | ||
jobject clustering_config = env->GetObjectField(config, fid); | ||
jclass clustering_config_cls = env->GetObjectClass(clustering_config); | ||
|
||
fid = env->GetFieldID(clustering_config_cls, "numClusters", "I"); | ||
ans.clustering.num_clusters = env->GetIntField(clustering_config, fid); | ||
|
||
fid = env->GetFieldID(clustering_config_cls, "threshold", "F"); | ||
ans.clustering.threshold = env->GetFloatField(clustering_config, fid); | ||
|
||
// its own fields | ||
fid = env->GetFieldID(cls, "minDurationOn", "F"); | ||
ans.min_duration_on = env->GetFloatField(config, fid); | ||
|
||
fid = env->GetFieldID(cls, "minDurationOff", "F"); | ||
ans.min_duration_off = env->GetFloatField(config, fid); | ||
|
||
return ans; | ||
} | ||
|
||
} // namespace sherpa_onnx | ||
|
||
SHERPA_ONNX_EXTERN_C | ||
JNIEXPORT jlong JNICALL | ||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset( | ||
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { | ||
return 0; | ||
} | ||
|
||
SHERPA_ONNX_EXTERN_C | ||
JNIEXPORT jlong JNICALL | ||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromFile( | ||
JNIEnv *env, jobject /*obj*/, jobject _config) { | ||
auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config); | ||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | ||
|
||
if (!config.Validate()) { | ||
SHERPA_ONNX_LOGE("Errors found in config!"); | ||
return 0; | ||
} | ||
|
||
auto sd = new sherpa_onnx::OfflineSpeakerDiarization(config); | ||
|
||
return (jlong)sd; | ||
} | ||
|
||
SHERPA_ONNX_EXTERN_C | ||
JNIEXPORT void JNICALL | ||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_setConfig( | ||
JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) { | ||
auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config); | ||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | ||
|
||
auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr); | ||
sd->SetConfig(config); | ||
} | ||
|
||
SHERPA_ONNX_EXTERN_C | ||
JNIEXPORT void JNICALL | ||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_delete(JNIEnv * /*env*/, | ||
jobject /*obj*/, | ||
jlong ptr) { | ||
delete reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr); | ||
} | ||
|
||
static jobjectArray ProcessImpl( | ||
JNIEnv *env, | ||
const std::vector<sherpa_onnx::OfflineSpeakerDiarizationSegment> | ||
&segments) { | ||
jclass cls = | ||
env->FindClass("com/k2fsa/sherpa/onnx/OfflineSpeakerDiarizationSegment"); | ||
|
||
jobjectArray obj_arr = | ||
(jobjectArray)env->NewObjectArray(segments.size(), cls, nullptr); | ||
|
||
jmethodID constructor = env->GetMethodID(cls, "<init>", "(FFI)V"); | ||
|
||
for (int32_t i = 0; i != segments.size(); ++i) { | ||
const auto &s = segments[i]; | ||
jobject segment = | ||
env->NewObject(cls, constructor, s.Start(), s.End(), s.Speaker()); | ||
env->SetObjectArrayElement(obj_arr, i, segment); | ||
} | ||
|
||
return obj_arr; | ||
} | ||
|
||
SHERPA_ONNX_EXTERN_C | ||
JNIEXPORT jobjectArray JNICALL | ||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_process( | ||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) { | ||
auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr); | ||
|
||
jfloat *p = env->GetFloatArrayElements(samples, nullptr); | ||
jsize n = env->GetArrayLength(samples); | ||
auto segments = sd->Process(p, n).SortByStartTime(); | ||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); | ||
|
||
return ProcessImpl(env, segments); | ||
} | ||
|
||
SHERPA_ONNX_EXTERN_C | ||
JNIEXPORT jobjectArray JNICALL | ||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_processWithCallback( | ||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, | ||
jobject callback, jlong arg) { | ||
std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper = | ||
[env, callback](int32_t num_processed_chunks, int32_t num_total_chunks, | ||
void *data) -> int { | ||
jclass cls = env->GetObjectClass(callback); | ||
|
||
jmethodID mid = env->GetMethodID(cls, "invoke", "(IIJ)Ljava/lang/Integer;"); | ||
if (mid == nullptr) { | ||
SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it."); | ||
return 0; | ||
} | ||
|
||
jobject ret = env->CallObjectMethod(callback, mid, num_processed_chunks, | ||
num_total_chunks, (jlong)data); | ||
jclass jklass = env->GetObjectClass(ret); | ||
jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I"); | ||
return env->CallIntMethod(ret, int_value_mid); | ||
}; | ||
|
||
auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr); | ||
|
||
jfloat *p = env->GetFloatArrayElements(samples, nullptr); | ||
jsize n = env->GetArrayLength(samples); | ||
auto segments = | ||
sd->Process(p, n, callback_wrapper, (void *)arg).SortByStartTime(); | ||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); | ||
|
||
return ProcessImpl(env, segments); | ||
} | ||
|
||
SHERPA_ONNX_EXTERN_C | ||
JNIEXPORT jint JNICALL | ||
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_getSampleRate( | ||
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { | ||
return reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr) | ||
->SampleRate(); | ||
} |
Oops, something went wrong.