Skip to content

Commit

Permalink
Kotlin API for speaker diarization (#1415)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 11, 2024
1 parent eefc172 commit 2d412b1
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 1 deletion.
1 change: 1 addition & 0 deletions kotlin-api-examples/OfflineSpeakerDiarization.kt
31 changes: 31 additions & 0 deletions kotlin-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,37 @@ function testPunctuation() {
java -Djava.library.path=../build/lib -jar $out_filename
}

function testOfflineSpeakerDiarization() {
if [ ! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
fi

if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
fi

if [ ! -f ./0-four-speakers-zh.wav ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
fi

out_filename=test_offline_speaker_diarization.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_offline_speaker_diarization.kt \
OfflineSpeakerDiarization.kt \
Speaker.kt \
OnlineStream.kt \
WaveReader.kt \
faked-asset-manager.kt \
faked-log.kt

ls -lh $out_filename

java -Djava.library.path=../build/lib -jar $out_filename
}

testOfflineSpeakerDiarization
testSpeakerEmbeddingExtractor
testOnlineAsr
testTts
Expand Down
53 changes: 53 additions & 0 deletions kotlin-api-examples/test_offline_speaker_diarization.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.k2fsa.sherpa.onnx

fun main() {
testOfflineSpeakerDiarization()
}

fun callback(numProcessedChunks: Int, numTotalChunks: Int, arg: Long): Int {
val progress = numProcessedChunks.toFloat() / numTotalChunks * 100
val s = "%.2f".format(progress)
println("Progress: ${s}%");

return 0
}

fun testOfflineSpeakerDiarization() {
var config = OfflineSpeakerDiarizationConfig(
segmentation=OfflineSpeakerSegmentationModelConfig(
pyannote=OfflineSpeakerSegmentationPyannoteModelConfig("./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"),
),
embedding=SpeakerEmbeddingExtractorConfig(
model="./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx",
),

// The test wave file ./0-four-speakers-zh.wav contains four speakers, so
// we use numClusters=4 here. If you don't know the number of speakers
// in the test wave file, please set the threshold like below.
//
// clustering=FastClusteringConfig(threshold=0.5),
//
// WARNING: You need to tune threshold by yourself.
// A larger threshold leads to fewer clusters, i.e., few speakers.
// A smaller threshold leads to more clusters, i.e., more speakers.
//
clustering=FastClusteringConfig(numClusters=4),
)

val sd = OfflineSpeakerDiarization(config=config)

val waveData = WaveReader.readWave(
filename = "./0-four-speakers-zh.wav",
)

if (sd.sampleRate() != waveData.sampleRate) {
println("Expected sample rate: ${sd.sampleRate()}, given: ${waveData.sampleRate}")
return
}

// val segments = sd.process(waveData.samples) // this one is also ok
val segments = sd.processWithCallback(waveData.samples, callback=::callback)
for (segment in segments) {
println("${segment.start} -- ${segment.end} speaker_${segment.speaker}")
}
}
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-speaker-diarization-result.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class OfflineSpeakerDiarizationResult {
std::vector<std::vector<OfflineSpeakerDiarizationSegment>> SortBySpeaker()
const;

public:
private:
std::vector<OfflineSpeakerDiarizationSegment> segments_;
};

Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ if(SHERPA_ONNX_ENABLE_TTS)
)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND sources
offline-speaker-diarization.cc
)
endif()

add_library(sherpa-onnx-jni ${sources})

target_compile_definitions(sherpa-onnx-jni PRIVATE SHERPA_ONNX_BUILD_SHARED_LIBS=1)
Expand Down
219 changes: 219 additions & 0 deletions sherpa-onnx/jni/offline-speaker-diarization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// sherpa-onnx/jni/offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-speaker-diarization.h"

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"

namespace sherpa_onnx {

static OfflineSpeakerDiarizationConfig GetOfflineSpeakerDiarizationConfig(
JNIEnv *env, jobject config) {
OfflineSpeakerDiarizationConfig ans;

jclass cls = env->GetObjectClass(config);
jfieldID fid;

//---------- segmentation ----------
fid = env->GetFieldID(
cls, "segmentation",
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationModelConfig;");
jobject segmentation_config = env->GetObjectField(config, fid);
jclass segmentation_config_cls = env->GetObjectClass(segmentation_config);

fid = env->GetFieldID(
segmentation_config_cls, "pyannote",
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationPyannoteModelConfig;");
jobject pyannote_config = env->GetObjectField(segmentation_config, fid);
jclass pyannote_config_cls = env->GetObjectClass(pyannote_config);

fid = env->GetFieldID(pyannote_config_cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(pyannote_config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.segmentation.pyannote.model = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(segmentation_config_cls, "numThreads", "I");
ans.segmentation.num_threads = env->GetIntField(segmentation_config, fid);

fid = env->GetFieldID(segmentation_config_cls, "debug", "Z");
ans.segmentation.debug = env->GetBooleanField(segmentation_config, fid);

fid = env->GetFieldID(segmentation_config_cls, "provider",
"Ljava/lang/String;");
s = (jstring)env->GetObjectField(segmentation_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.segmentation.provider = p;
env->ReleaseStringUTFChars(s, p);

//---------- embedding ----------
fid = env->GetFieldID(
cls, "embedding",
"Lcom/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig;");
jobject embedding_config = env->GetObjectField(config, fid);
jclass embedding_config_cls = env->GetObjectClass(embedding_config);

fid = env->GetFieldID(embedding_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(embedding_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.embedding.model = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(embedding_config_cls, "numThreads", "I");
ans.embedding.num_threads = env->GetIntField(embedding_config, fid);

fid = env->GetFieldID(embedding_config_cls, "debug", "Z");
ans.embedding.debug = env->GetBooleanField(embedding_config, fid);

fid = env->GetFieldID(embedding_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(embedding_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.embedding.provider = p;
env->ReleaseStringUTFChars(s, p);

//---------- clustering ----------
fid = env->GetFieldID(cls, "clustering",
"Lcom/k2fsa/sherpa/onnx/FastClusteringConfig;");
jobject clustering_config = env->GetObjectField(config, fid);
jclass clustering_config_cls = env->GetObjectClass(clustering_config);

fid = env->GetFieldID(clustering_config_cls, "numClusters", "I");
ans.clustering.num_clusters = env->GetIntField(clustering_config, fid);

fid = env->GetFieldID(clustering_config_cls, "threshold", "F");
ans.clustering.threshold = env->GetFloatField(clustering_config, fid);

// its own fields
fid = env->GetFieldID(cls, "minDurationOn", "F");
ans.min_duration_on = env->GetFloatField(config, fid);

fid = env->GetFieldID(cls, "minDurationOff", "F");
ans.min_duration_off = env->GetFloatField(config, fid);

return ans;
}

} // namespace sherpa_onnx

SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
return 0;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}

auto sd = new sherpa_onnx::OfflineSpeakerDiarization(config);

return (jlong)sd;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_setConfig(
JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) {
auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());

auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
sd->SetConfig(config);
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_delete(JNIEnv * /*env*/,
jobject /*obj*/,
jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
}

static jobjectArray ProcessImpl(
JNIEnv *env,
const std::vector<sherpa_onnx::OfflineSpeakerDiarizationSegment>
&segments) {
jclass cls =
env->FindClass("com/k2fsa/sherpa/onnx/OfflineSpeakerDiarizationSegment");

jobjectArray obj_arr =
(jobjectArray)env->NewObjectArray(segments.size(), cls, nullptr);

jmethodID constructor = env->GetMethodID(cls, "<init>", "(FFI)V");

for (int32_t i = 0; i != segments.size(); ++i) {
const auto &s = segments[i];
jobject segment =
env->NewObject(cls, constructor, s.Start(), s.End(), s.Speaker());
env->SetObjectArrayElement(obj_arr, i, segment);
}

return obj_arr;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_process(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);

jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
auto segments = sd->Process(p, n).SortByStartTime();
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);

return ProcessImpl(env, segments);
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_processWithCallback(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jobject callback, jlong arg) {
std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
[env, callback](int32_t num_processed_chunks, int32_t num_total_chunks,
void *data) -> int {
jclass cls = env->GetObjectClass(callback);

jmethodID mid = env->GetMethodID(cls, "invoke", "(IIJ)Ljava/lang/Integer;");
if (mid == nullptr) {
SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it.");
return 0;
}

jobject ret = env->CallObjectMethod(callback, mid, num_processed_chunks,
num_total_chunks, (jlong)data);
jclass jklass = env->GetObjectClass(ret);
jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I");
return env->CallIntMethod(ret, int_value_mid);
};

auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);

jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
auto segments =
sd->Process(p, n, callback_wrapper, (void *)arg).SortByStartTime();
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);

return ProcessImpl(env, segments);
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_getSampleRate(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
return reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr)
->SampleRate();
}
Loading

0 comments on commit 2d412b1

Please sign in to comment.