Skip to content

Commit

Permalink
update kotlin api for better release native object and add user-frien…
Browse files Browse the repository at this point in the history
…dly apis. (#1275)
  • Loading branch information
fbzhong authored Aug 22, 2024
1 parent 5a2aa11 commit d8001d6
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 15 deletions.
7 changes: 5 additions & 2 deletions sherpa-onnx/kotlin-api/KeywordSpotter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class KeywordSpotter(
assetManager: AssetManager? = null,
val config: KeywordSpotterConfig,
) {
private val ptr: Long
private var ptr: Long

init {
ptr = if (assetManager != null) {
Expand All @@ -35,7 +35,10 @@ class KeywordSpotter(
}

protected fun finalize() {
delete(ptr)
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}

fun release() = finalize()
Expand Down
7 changes: 5 additions & 2 deletions sherpa-onnx/kotlin-api/OfflinePunctuation.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class OfflinePunctuation(
assetManager: AssetManager? = null,
config: OfflinePunctuationConfig,
) {
private val ptr: Long
private var ptr: Long

init {
ptr = if (assetManager != null) {
Expand All @@ -29,7 +29,10 @@ class OfflinePunctuation(
}

protected fun finalize() {
delete(ptr)
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}

fun release() = finalize()
Expand Down
16 changes: 13 additions & 3 deletions sherpa-onnx/kotlin-api/OfflineRecognizer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class OfflineRecognizer(
assetManager: AssetManager? = null,
config: OfflineRecognizerConfig,
) {
private val ptr: Long
private var ptr: Long

init {
ptr = if (assetManager != null) {
Expand All @@ -83,7 +83,10 @@ class OfflineRecognizer(
}

protected fun finalize() {
delete(ptr)
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}

fun release() = finalize()
Expand All @@ -102,7 +105,14 @@ class OfflineRecognizer(
val lang = objArray[3] as String
val emotion = objArray[4] as String
val event = objArray[5] as String
return OfflineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps, lang = lang, emotion = emotion, event = event)
return OfflineRecognizerResult(
text = text,
tokens = tokens,
timestamps = timestamps,
lang = lang,
emotion = emotion,
event = event
)
}

fun decode(stream: OfflineStream) = decode(ptr, stream.ptr)
Expand Down
8 changes: 8 additions & 0 deletions sherpa-onnx/kotlin-api/OfflineStream.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ class OfflineStream(var ptr: Long) {

fun release() = finalize()

fun use(block: (OfflineStream) -> Unit) {
try {
block(this)
} finally {
release()
}
}

private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun delete(ptr: Long)

Expand Down
9 changes: 6 additions & 3 deletions sherpa-onnx/kotlin-api/OnlineRecognizer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineModelConfig,
var lmConfig: OnlineLMConfig = OnlineLMConfig(),
var ctcFstDecoderConfig : OnlineCtcFstDecoderConfig = OnlineCtcFstDecoderConfig(),
var ctcFstDecoderConfig: OnlineCtcFstDecoderConfig = OnlineCtcFstDecoderConfig(),
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
Expand All @@ -85,7 +85,7 @@ class OnlineRecognizer(
assetManager: AssetManager? = null,
val config: OnlineRecognizerConfig,
) {
private val ptr: Long
private var ptr: Long

init {
ptr = if (assetManager != null) {
Expand All @@ -96,7 +96,10 @@ class OnlineRecognizer(
}

protected fun finalize() {
delete(ptr)
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}

fun release() = finalize()
Expand Down
9 changes: 9 additions & 0 deletions sherpa-onnx/kotlin-api/OnlineStream.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@ class OnlineStream(var ptr: Long = 0) {

fun release() = finalize()

fun use(block: (OnlineStream) -> Unit) {
try {
block(this)
} finally {
release()
}
}

private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun delete(ptr: Long)


companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
Expand Down
18 changes: 13 additions & 5 deletions sherpa-onnx/kotlin-api/Vad.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ data class VadModelConfig(
var debug: Boolean = false,
)

class SpeechSegment(val start: Int, val samples: FloatArray)

class Vad(
assetManager: AssetManager? = null,
var config: VadModelConfig,
) {
private val ptr: Long
private var ptr: Long

init {
if (assetManager != null) {
Expand All @@ -34,17 +36,23 @@ class Vad(
}

protected fun finalize() {
delete(ptr)
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}

fun release() = finalize()

fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples)

fun empty(): Boolean = empty(ptr)
fun pop() = pop(ptr)

// return an array containing
// [start: Int, samples: FloatArray]
fun front() = front(ptr)
fun front(): SpeechSegment {
val segment = front(ptr)
return SpeechSegment(segment[0] as Int, segment[1] as FloatArray)
}

fun clear() = clear(ptr)

Expand Down
41 changes: 41 additions & 0 deletions sherpa-onnx/kotlin-api/WaveReader.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,49 @@ package com.k2fsa.sherpa.onnx

import android.content.res.AssetManager

data class WaveData(
val samples: FloatArray,
val sampleRate: Int,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as WaveData

if (!samples.contentEquals(other.samples)) return false
if (sampleRate != other.sampleRate) return false

return true
}

override fun hashCode(): Int {
var result = samples.contentHashCode()
result = 31 * result + sampleRate
return result
}
}

class WaveReader {
companion object {

fun readWave(
assetManager: AssetManager,
filename: String,
): WaveData {
return readWaveFromAsset(assetManager, filename).let {
WaveData(it[0] as FloatArray, it[1] as Int)
}
}

fun readWave(
filename: String,
): WaveData {
return readWaveFromFile(filename).let {
WaveData(it[0] as FloatArray, it[1] as Int)
}
}

// Read a mono wave file asset
// The returned array has two entries:
// - the first entry contains an 1-D float array
Expand Down

0 comments on commit d8001d6

Please sign in to comment.