diff --git a/.github/workflows/apk-asr.yaml b/.github/workflows/apk-asr.yaml new file mode 100644 index 000000000..3fdb2baac --- /dev/null +++ b/.github/workflows/apk-asr.yaml @@ -0,0 +1,174 @@ +name: apk-asr + +on: + push: + tags: + - '*' + + workflow_dispatch: + +concurrency: + group: apk-asr-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: write + +jobs: + apk_asr: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + runs-on: ${{ matrix.os }} + name: apk for asr ${{ matrix.index }}/${{ matrix.total }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + total: ["1"] + index: ["0"] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + # https://github.com/actions/setup-java + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' # See 'Supported distributions' for available options + java-version: '21' + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + key: ${{ matrix.os }}-android + + - name: Display NDK HOME + shell: bash + run: | + echo "ANDROID_NDK_LATEST_HOME: ${ANDROID_NDK_LATEST_HOME}" + ls -lh ${ANDROID_NDK_LATEST_HOME} + + - name: Install Python dependencies + shell: bash + run: | + python3 -m pip install --upgrade pip jinja2 + + - name: Setup build tool version variable + shell: bash + run: | + echo "---" + ls -lh /usr/local/lib/android/ + echo "---" + + ls -lh /usr/local/lib/android/sdk + echo "---" + + ls -lh /usr/local/lib/android/sdk/build-tools + echo "---" + + BUILD_TOOL_VERSION=$(ls /usr/local/lib/android/sdk/build-tools/ | tail -n 1) + echo "BUILD_TOOL_VERSION=$BUILD_TOOL_VERSION" >> $GITHUB_ENV + echo "Last build tool version is: $BUILD_TOOL_VERSION" + + - name: Generate build script + shell: bash + run: | + cd scripts/apk + + total=${{ matrix.total }} + index=${{ matrix.index }} + + ./generate-asr-apk-script.py --total $total --index $index + + chmod +x build-apk-asr.sh + mv -v ./build-apk-asr.sh ../.. + + - name: build APK + shell: bash + run: | + export CMAKE_CXX_COMPILER_LAUNCHER=ccache + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" + cmake --version + + export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME + ./build-apk-asr.sh + + - name: Display APK + shell: bash + run: | + ls -lh ./apks/ + du -h -d1 . + + # https://github.com/marketplace/actions/sign-android-release + - uses: r0adkll/sign-android-release@v1 + name: Sign app APK + with: + releaseDirectory: ./apks + signingKeyBase64: ${{ secrets.ANDROID_SIGNING_KEY }} + alias: ${{ secrets.ANDROID_SIGNING_KEY_ALIAS }} + keyStorePassword: ${{ secrets.ANDROID_SIGNING_KEY_STORE_PASSWORD }} + env: + BUILD_TOOLS_VERSION: ${{ env.BUILD_TOOL_VERSION }} + + - name: Display APK after signing + shell: bash + run: | + ls -lh ./apks/ + du -h -d1 . + + - name: Rename APK after signing + shell: bash + run: | + cd apks + rm -fv signingKey.jks + rm -fv *.apk.idsig + rm -fv *-aligned.apk + + all_apks=$(ls -1 *-signed.apk) + echo "----" + echo $all_apks + echo "----" + for apk in ${all_apks[@]}; do + n=$(echo $apk | sed -e s/-signed//) + mv -v $apk $n + done + + cd .. + + ls -lh ./apks/ + du -h -d1 . + + - name: Display APK after rename + shell: bash + run: | + ls -lh ./apks/ + du -h -d1 . + + - name: Publish to huggingface + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 20 + timeout_seconds: 200 + shell: bash + command: | + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + rm -rf huggingface + export GIT_LFS_SKIP_SMUDGE=1 + + git clone https://huggingface.co/csukuangfj/sherpa-onnx-apk huggingface + cd huggingface + git fetch + git pull + git merge -m "merge remote" --ff origin main + + mkdir -p asr + cp -v ../apks/*.apk ./asr/ + git status + git lfs track "*.apk" + git add . + git commit -m "add more apks" + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-apk main diff --git a/.gitignore b/.gitignore index 48114405e..83ca941d2 100644 --- a/.gitignore +++ b/.gitignore @@ -95,3 +95,4 @@ sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 spoken-language-identification-test-wavs my-release-key* vits-zh-hf-fanchen-C +sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 diff --git a/android/SherpaOnnx/app/src/main/AndroidManifest.xml b/android/SherpaOnnx/app/src/main/AndroidManifest.xml index 935fb0e95..c0c79ddd3 100644 --- a/android/SherpaOnnx/app/src/main/AndroidManifest.xml +++ b/android/SherpaOnnx/app/src/main/AndroidManifest.xml @@ -16,6 +16,7 @@ tools:targetApi="31"> diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt new file mode 120000 index 000000000..952fae878 --- /dev/null +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt \ No newline at end of file diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 1619f3b27..e4eb5e276 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -12,16 +12,19 @@ import android.widget.Button import android.widget.TextView import androidx.appcompat.app.AppCompatActivity import androidx.core.app.ActivityCompat -import com.k2fsa.sherpa.onnx.* import kotlin.concurrent.thread private const val TAG = "sherpa-onnx" private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 +// To enable microphone in android emulator, use +// +// adb emu avd hostmicon + class MainActivity : AppCompatActivity() { private val permissions: Array = arrayOf(Manifest.permission.RECORD_AUDIO) - private lateinit var model: SherpaOnnx + private lateinit var recognizer: OnlineRecognizer private var audioRecord: AudioRecord? = null private lateinit var recordButton: Button private lateinit var textView: TextView @@ -87,7 +90,6 @@ class MainActivity : AppCompatActivity() { audioRecord!!.startRecording() recordButton.setText(R.string.stop) isRecording = true - model.reset(true) textView.text = "" lastText = "" idx = 0 @@ -108,6 +110,7 @@ class MainActivity : AppCompatActivity() { private fun processSamples() { Log.i(TAG, "processing samples") + val stream = recognizer.createStream() val interval = 0.1 // i.e., 100 ms val bufferSize = (interval * sampleRateInHz).toInt() // in samples @@ -117,29 +120,41 @@ class MainActivity : AppCompatActivity() { val ret = audioRecord?.read(buffer, 0, buffer.size) if (ret != null && ret > 0) { val samples = FloatArray(ret) { buffer[it] / 32768.0f } - model.acceptWaveform(samples, sampleRate=sampleRateInHz) - while (model.isReady()) { - model.decode() + stream.acceptWaveform(samples, sampleRate = sampleRateInHz) + while (recognizer.isReady(stream)) { + recognizer.decode(stream) } - val isEndpoint = model.isEndpoint() - val text = model.text + val isEndpoint = recognizer.isEndpoint(stream) + var text = recognizer.getResult(stream).text + + // For streaming parformer, we need to manually add some + // paddings so that it has enough right context to + // recognize the last word of this segment + if (isEndpoint && recognizer.config.modelConfig.paraformer.encoder.isNotBlank()) { + val tailPaddings = FloatArray((0.8 * sampleRateInHz).toInt()) + stream.acceptWaveform(tailPaddings, sampleRate = sampleRateInHz) + while (recognizer.isReady(stream)) { + recognizer.decode(stream) + } + text = recognizer.getResult(stream).text + } - var textToDisplay = lastText; + var textToDisplay = lastText - if(text.isNotBlank()) { - if (lastText.isBlank()) { - textToDisplay = "${idx}: ${text}" + if (text.isNotBlank()) { + textToDisplay = if (lastText.isBlank()) { + "${idx}: $text" } else { - textToDisplay = "${lastText}\n${idx}: ${text}" + "${lastText}\n${idx}: $text" } } if (isEndpoint) { - model.reset() + recognizer.reset(stream) if (text.isNotBlank()) { - lastText = "${lastText}\n${idx}: ${text}" - textToDisplay = lastText; + lastText = "${lastText}\n${idx}: $text" + textToDisplay = lastText idx += 1 } } @@ -149,6 +164,7 @@ class MainActivity : AppCompatActivity() { } } } + stream.release() } private fun initMicrophone(): Boolean { @@ -180,7 +196,7 @@ class MainActivity : AppCompatActivity() { // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html // for a list of available models val type = 0 - println("Select model type ${type}") + Log.i(TAG, "Select model type $type") val config = OnlineRecognizerConfig( featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), modelConfig = getModelConfig(type = type)!!, @@ -189,7 +205,7 @@ class MainActivity : AppCompatActivity() { enableEndpoint = true, ) - model = SherpaOnnx( + recognizer = OnlineRecognizer( assetManager = application.assets, config = config, ) diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.kt new file mode 120000 index 000000000..5bb19ee10 --- /dev/null +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineRecognizer.kt \ No newline at end of file diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineStream.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineStream.kt new file mode 120000 index 000000000..d4518b89b --- /dev/null +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineStream.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineStream.kt \ No newline at end of file diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt deleted file mode 100644 index dca399840..000000000 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2023 Xiaomi Corporation -package com.k2fsa.sherpa.onnx - -import android.content.res.AssetManager - -class WaveReader { - companion object { - // Read a mono wave file asset - // The returned array has two entries: - // - the first entry contains an 1-D float array - // - the second entry is the sample rate - external fun readWaveFromAsset( - assetManager: AssetManager, - filename: String, - ): Array - - // Read a mono wave file from disk - // The returned array has two entries: - // - the first entry contains an 1-D float array - // - the second entry is the sample rate - external fun readWaveFromFile( - filename: String, - ): Array - - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt new file mode 120000 index 000000000..05c8fb246 --- /dev/null +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/WaveReader.kt \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/src/main/AndroidManifest.xml b/android/SherpaOnnx2Pass/app/src/main/AndroidManifest.xml index 2a440df14..0cbbfafe8 100644 --- a/android/SherpaOnnx2Pass/app/src/main/AndroidManifest.xml +++ b/android/SherpaOnnx2Pass/app/src/main/AndroidManifest.xml @@ -16,6 +16,7 @@ tools:targetApi="31"> @@ -29,4 +30,4 @@ - \ No newline at end of file + diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt new file mode 120000 index 000000000..952fae878 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 012c0db5e..596d03e09 100644 --- a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -17,11 +17,13 @@ import kotlin.concurrent.thread private const val TAG = "sherpa-onnx" private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 +// adb emu avd hostmicon +// to enable microphone inside the emulator class MainActivity : AppCompatActivity() { private val permissions: Array = arrayOf(Manifest.permission.RECORD_AUDIO) - private lateinit var onlineRecognizer: SherpaOnnx - private lateinit var offlineRecognizer: SherpaOnnxOffline + private lateinit var onlineRecognizer: OnlineRecognizer + private lateinit var offlineRecognizer: OfflineRecognizer private var audioRecord: AudioRecord? = null private lateinit var recordButton: Button private lateinit var textView: TextView @@ -93,7 +95,6 @@ class MainActivity : AppCompatActivity() { audioRecord!!.startRecording() recordButton.setText(R.string.stop) isRecording = true - onlineRecognizer.reset(true) samplesBuffer.clear() textView.text = "" lastText = "" @@ -115,6 +116,7 @@ class MainActivity : AppCompatActivity() { private fun processSamples() { Log.i(TAG, "processing samples") + val stream = onlineRecognizer.createStream() val interval = 0.1 // i.e., 100 ms val bufferSize = (interval * sampleRateInHz).toInt() // in samples @@ -126,29 +128,29 @@ class MainActivity : AppCompatActivity() { val samples = FloatArray(ret) { buffer[it] / 32768.0f } samplesBuffer.add(samples) - onlineRecognizer.acceptWaveform(samples, sampleRate = sampleRateInHz) - while (onlineRecognizer.isReady()) { - onlineRecognizer.decode() + stream.acceptWaveform(samples, sampleRate = sampleRateInHz) + while (onlineRecognizer.isReady(stream)) { + onlineRecognizer.decode(stream) } - val isEndpoint = onlineRecognizer.isEndpoint() + val isEndpoint = onlineRecognizer.isEndpoint(stream) var textToDisplay = lastText - var text = onlineRecognizer.text + var text = onlineRecognizer.getResult(stream).text if (text.isNotBlank()) { - if (lastText.isBlank()) { + textToDisplay = if (lastText.isBlank()) { // textView.text = "${idx}: ${text}" - textToDisplay = "${idx}: ${text}" + "${idx}: $text" } else { - textToDisplay = "${lastText}\n${idx}: ${text}" + "${lastText}\n${idx}: $text" } } if (isEndpoint) { - onlineRecognizer.reset() + onlineRecognizer.reset(stream) if (text.isNotBlank()) { text = runSecondPass() - lastText = "${lastText}\n${idx}: ${text}" + lastText = "${lastText}\n${idx}: $text" idx += 1 } else { samplesBuffer.clear() @@ -160,6 +162,7 @@ class MainActivity : AppCompatActivity() { } } } + stream.release() } private fun initMicrophone(): Boolean { @@ -190,8 +193,8 @@ class MainActivity : AppCompatActivity() { // Please change getModelConfig() to add new models // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html // for a list of available models - val firstType = 1 - println("Select model type ${firstType} for the first pass") + val firstType = 9 + Log.i(TAG, "Select model type $firstType for the first pass") val config = OnlineRecognizerConfig( featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), modelConfig = getModelConfig(type = firstType)!!, @@ -199,7 +202,7 @@ class MainActivity : AppCompatActivity() { enableEndpoint = true, ) - onlineRecognizer = SherpaOnnx( + onlineRecognizer = OnlineRecognizer( assetManager = application.assets, config = config, ) @@ -209,15 +212,15 @@ class MainActivity : AppCompatActivity() { // Please change getOfflineModelConfig() to add new models // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html // for a list of available models - val secondType = 1 - println("Select model type ${secondType} for the second pass") + val secondType = 0 + Log.i(TAG, "Select model type $secondType for the second pass") val config = OfflineRecognizerConfig( featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), modelConfig = getOfflineModelConfig(type = secondType)!!, ) - offlineRecognizer = SherpaOnnxOffline( + offlineRecognizer = OfflineRecognizer( assetManager = application.assets, config = config, ) @@ -244,8 +247,15 @@ class MainActivity : AppCompatActivity() { val n = maxOf(0, samples.size - 8000) samplesBuffer.clear() - samplesBuffer.add(samples.sliceArray(n..samples.size-1)) + samplesBuffer.add(samples.sliceArray(n until samples.size)) - return offlineRecognizer.decode(samples.sliceArray(0..n), sampleRateInHz) + val stream = offlineRecognizer.createStream() + stream.acceptWaveform(samples.sliceArray(0..n), sampleRateInHz) + offlineRecognizer.decode(stream) + val result = offlineRecognizer.getResult(stream) + + stream.release() + + return result.text } } diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineRecognizer.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineRecognizer.kt new file mode 120000 index 000000000..faa3ab4ac --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineRecognizer.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineRecognizer.kt \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineStream.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineStream.kt new file mode 120000 index 000000000..2a3aff864 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineStream.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.kt new file mode 120000 index 000000000..5bb19ee10 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineRecognizer.kt \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineStream.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineStream.kt new file mode 120000 index 000000000..d4518b89b --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineStream.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineStream.kt \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt deleted file mode 100644 index 601ecf83f..000000000 --- a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ /dev/null @@ -1,404 +0,0 @@ -package com.k2fsa.sherpa.onnx - -import android.content.res.AssetManager - -data class EndpointRule( - var mustContainNonSilence: Boolean, - var minTrailingSilence: Float, - var minUtteranceLength: Float, -) - -data class EndpointConfig( - var rule1: EndpointRule = EndpointRule(false, 2.0f, 0.0f), - var rule2: EndpointRule = EndpointRule(true, 1.2f, 0.0f), - var rule3: EndpointRule = EndpointRule(false, 0.0f, 20.0f) -) - -data class OnlineTransducerModelConfig( - var encoder: String = "", - var decoder: String = "", - var joiner: String = "", -) - -data class OnlineParaformerModelConfig( - var encoder: String = "", - var decoder: String = "", -) - -data class OnlineZipformer2CtcModelConfig( - var model: String = "", -) - -data class OnlineModelConfig( - var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(), - var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(), - var zipformer2Ctc: OnlineZipformer2CtcModelConfig = OnlineZipformer2CtcModelConfig(), - var tokens: String, - var numThreads: Int = 1, - var debug: Boolean = false, - var provider: String = "cpu", - var modelType: String = "", -) - -data class OnlineLMConfig( - var model: String = "", - var scale: Float = 0.5f, -) - -data class FeatureConfig( - var sampleRate: Int = 16000, - var featureDim: Int = 80, -) - -data class OnlineRecognizerConfig( - var featConfig: FeatureConfig = FeatureConfig(), - var modelConfig: OnlineModelConfig, - var lmConfig: OnlineLMConfig = OnlineLMConfig(), - var endpointConfig: EndpointConfig = EndpointConfig(), - var enableEndpoint: Boolean = true, - var decodingMethod: String = "greedy_search", - var maxActivePaths: Int = 4, - var hotwordsFile: String = "", - var hotwordsScore: Float = 1.5f, -) - -data class OfflineTransducerModelConfig( - var encoder: String = "", - var decoder: String = "", - var joiner: String = "", -) - -data class OfflineParaformerModelConfig( - var model: String = "", -) - -data class OfflineWhisperModelConfig( - var encoder: String = "", - var decoder: String = "", - var language: String = "en", // Used with multilingual model - var task: String = "transcribe", // transcribe or translate - var tailPaddings: Int = 1000, // Padding added at the end of the samples -) - -data class OfflineModelConfig( - var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(), - var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(), - var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(), - var numThreads: Int = 1, - var debug: Boolean = false, - var provider: String = "cpu", - var modelType: String = "", - var tokens: String, -) - -data class OfflineRecognizerConfig( - var featConfig: FeatureConfig = FeatureConfig(), - var modelConfig: OfflineModelConfig, - // var lmConfig: OfflineLMConfig(), // TODO(fangjun): enable it - var decodingMethod: String = "greedy_search", - var maxActivePaths: Int = 4, - var hotwordsFile: String = "", - var hotwordsScore: Float = 1.5f, -) - -class SherpaOnnx( - assetManager: AssetManager? = null, - var config: OnlineRecognizerConfig, -) { - private val ptr: Long - - init { - if (assetManager != null) { - ptr = new(assetManager, config) - } else { - ptr = newFromFile(config) - } - } - - protected fun finalize() { - delete(ptr) - } - - fun acceptWaveform(samples: FloatArray, sampleRate: Int) = - acceptWaveform(ptr, samples, sampleRate) - - fun inputFinished() = inputFinished(ptr) - fun reset(recreate: Boolean = false, hotwords: String = "") = reset(ptr, recreate, hotwords) - fun decode() = decode(ptr) - fun isEndpoint(): Boolean = isEndpoint(ptr) - fun isReady(): Boolean = isReady(ptr) - - val text: String - get() = getText(ptr) - - val tokens: Array - get() = getTokens(ptr) - - private external fun delete(ptr: Long) - - private external fun new( - assetManager: AssetManager, - config: OnlineRecognizerConfig, - ): Long - - private external fun newFromFile( - config: OnlineRecognizerConfig, - ): Long - - private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) - private external fun inputFinished(ptr: Long) - private external fun getText(ptr: Long): String - private external fun reset(ptr: Long, recreate: Boolean, hotwords: String) - private external fun decode(ptr: Long) - private external fun isEndpoint(ptr: Long): Boolean - private external fun isReady(ptr: Long): Boolean - private external fun getTokens(ptr: Long): Array - - companion object { - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} - -class SherpaOnnxOffline( - assetManager: AssetManager? = null, - var config: OfflineRecognizerConfig, -) { - private val ptr: Long - - init { - if (assetManager != null) { - ptr = new(assetManager, config) - } else { - ptr = newFromFile(config) - } - } - - protected fun finalize() { - delete(ptr) - } - - fun decode(samples: FloatArray, sampleRate: Int) = decode(ptr, samples, sampleRate) - - private external fun delete(ptr: Long) - - private external fun new( - assetManager: AssetManager, - config: OfflineRecognizerConfig, - ): Long - - private external fun newFromFile( - config: OfflineRecognizerConfig, - ): Long - - private external fun decode(ptr: Long, samples: FloatArray, sampleRate: Int): String - - companion object { - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} - -fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig { - return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim) -} - -/* -Please see -https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html -for a list of pre-trained models. - -We only add a few here. Please change the following code -to add your own. (It should be straightforward to add a new model -by following the code) - -@param type -0 - csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23 (Chinese) - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-zh-14m-2023-02-23 - encoder/joiner int8, decoder float32 - -1 - csukuangfj/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 (English) - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-en-20m-2023-02-17-english - encoder/joiner int8, decoder fp32 - - */ -fun getModelConfig(type: Int): OnlineModelConfig? { - when (type) { - 0 -> { - val modelDir = "sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23" - return OnlineModelConfig( - transducer = OnlineTransducerModelConfig( - encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", - decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", - joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", - ), - tokens = "$modelDir/tokens.txt", - modelType = "zipformer", - ) - } - - 1 -> { - val modelDir = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17" - return OnlineModelConfig( - transducer = OnlineTransducerModelConfig( - encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", - decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", - joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", - ), - tokens = "$modelDir/tokens.txt", - modelType = "zipformer", - ) - } - } - return null -} - -/* -Please see -https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html -for a list of pre-trained models. - -We only add a few here. Please change the following code -to add your own LM model. (It should be straightforward to train a new NN LM model -by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py) - -@param type -0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english - */ -fun getOnlineLMConfig(type: Int): OnlineLMConfig { - when (type) { - 0 -> { - val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" - return OnlineLMConfig( - model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx", - scale = 0.5f, - ) - } - } - return OnlineLMConfig() -} - -// for English models, use a small value for rule2.minTrailingSilence, e.g., 0.8 -fun getEndpointConfig(): EndpointConfig { - return EndpointConfig( - rule1 = EndpointRule(false, 2.4f, 0.0f), - rule2 = EndpointRule(true, 0.8f, 0.0f), - rule3 = EndpointRule(false, 0.0f, 20.0f) - ) -} - -/* -Please see -https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html -for a list of pre-trained models. - -We only add a few here. Please change the following code -to add your own. (It should be straightforward to add a new model -by following the code) - -@param type - -0 - csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 (Chinese) - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese - int8 - -1 - icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04 (English) - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#icefall-asr-multidataset-pruned-transducer-stateless7-2023-05-04-english - encoder int8, decoder/joiner float32 - -2 - sherpa-onnx-whisper-tiny.en - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en - encoder int8, decoder int8 - -3 - sherpa-onnx-whisper-base.en - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en - encoder int8, decoder int8 - -4 - pkufool/icefall-asr-zipformer-wenetspeech-20230615 (Chinese) - https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#pkufool-icefall-asr-zipformer-wenetspeech-20230615-chinese - encoder/joiner int8, decoder fp32 - - */ -fun getOfflineModelConfig(type: Int): OfflineModelConfig? { - when (type) { - 0 -> { - val modelDir = "sherpa-onnx-paraformer-zh-2023-03-28" - return OfflineModelConfig( - paraformer = OfflineParaformerModelConfig( - model = "$modelDir/model.int8.onnx", - ), - tokens = "$modelDir/tokens.txt", - modelType = "paraformer", - ) - } - - 1 -> { - val modelDir = "icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04" - return OfflineModelConfig( - transducer = OfflineTransducerModelConfig( - encoder = "$modelDir/encoder-epoch-30-avg-4.int8.onnx", - decoder = "$modelDir/decoder-epoch-30-avg-4.onnx", - joiner = "$modelDir/joiner-epoch-30-avg-4.onnx", - ), - tokens = "$modelDir/tokens.txt", - modelType = "zipformer", - ) - } - - 2 -> { - val modelDir = "sherpa-onnx-whisper-tiny.en" - return OfflineModelConfig( - whisper = OfflineWhisperModelConfig( - encoder = "$modelDir/tiny.en-encoder.int8.onnx", - decoder = "$modelDir/tiny.en-decoder.int8.onnx", - ), - tokens = "$modelDir/tiny.en-tokens.txt", - modelType = "whisper", - ) - } - - 3 -> { - val modelDir = "sherpa-onnx-whisper-base.en" - return OfflineModelConfig( - whisper = OfflineWhisperModelConfig( - encoder = "$modelDir/base.en-encoder.int8.onnx", - decoder = "$modelDir/base.en-decoder.int8.onnx", - ), - tokens = "$modelDir/base.en-tokens.txt", - modelType = "whisper", - ) - } - - - 4 -> { - val modelDir = "icefall-asr-zipformer-wenetspeech-20230615" - return OfflineModelConfig( - transducer = OfflineTransducerModelConfig( - encoder = "$modelDir/encoder-epoch-12-avg-4.int8.onnx", - decoder = "$modelDir/decoder-epoch-12-avg-4.onnx", - joiner = "$modelDir/joiner-epoch-12-avg-4.int8.onnx", - ), - tokens = "$modelDir/tokens.txt", - modelType = "zipformer", - ) - } - - 5 -> { - val modelDir = "sherpa-onnx-zipformer-multi-zh-hans-2023-9-2" - return OfflineModelConfig( - transducer = OfflineTransducerModelConfig( - encoder = "$modelDir/encoder-epoch-20-avg-1.int8.onnx", - decoder = "$modelDir/decoder-epoch-20-avg-1.onnx", - joiner = "$modelDir/joiner-epoch-20-avg-1.int8.onnx", - ), - tokens = "$modelDir/tokens.txt", - modelType = "zipformer2", - ) - } - - } - return null -} diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt deleted file mode 100644 index 3060450d6..000000000 --- a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt +++ /dev/null @@ -1,28 +0,0 @@ -package com.k2fsa.sherpa.onnx - -import android.content.res.AssetManager - -class WaveReader { - companion object { - // Read a mono wave file asset - // The returned array has two entries: - // - the first entry contains an 1-D float array - // - the second entry is the sample rate - external fun readWaveFromAsset( - assetManager: AssetManager, - filename: String, - ): Array - - // Read a mono wave file from disk - // The returned array has two entries: - // - the first entry contains an 1-D float array - // - the second entry is the sample rate - external fun readWaveFromFile( - filename: String, - ): Array - - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} diff --git a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt deleted file mode 100644 index df897dbde..000000000 --- a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt +++ /dev/null @@ -1,188 +0,0 @@ -package com.k2fsa.sherpa.onnx - -import android.content.res.AssetManager - -const val TAG = "sherpa-onnx" - -data class OfflineZipformerAudioTaggingModelConfig( - var model: String = "", -) - -data class AudioTaggingModelConfig( - var zipformer: OfflineZipformerAudioTaggingModelConfig = OfflineZipformerAudioTaggingModelConfig(), - var ced: String = "", - var numThreads: Int = 1, - var debug: Boolean = false, - var provider: String = "cpu", -) - -data class AudioTaggingConfig( - var model: AudioTaggingModelConfig, - var labels: String, - var topK: Int = 5, -) - -data class AudioEvent( - val name: String, - val index: Int, - val prob: Float, -) - -class AudioTagging( - assetManager: AssetManager? = null, - config: AudioTaggingConfig, -) { - private var ptr: Long - - init { - ptr = if (assetManager != null) { - newFromAsset(assetManager, config) - } else { - newFromFile(config) - } - } - - protected fun finalize() { - if (ptr != 0L) { - delete(ptr) - ptr = 0 - } - } - - fun release() = finalize() - - fun createStream(): OfflineStream { - val p = createStream(ptr) - return OfflineStream(p) - } - - @Suppress("UNCHECKED_CAST") - fun compute(stream: OfflineStream, topK: Int = -1): ArrayList { - val events: Array = compute(ptr, stream.ptr, topK) - val ans = ArrayList() - - for (e in events) { - val p: Array = e as Array - ans.add( - AudioEvent( - name = p[0] as String, - index = p[1] as Int, - prob = p[2] as Float, - ) - ) - } - - return ans - } - - private external fun newFromAsset( - assetManager: AssetManager, - config: AudioTaggingConfig, - ): Long - - private external fun newFromFile( - config: AudioTaggingConfig, - ): Long - - private external fun delete(ptr: Long) - - private external fun createStream(ptr: Long): Long - - private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array - - companion object { - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} - -// please refer to -// https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models -// to download more models -// -// See also -// https://k2-fsa.github.io/sherpa/onnx/audio-tagging/ -fun getAudioTaggingConfig(type: Int, numThreads: Int = 1): AudioTaggingConfig? { - when (type) { - 0 -> { - val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15" - return AudioTaggingConfig( - model = AudioTaggingModelConfig( - zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"), - numThreads = numThreads, - debug = true, - ), - labels = "$modelDir/class_labels_indices.csv", - topK = 3, - ) - } - - 1 -> { - val modelDir = "sherpa-onnx-zipformer-audio-tagging-2024-04-09" - return AudioTaggingConfig( - model = AudioTaggingModelConfig( - zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"), - numThreads = numThreads, - debug = true, - ), - labels = "$modelDir/class_labels_indices.csv", - topK = 3, - ) - } - - 2 -> { - val modelDir = "sherpa-onnx-ced-tiny-audio-tagging-2024-04-19" - return AudioTaggingConfig( - model = AudioTaggingModelConfig( - ced = "$modelDir/model.int8.onnx", - numThreads = numThreads, - debug = true, - ), - labels = "$modelDir/class_labels_indices.csv", - topK = 3, - ) - } - - 3 -> { - val modelDir = "sherpa-onnx-ced-mini-audio-tagging-2024-04-19" - return AudioTaggingConfig( - model = AudioTaggingModelConfig( - ced = "$modelDir/model.int8.onnx", - numThreads = numThreads, - debug = true, - ), - labels = "$modelDir/class_labels_indices.csv", - topK = 3, - ) - } - - 4 -> { - val modelDir = "sherpa-onnx-ced-small-audio-tagging-2024-04-19" - return AudioTaggingConfig( - model = AudioTaggingModelConfig( - ced = "$modelDir/model.int8.onnx", - numThreads = numThreads, - debug = true, - ), - labels = "$modelDir/class_labels_indices.csv", - topK = 3, - ) - } - - 5 -> { - val modelDir = "sherpa-onnx-ced-base-audio-tagging-2024-04-19" - return AudioTaggingConfig( - model = AudioTaggingModelConfig( - ced = "$modelDir/model.int8.onnx", - numThreads = numThreads, - debug = true, - ), - labels = "$modelDir/class_labels_indices.csv", - topK = 3, - ) - } - } - - return null -} diff --git a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt new file mode 120000 index 000000000..176a8df8d --- /dev/null +++ b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt @@ -0,0 +1 @@ +../../../../../../../../../../../../sherpa-onnx/kotlin-api/AudioTagging.kt \ No newline at end of file diff --git a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/Home.kt b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/Home.kt index b2239cee2..a1edc2554 100644 --- a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/Home.kt +++ b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/Home.kt @@ -46,7 +46,6 @@ import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.sp import androidx.core.app.ActivityCompat import com.k2fsa.sherpa.onnx.AudioEvent -import com.k2fsa.sherpa.onnx.Tagger import kotlin.concurrent.thread diff --git a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/MainActivity.kt b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/MainActivity.kt index cb45f1005..c338a930e 100644 --- a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/MainActivity.kt +++ b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/MainActivity.kt @@ -13,13 +13,14 @@ import androidx.compose.material3.Surface import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier import androidx.core.app.ActivityCompat -import com.k2fsa.sherpa.onnx.Tagger import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme const val TAG = "sherpa-onnx" private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 +// adb emu avd hostmicon +// to enable mic inside the emulator class MainActivity : ComponentActivity() { private val permissions: Array = arrayOf(Manifest.permission.RECORD_AUDIO) override fun onCreate(savedInstanceState: Bundle?) { diff --git a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt deleted file mode 100644 index 49652e72d..000000000 --- a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt +++ /dev/null @@ -1,24 +0,0 @@ -package com.k2fsa.sherpa.onnx - -class OfflineStream(var ptr: Long) { - fun acceptWaveform(samples: FloatArray, sampleRate: Int) = - acceptWaveform(ptr, samples, sampleRate) - - protected fun finalize() { - if (ptr != 0L) { - delete(ptr) - ptr = 0 - } - } - - fun release() = finalize() - - private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) - private external fun delete(ptr: Long) - - companion object { - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} diff --git a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt new file mode 120000 index 000000000..f3faa5b76 --- /dev/null +++ b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt @@ -0,0 +1 @@ +../../../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt \ No newline at end of file diff --git a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/Tagger.kt b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/Tagger.kt index c714094aa..811c9e74f 100644 --- a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/Tagger.kt +++ b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/Tagger.kt @@ -1,7 +1,9 @@ -package com.k2fsa.sherpa.onnx +package com.k2fsa.sherpa.onnx.audio.tagging import android.content.res.AssetManager import android.util.Log +import com.k2fsa.sherpa.onnx.AudioTagging +import com.k2fsa.sherpa.onnx.getAudioTaggingConfig object Tagger { @@ -17,7 +19,7 @@ object Tagger { return } - Log.i(TAG, "Initializing audio tagger") + Log.i("sherpa-onnx", "Initializing audio tagger") val config = getAudioTaggingConfig(type = 0, numThreads = numThreads)!! _tagger = AudioTagging(assetManager, config) } diff --git a/android/SherpaOnnxAudioTaggingWearOs/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/wear/os/presentation/HomeScreen.kt b/android/SherpaOnnxAudioTaggingWearOs/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/wear/os/presentation/HomeScreen.kt index a4f1ba88d..9af2c571e 100644 --- a/android/SherpaOnnxAudioTaggingWearOs/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/wear/os/presentation/HomeScreen.kt +++ b/android/SherpaOnnxAudioTaggingWearOs/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/wear/os/presentation/HomeScreen.kt @@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button import androidx.wear.compose.material.MaterialTheme import androidx.wear.compose.material.Text import com.k2fsa.sherpa.onnx.AudioEvent -import com.k2fsa.sherpa.onnx.Tagger +import com.k2fsa.sherpa.onnx.audio.tagging.Tagger import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme import kotlin.concurrent.thread diff --git a/android/SherpaOnnxAudioTaggingWearOs/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/wear/os/presentation/MainActivity.kt b/android/SherpaOnnxAudioTaggingWearOs/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/wear/os/presentation/MainActivity.kt index fd8b1f719..59a004fdc 100644 --- a/android/SherpaOnnxAudioTaggingWearOs/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/wear/os/presentation/MainActivity.kt +++ b/android/SherpaOnnxAudioTaggingWearOs/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/wear/os/presentation/MainActivity.kt @@ -17,11 +17,14 @@ import androidx.activity.compose.setContent import androidx.compose.runtime.Composable import androidx.core.app.ActivityCompat import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen -import com.k2fsa.sherpa.onnx.Tagger +import com.k2fsa.sherpa.onnx.audio.tagging.Tagger const val TAG = "sherpa-onnx" private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 +// adb emu avd hostmicon +// to enable mic inside the emulator + class MainActivity : ComponentActivity() { private val permissions: Array = arrayOf(Manifest.permission.RECORD_AUDIO) override fun onCreate(savedInstanceState: Bundle?) { diff --git a/android/SherpaOnnxKws/app/src/main/AndroidManifest.xml b/android/SherpaOnnxKws/app/src/main/AndroidManifest.xml index 935fb0e95..d575b6b90 100644 --- a/android/SherpaOnnxKws/app/src/main/AndroidManifest.xml +++ b/android/SherpaOnnxKws/app/src/main/AndroidManifest.xml @@ -15,7 +15,8 @@ android:theme="@style/Theme.SherpaOnnx" tools:targetApi="31"> diff --git a/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt new file mode 120000 index 000000000..952fae878 --- /dev/null +++ b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt \ No newline at end of file diff --git a/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/KeywordSpotter.kt b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/KeywordSpotter.kt new file mode 120000 index 000000000..4392376a1 --- /dev/null +++ b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/KeywordSpotter.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/KeywordSpotter.kt \ No newline at end of file diff --git a/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 83c8abe31..b17a6ea6c 100644 --- a/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -1,4 +1,4 @@ -package com.k2fsa.sherpa.onnx +package com.k2fsa.sherpa.onnx.kws import android.Manifest import android.content.pm.PackageManager @@ -14,7 +14,13 @@ import android.widget.TextView import android.widget.Toast import androidx.appcompat.app.AppCompatActivity import androidx.core.app.ActivityCompat -import com.k2fsa.sherpa.onnx.* +import com.k2fsa.sherpa.onnx.KeywordSpotter +import com.k2fsa.sherpa.onnx.KeywordSpotterConfig +import com.k2fsa.sherpa.onnx.OnlineStream +import com.k2fsa.sherpa.onnx.R +import com.k2fsa.sherpa.onnx.getFeatureConfig +import com.k2fsa.sherpa.onnx.getKeywordsFile +import com.k2fsa.sherpa.onnx.getKwsModelConfig import kotlin.concurrent.thread private const val TAG = "sherpa-onnx" @@ -23,7 +29,8 @@ private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 class MainActivity : AppCompatActivity() { private val permissions: Array = arrayOf(Manifest.permission.RECORD_AUDIO) - private lateinit var model: SherpaOnnxKws + private lateinit var kws: KeywordSpotter + private lateinit var stream: OnlineStream private var audioRecord: AudioRecord? = null private lateinit var recordButton: Button private lateinit var textView: TextView @@ -87,15 +94,18 @@ class MainActivity : AppCompatActivity() { Log.i(TAG, keywords) keywords = keywords.replace("\n", "/") + keywords = keywords.trim() // If keywords is an empty string, it just resets the decoding stream // always returns true in this case. // If keywords is not empty, it will create a new decoding stream with // the given keywords appended to the default keywords. - // Return false if errors occured when adding keywords, true otherwise. - val status = model.reset(keywords) - if (!status) { - Log.i(TAG, "Failed to reset with keywords.") - Toast.makeText(this, "Failed to set keywords.", Toast.LENGTH_LONG).show(); + // Return false if errors occurred when adding keywords, true otherwise. + stream.release() + stream = kws.createStream(keywords) + if (stream.ptr == 0L) { + Log.i(TAG, "Failed to create stream with keywords: $keywords") + Toast.makeText(this, "Failed to set keywords to $keywords.", Toast.LENGTH_LONG) + .show() return } @@ -122,6 +132,7 @@ class MainActivity : AppCompatActivity() { audioRecord!!.release() audioRecord = null recordButton.setText(R.string.start) + stream.release() Log.i(TAG, "Stopped recording") } } @@ -137,22 +148,22 @@ class MainActivity : AppCompatActivity() { val ret = audioRecord?.read(buffer, 0, buffer.size) if (ret != null && ret > 0) { val samples = FloatArray(ret) { buffer[it] / 32768.0f } - model.acceptWaveform(samples, sampleRate=sampleRateInHz) - while (model.isReady()) { - model.decode() + stream.acceptWaveform(samples, sampleRate = sampleRateInHz) + while (kws.isReady(stream)) { + kws.decode(stream) } - val text = model.keyword + val text = kws.getResult(stream).keyword - var textToDisplay = lastText; + var textToDisplay = lastText - if(text.isNotBlank()) { + if (text.isNotBlank()) { if (lastText.isBlank()) { - textToDisplay = "${idx}: ${text}" + textToDisplay = "$idx: $text" } else { - textToDisplay = "${idx}: ${text}\n${lastText}" + textToDisplay = "$idx: $text\n$lastText" } - lastText = "${idx}: ${text}\n${lastText}" + lastText = "$idx: $text\n$lastText" idx += 1 } @@ -188,20 +199,21 @@ class MainActivity : AppCompatActivity() { } private fun initModel() { - // Please change getModelConfig() to add new models + // Please change getKwsModelConfig() to add new models // See https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html // for a list of available models val type = 0 - Log.i(TAG, "Select model type ${type}") + Log.i(TAG, "Select model type $type") val config = KeywordSpotterConfig( featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), - modelConfig = getModelConfig(type = type)!!, - keywordsFile = getKeywordsFile(type = type)!!, + modelConfig = getKwsModelConfig(type = type)!!, + keywordsFile = getKeywordsFile(type = type), ) - model = SherpaOnnxKws( + kws = KeywordSpotter( assetManager = application.assets, config = config, ) + stream = kws.createStream() } -} +} \ No newline at end of file diff --git a/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.kt b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.kt new file mode 120000 index 000000000..5bb19ee10 --- /dev/null +++ b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineRecognizer.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineRecognizer.kt \ No newline at end of file diff --git a/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineStream.kt b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineStream.kt new file mode 120000 index 000000000..d4518b89b --- /dev/null +++ b/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/OnlineStream.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineStream.kt \ No newline at end of file diff --git a/android/SherpaOnnxKws/app/src/main/res/values/strings.xml b/android/SherpaOnnxKws/app/src/main/res/values/strings.xml index 1fba032f9..484977db0 100644 --- a/android/SherpaOnnxKws/app/src/main/res/values/strings.xml +++ b/android/SherpaOnnxKws/app/src/main/res/values/strings.xml @@ -1,12 +1,12 @@ - KWS with Next-gen Kaldi + Keyword spotting Click the Start button to play keyword spotting with Next-gen Kaldi. \n \n\n\n The source code and pre-trained models are publicly available. Please see https://github.com/k2-fsa/sherpa-onnx for details. - Input your keywords here, one keyword perline. + Input your keywords here, one keyword per line.\nTwo example keywords are given below:\n\nn ǐ h ǎo @你好\nd àn g ē d àn g ē @蛋哥蛋哥 Start Stop diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/BarItem.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/BarItem.kt index 7c3a56dda..620f4f0c5 100644 --- a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/BarItem.kt +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/BarItem.kt @@ -2,7 +2,7 @@ package com.k2fsa.sherpa.onnx.speaker.identification import androidx.compose.ui.graphics.vector.ImageVector -data class BarItem ( +data class BarItem( val title: String, // see https://www.composables.com/icons diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/NavRoutes.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/NavRoutes.kt index 118396645..e00abc95a 100644 --- a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/NavRoutes.kt +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/NavRoutes.kt @@ -1,8 +1,8 @@ package com.k2fsa.sherpa.onnx.speaker.identification sealed class NavRoutes(val route: String) { - object Home: NavRoutes("home") - object Register: NavRoutes("register") - object View: NavRoutes("view") - object Help: NavRoutes("help") + object Home : NavRoutes("home") + object Register : NavRoutes("register") + object View : NavRoutes("view") + object Help : NavRoutes("help") } \ No newline at end of file diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/OnlineStream.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/OnlineStream.kt new file mode 120000 index 000000000..3211155f6 --- /dev/null +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/OnlineStream.kt @@ -0,0 +1 @@ +../../../../../../../../../../../../sherpa-onnx/kotlin-api/OnlineStream.kt \ No newline at end of file diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt deleted file mode 100644 index 4c9bd06fa..000000000 --- a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt +++ /dev/null @@ -1,188 +0,0 @@ -package com.k2fsa.sherpa.onnx - -import android.content.res.AssetManager -import android.util.Log - -private val TAG = "sherpa-onnx" -data class SpeakerEmbeddingExtractorConfig( - val model: String, - var numThreads: Int = 1, - var debug: Boolean = false, - var provider: String = "cpu", -) - -class SpeakerEmbeddingExtractorStream(var ptr: Long) { - fun acceptWaveform(samples: FloatArray, sampleRate: Int) = - acceptWaveform(ptr, samples, sampleRate) - - fun inputFinished() = inputFinished(ptr) - - protected fun finalize() { - delete(ptr) - ptr = 0 - } - - private external fun myTest(ptr: Long, v: Array) - - fun release() = finalize() - private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) - - private external fun inputFinished(ptr: Long) - - private external fun delete(ptr: Long) - - companion object { - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} - -class SpeakerEmbeddingExtractor( - assetManager: AssetManager? = null, - config: SpeakerEmbeddingExtractorConfig, -) { - private var ptr: Long - - init { - ptr = if (assetManager != null) { - new(assetManager, config) - } else { - newFromFile(config) - } - } - - protected fun finalize() { - delete(ptr) - ptr = 0 - } - - fun release() = finalize() - - fun createStream(): SpeakerEmbeddingExtractorStream { - val p = createStream(ptr) - return SpeakerEmbeddingExtractorStream(p) - } - - fun isReady(stream: SpeakerEmbeddingExtractorStream) = isReady(ptr, stream.ptr) - fun compute(stream: SpeakerEmbeddingExtractorStream) = compute(ptr, stream.ptr) - fun dim() = dim(ptr) - - private external fun new( - assetManager: AssetManager, - config: SpeakerEmbeddingExtractorConfig, - ): Long - - private external fun newFromFile( - config: SpeakerEmbeddingExtractorConfig, - ): Long - - private external fun delete(ptr: Long) - - private external fun createStream(ptr: Long): Long - - private external fun isReady(ptr: Long, streamPtr: Long): Boolean - - private external fun compute(ptr: Long, streamPtr: Long): FloatArray - - private external fun dim(ptr: Long): Int - - companion object { - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} - -class SpeakerEmbeddingManager(val dim: Int) { - private var ptr: Long - - init { - ptr = new(dim) - } - - protected fun finalize() { - delete(ptr) - ptr = 0 - } - - fun release() = finalize() - fun add(name: String, embedding: FloatArray) = add(ptr, name, embedding) - fun add(name: String, embedding: Array) = addList(ptr, name, embedding) - fun remove(name: String) = remove(ptr, name) - fun search(embedding: FloatArray, threshold: Float) = search(ptr, embedding, threshold) - fun verify(name: String, embedding: FloatArray, threshold: Float) = - verify(ptr, name, embedding, threshold) - - fun contains(name: String) = contains(ptr, name) - fun numSpeakers() = numSpeakers(ptr) - - fun allSpeakerNames() = allSpeakerNames(ptr) - - private external fun new(dim: Int): Long - private external fun delete(ptr: Long): Unit - private external fun add(ptr: Long, name: String, embedding: FloatArray): Boolean - private external fun addList(ptr: Long, name: String, embedding: Array): Boolean - private external fun remove(ptr: Long, name: String): Boolean - private external fun search(ptr: Long, embedding: FloatArray, threshold: Float): String - private external fun verify( - ptr: Long, - name: String, - embedding: FloatArray, - threshold: Float - ): Boolean - - private external fun contains(ptr: Long, name: String): Boolean - private external fun numSpeakers(ptr: Long): Int - - private external fun allSpeakerNames(ptr: Long): Array - - companion object { - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} - -// Please download the model file from -// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models -// and put it inside the assets directory. -// -// Please don't put it in a subdirectory of assets -private val modelName = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" - -object SpeakerRecognition { - var _extractor: SpeakerEmbeddingExtractor? = null - var _manager: SpeakerEmbeddingManager? = null - - val extractor: SpeakerEmbeddingExtractor - get() { - return _extractor!! - } - - val manager: SpeakerEmbeddingManager - get() { - return _manager!! - } - - fun initExtractor(assetManager: AssetManager? = null) { - synchronized(this) { - if (_extractor != null) { - return - } - Log.i(TAG, "Initializing speaker embedding extractor") - - _extractor = SpeakerEmbeddingExtractor( - assetManager = assetManager, - config = SpeakerEmbeddingExtractorConfig( - model = modelName, - numThreads = 2, - debug = false, - provider = "cpu", - ) - ) - - _manager = SpeakerEmbeddingManager(dim = _extractor!!.dim()) - } - } -} diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt new file mode 120000 index 000000000..b7307bc21 --- /dev/null +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt @@ -0,0 +1 @@ +../../../../../../../../../../../../sherpa-onnx/kotlin-api/Speaker.kt \ No newline at end of file diff --git a/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/Home.kt b/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/Home.kt index 018e39134..5a994e9a3 100644 --- a/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/Home.kt +++ b/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/Home.kt @@ -1,4 +1,4 @@ -@file:OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class) +@file:OptIn(ExperimentalMaterial3Api::class) package com.k2fsa.sherpa.onnx.slid @@ -9,11 +9,9 @@ import android.media.AudioFormat import android.media.AudioRecord import android.media.MediaRecorder import android.util.Log -import androidx.compose.foundation.ExperimentalFoundationApi import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.PaddingValues -import androidx.compose.ui.Modifier import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.height @@ -31,6 +29,7 @@ import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.unit.dp @@ -63,13 +62,13 @@ fun Home() { } private var audioRecord: AudioRecord? = null -private val sampleRateInHz = 16000 +private const val sampleRateInHz = 16000 @Composable fun MyApp(padding: PaddingValues) { val activity = LocalContext.current as Activity var isStarted by remember { mutableStateOf(false) } - var result by remember { mutableStateOf("") } + var result by remember { mutableStateOf("") } val onButtonClick: () -> Unit = { isStarted = !isStarted @@ -114,12 +113,12 @@ fun MyApp(padding: PaddingValues) { } Log.i(TAG, "Stop recording") Log.i(TAG, "Start recognition") - val samples = Flatten(sampleList) + val samples = flatten(sampleList) val stream = Slid.slid.createStream() stream.acceptWaveform(samples, sampleRateInHz) val lang = Slid.slid.compute(stream) - result = Slid.localeMap.get(lang) ?: lang + result = Slid.localeMap[lang] ?: lang stream.release() } @@ -152,7 +151,7 @@ fun MyApp(padding: PaddingValues) { } } -fun Flatten(sampleList: ArrayList): FloatArray { +fun flatten(sampleList: ArrayList): FloatArray { var totalSamples = 0 for (a in sampleList) { totalSamples += a.size diff --git a/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/MainActivity.kt b/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/MainActivity.kt index dfbcba160..705f431ee 100644 --- a/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/MainActivity.kt +++ b/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/MainActivity.kt @@ -10,12 +10,9 @@ import androidx.activity.compose.setContent import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Surface -import androidx.compose.material3.Text import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier -import androidx.compose.ui.tooling.preview.Preview import androidx.core.app.ActivityCompat -import com.k2fsa.sherpa.onnx.SpokenLanguageIdentification import com.k2fsa.sherpa.onnx.slid.ui.theme.SherpaOnnxSpokenLanguageIdentificationTheme const val TAG = "sherpa-onnx" @@ -32,6 +29,7 @@ class MainActivity : ComponentActivity() { ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION) Slid.initSlid(this.assets) } + @Suppress("DEPRECATION") @Deprecated("Deprecated in Java") override fun onRequestPermissionsResult( diff --git a/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/OfflineStream.kt b/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/OfflineStream.kt index 1a5dfc316..c8c06085c 120000 --- a/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/OfflineStream.kt +++ b/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/OfflineStream.kt @@ -1 +1 @@ -../../../../../../../../../../SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt \ No newline at end of file +../../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt \ No newline at end of file diff --git a/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/SpokenLanguageIdentification.kt b/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/SpokenLanguageIdentification.kt deleted file mode 100644 index fedf9d65b..000000000 --- a/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/SpokenLanguageIdentification.kt +++ /dev/null @@ -1,102 +0,0 @@ -package com.k2fsa.sherpa.onnx - -import android.content.res.AssetManager -import android.util.Log - -private val TAG = "sherpa-onnx" - -data class SpokenLanguageIdentificationWhisperConfig ( - var encoder: String, - var decoder: String, - var tailPaddings: Int = -1, -) - -data class SpokenLanguageIdentificationConfig ( - var whisper: SpokenLanguageIdentificationWhisperConfig, - var numThreads: Int = 1, - var debug: Boolean = false, - var provider: String = "cpu", -) - -class SpokenLanguageIdentification ( - assetManager: AssetManager? = null, - config: SpokenLanguageIdentificationConfig, -) { - private var ptr: Long - - init { - ptr = if (assetManager != null) { - newFromAsset(assetManager, config) - } else { - newFromFile(config) - } - } - - protected fun finalize() { - if (ptr != 0L) { - delete(ptr) - ptr = 0 - } - } - - fun release() = finalize() - - fun createStream(): OfflineStream { - val p = createStream(ptr) - return OfflineStream(p) - } - - fun compute(stream: OfflineStream) = compute(ptr, stream.ptr) - - private external fun newFromAsset( - assetManager: AssetManager, - config: SpokenLanguageIdentificationConfig, - ): Long - - private external fun newFromFile( - config: SpokenLanguageIdentificationConfig, - ): Long - - private external fun delete(ptr: Long) - - private external fun createStream(ptr: Long): Long - - private external fun compute(ptr: Long, streamPtr: Long): String - - companion object { - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} -// please refer to -// https://k2-fsa.github.io/sherpa/onnx/spolken-language-identification/pretrained_models.html#whisper -// to download more models -fun getSpokenLanguageIdentificationConfig(type: Int, numThreads: Int=1): SpokenLanguageIdentificationConfig? { - when (type) { - 0 -> { - val modelDir = "sherpa-onnx-whisper-tiny" - return SpokenLanguageIdentificationConfig( - whisper = SpokenLanguageIdentificationWhisperConfig( - encoder = "$modelDir/tiny-encoder.int8.onnx", - decoder = "$modelDir/tiny-decoder.int8.onnx", - ), - numThreads = numThreads, - debug = true, - ) - } - - 1 -> { - val modelDir = "sherpa-onnx-whisper-base" - return SpokenLanguageIdentificationConfig( - whisper = SpokenLanguageIdentificationWhisperConfig( - encoder = "$modelDir/tiny-encoder.int8.onnx", - decoder = "$modelDir/tiny-decoder.int8.onnx", - ), - numThreads = 1, - debug = true, - ) - } - } - return null -} diff --git a/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/SpokenLanguageIdentification.kt b/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/SpokenLanguageIdentification.kt new file mode 120000 index 000000000..b5cd3eb98 --- /dev/null +++ b/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/SpokenLanguageIdentification.kt @@ -0,0 +1 @@ +../../../../../../../../../../../sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt \ No newline at end of file diff --git a/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/slid.kt b/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/slid.kt index 60c511704..ed9439db9 100644 --- a/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/slid.kt +++ b/android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/slid.kt @@ -15,10 +15,10 @@ object Slid { get() { return _slid!! } - val localeMap : Map - get() { - return _localeMap - } + val localeMap: Map + get() { + return _localeMap + } fun initSlid(assetManager: AssetManager? = null, numThreads: Int = 1) { synchronized(this) { @@ -31,7 +31,7 @@ object Slid { } if (_localeMap.isEmpty()) { - val allLang = Locale.getISOLanguages(); + val allLang = Locale.getISOLanguages() for (lang in allLang) { val locale = Locale(lang) _localeMap[lang] = locale.displayName diff --git a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index c342b9f61..f44bef8eb 100644 --- a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -1,7 +1,11 @@ package com.k2fsa.sherpa.onnx import android.content.res.AssetManager -import android.media.* +import android.media.AudioAttributes +import android.media.AudioFormat +import android.media.AudioManager +import android.media.AudioTrack +import android.media.MediaPlayer import android.net.Uri import android.os.Bundle import android.util.Log @@ -212,7 +216,7 @@ class MainActivity : AppCompatActivity() { } if (dictDir != null) { - val newDir = copyDataDir( modelDir!!) + val newDir = copyDataDir(modelDir!!) modelDir = newDir + "/" + modelDir dictDir = modelDir + "/" + "dict" ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst" @@ -220,7 +224,9 @@ class MainActivity : AppCompatActivity() { } val config = getOfflineTtsConfig( - modelDir = modelDir!!, modelName = modelName!!, lexicon = lexicon ?: "", + modelDir = modelDir!!, + modelName = modelName!!, + lexicon = lexicon ?: "", dataDir = dataDir ?: "", dictDir = dictDir ?: "", ruleFsts = ruleFsts ?: "", @@ -232,11 +238,11 @@ class MainActivity : AppCompatActivity() { private fun copyDataDir(dataDir: String): String { - println("data dir is $dataDir") + Log.i(TAG, "data dir is $dataDir") copyAssets(dataDir) val newDataDir = application.getExternalFilesDir(null)!!.absolutePath - println("newDataDir: $newDataDir") + Log.i(TAG, "newDataDir: $newDataDir") return newDataDir } @@ -256,7 +262,7 @@ class MainActivity : AppCompatActivity() { } } } catch (ex: IOException) { - Log.e(TAG, "Failed to copy $path. ${ex.toString()}") + Log.e(TAG, "Failed to copy $path. $ex") } } @@ -276,7 +282,7 @@ class MainActivity : AppCompatActivity() { ostream.flush() ostream.close() } catch (ex: Exception) { - Log.e(TAG, "Failed to copy $filename, ${ex.toString()}") + Log.e(TAG, "Failed to copy $filename, $ex") } } } diff --git a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt index e0f95166c..b25869d07 100644 --- a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt +++ b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt @@ -49,10 +49,10 @@ class OfflineTts( private var ptr: Long init { - if (assetManager != null) { - ptr = newFromAsset(assetManager, config) + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) } else { - ptr = newFromFile(config) + newFromFile(config) } } @@ -65,7 +65,7 @@ class OfflineTts( sid: Int = 0, speed: Float = 1.0f ): GeneratedAudio { - var objArray = generateImpl(ptr, text = text, sid = sid, speed = speed) + val objArray = generateImpl(ptr, text = text, sid = sid, speed = speed) return GeneratedAudio( samples = objArray[0] as FloatArray, sampleRate = objArray[1] as Int @@ -78,7 +78,13 @@ class OfflineTts( speed: Float = 1.0f, callback: (samples: FloatArray) -> Unit ): GeneratedAudio { - var objArray = generateWithCallbackImpl(ptr, text = text, sid = sid, speed = speed, callback=callback) + val objArray = generateWithCallbackImpl( + ptr, + text = text, + sid = sid, + speed = speed, + callback = callback + ) return GeneratedAudio( samples = objArray[0] as FloatArray, sampleRate = objArray[1] as Int @@ -87,10 +93,10 @@ class OfflineTts( fun allocate(assetManager: AssetManager? = null) { if (ptr == 0L) { - if (assetManager != null) { - ptr = newFromAsset(assetManager, config) + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) } else { - ptr = newFromFile(config) + newFromFile(config) } } } @@ -103,9 +109,14 @@ class OfflineTts( } protected fun finalize() { - delete(ptr) + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } } + fun release() = finalize() + private external fun newFromAsset( assetManager: AssetManager, config: OfflineTtsConfig, @@ -123,14 +134,14 @@ class OfflineTts( // - the first entry is an 1-D float array containing audio samples. // Each sample is normalized to the range [-1, 1] // - the second entry is the sample rate - external fun generateImpl( + private external fun generateImpl( ptr: Long, text: String, sid: Int = 0, speed: Float = 1.0f ): Array - external fun generateWithCallbackImpl( + private external fun generateWithCallbackImpl( ptr: Long, text: String, sid: Int = 0, @@ -156,7 +167,7 @@ fun getOfflineTtsConfig( dictDir: String, ruleFsts: String, ruleFars: String -): OfflineTtsConfig? { +): OfflineTtsConfig { return OfflineTtsConfig( model = OfflineTtsModelConfig( vits = OfflineTtsVitsModelConfig( diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/CheckVoiceData.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/CheckVoiceData.kt index 9ddc13820..78a7de9e3 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/CheckVoiceData.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/CheckVoiceData.kt @@ -1,15 +1,18 @@ package com.k2fsa.sherpa.onnx.tts.engine import android.content.Intent -import androidx.appcompat.app.AppCompatActivity import android.os.Bundle import android.speech.tts.TextToSpeech +import androidx.appcompat.app.AppCompatActivity class CheckVoiceData : AppCompatActivity() { override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) val intent = Intent().apply { - putStringArrayListExtra(TextToSpeech.Engine.EXTRA_AVAILABLE_VOICES, arrayListOf(TtsEngine.lang)) + putStringArrayListExtra( + TextToSpeech.Engine.EXTRA_AVAILABLE_VOICES, + arrayListOf(TtsEngine.lang) + ) putStringArrayListExtra(TextToSpeech.Engine.EXTRA_UNAVAILABLE_VOICES, arrayListOf()) } setResult(TextToSpeech.Engine.CHECK_VOICE_DATA_PASS, intent) diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/GetSampleText.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/GetSampleText.kt index bd5b9ea5e..61e683738 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/GetSampleText.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/GetSampleText.kt @@ -2,7 +2,6 @@ package com.k2fsa.sherpa.onnx.tts.engine import android.app.Activity import android.content.Intent -import androidx.appcompat.app.AppCompatActivity import android.os.Bundle import android.speech.tts.TextToSpeech @@ -12,120 +11,168 @@ fun getSampleText(lang: String): String { "ara" -> { text = "هذا هو محرك تحويل النص إلى كلام باستخدام الجيل القادم من كالدي" } + "ben" -> { text = "এটি একটি টেক্সট-টু-স্পীচ ইঞ্জিন যা পরবর্তী প্রজন্মের কালডি ব্যবহার করে" } + "bul" -> { - text = "Това е машина за преобразуване на текст в реч, използваща Kaldi от следващо поколение" + text = + "Това е машина за преобразуване на текст в реч, използваща Kaldi от следващо поколение" } + "cat" -> { text = "Aquest és un motor de text a veu que utilitza Kaldi de nova generació" } + "ces" -> { text = "Toto je převodník textu na řeč využívající novou generaci kaldi" } + "dan" -> { text = "Dette er en tekst til tale-motor, der bruger næste generation af kaldi" } + "deu" -> { - text = "Dies ist eine Text-to-Speech-Engine, die Kaldi der nächsten Generation verwendet" + text = + "Dies ist eine Text-to-Speech-Engine, die Kaldi der nächsten Generation verwendet" } + "ell" -> { text = "Αυτή είναι μια μηχανή κειμένου σε ομιλία που χρησιμοποιεί kaldi επόμενης γενιάς" } + "eng" -> { text = "This is a text-to-speech engine using next generation Kaldi" } + "est" -> { text = "See on teksti kõneks muutmise mootor, mis kasutab järgmise põlvkonna Kaldi" } + "fin" -> { text = "Tämä on tekstistä puheeksi -moottori, joka käyttää seuraavan sukupolven kaldia" } + "fra" -> { text = "Il s'agit d'un moteur de synthèse vocale utilisant Kaldi de nouvelle génération" } + "gle" -> { text = "Is inneall téacs-go-hurlabhra é seo a úsáideann Kaldi den chéad ghlúin eile" } + "hrv" -> { - text = "Ovo je mehanizam za pretvaranje teksta u govor koji koristi Kaldi sljedeće generacije" + text = + "Ovo je mehanizam za pretvaranje teksta u govor koji koristi Kaldi sljedeće generacije" } + "hun" -> { text = "Ez egy szövegfelolvasó motor a következő generációs kaldi használatával" } + "isl" -> { text = "Þetta er texta í tal vél sem notar næstu kynslóð kaldi" } + "ita" -> { text = "Questo è un motore di sintesi vocale che utilizza kaldi di nuova generazione" } + "kat" -> { text = "ეს არის ტექსტიდან მეტყველების ძრავა შემდეგი თაობის კალდის გამოყენებით" } + "kaz" -> { text = "Бұл келесі буын kaldi көмегімен мәтіннен сөйлеуге арналған қозғалтқыш" } + "mlt" -> { text = "Din hija magna text-to-speech li tuża Kaldi tal-ġenerazzjoni li jmiss" } + "lav" -> { text = "Šis ir teksta pārvēršanas runā dzinējs, kas izmanto nākamās paaudzes Kaldi" } + "lit" -> { text = "Tai teksto į kalbą variklis, kuriame naudojamas naujos kartos Kaldi" } + "ltz" -> { text = "Dëst ass en Text-zu-Speech-Motor mat der nächster Generatioun Kaldi" } + "nep" -> { text = "यो अर्को पुस्ता काल्डी प्रयोग गरेर स्पीच इन्जिनको पाठ हो" } + "nld" -> { - text = "Dit is een tekst-naar-spraak-engine die gebruik maakt van Kaldi van de volgende generatie" + text = + "Dit is een tekst-naar-spraak-engine die gebruik maakt van Kaldi van de volgende generatie" } + "nor" -> { text = "Dette er en tekst til tale-motor som bruker neste generasjons kaldi" } + "pol" -> { text = "Jest to silnik syntezatora mowy wykorzystujący Kaldi nowej generacji" } + "por" -> { - text = "Este é um mecanismo de conversão de texto em fala usando Kaldi de próxima geração" + text = + "Este é um mecanismo de conversão de texto em fala usando Kaldi de próxima geração" } + "ron" -> { text = "Acesta este un motor text to speech care folosește generația următoare de kadi" } + "rus" -> { - text = "Это движок преобразования текста в речь, использующий Kaldi следующего поколения." + text = + "Это движок преобразования текста в речь, использующий Kaldi следующего поколения." } + "slk" -> { text = "Toto je nástroj na prevod textu na reč využívajúci kaldi novej generácie" } + "slv" -> { - text = "To je mehanizem za pretvorbo besedila v govor, ki uporablja Kaldi naslednje generacije" + text = + "To je mehanizem za pretvorbo besedila v govor, ki uporablja Kaldi naslednje generacije" } + "spa" -> { text = "Este es un motor de texto a voz que utiliza kaldi de próxima generación." } + "srp" -> { - text = "Ово је механизам за претварање текста у говор који користи калди следеће генерације" + text = + "Ово је механизам за претварање текста у говор који користи калди следеће генерације" } + "swa" -> { text = "Haya ni maandishi kwa injini ya hotuba kwa kutumia kizazi kijacho kaldi" } + "swe" -> { text = "Detta är en text till tal-motor som använder nästa generations kaldi" } + "tur" -> { text = "Bu, yeni nesil kaldi'yi kullanan bir metinden konuşmaya motorudur" } + "ukr" -> { - text = "Це механізм перетворення тексту на мовлення, який використовує kaldi нового покоління" + text = + "Це механізм перетворення тексту на мовлення, який використовує kaldi нового покоління" } + "vie" -> { text = "Đây là công cụ chuyển văn bản thành giọng nói sử dụng kaldi thế hệ tiếp theo" } + "zho", "cmn" -> { text = "使用新一代卡尔迪的语音合成引擎" } @@ -137,13 +184,13 @@ class GetSampleText : Activity() { override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) var result = TextToSpeech.LANG_AVAILABLE - var text: String = getSampleText(TtsEngine.lang ?: "") + val text: String = getSampleText(TtsEngine.lang ?: "") if (text.isEmpty()) { result = TextToSpeech.LANG_NOT_SUPPORTED } - val intent = Intent().apply{ - if(result == TextToSpeech.LANG_AVAILABLE) { + val intent = Intent().apply { + if (result == TextToSpeech.LANG_AVAILABLE) { putExtra(TextToSpeech.Engine.EXTRA_SAMPLE_TEXT, text) } else { putExtra("sampleText", text) diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt index d42b1b1f7..28ce449a0 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt @@ -26,20 +26,16 @@ import androidx.compose.material3.Scaffold import androidx.compose.material3.Slider import androidx.compose.material3.Surface import androidx.compose.material3.Text -import androidx.compose.material3.TextField import androidx.compose.material3.TopAppBar -import androidx.compose.runtime.Composable import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember import androidx.compose.runtime.setValue import androidx.compose.ui.Modifier import androidx.compose.ui.text.input.KeyboardType -import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import com.k2fsa.sherpa.onnx.tts.engine.ui.theme.SherpaOnnxTtsEngineTheme import java.io.File -import java.lang.NumberFormatException const val TAG = "sherpa-onnx-tts-engine" @@ -76,7 +72,7 @@ class MainActivity : ComponentActivity() { val testTextContent = getSampleText(TtsEngine.lang ?: "") var testText by remember { mutableStateOf(testTextContent) } - + val numSpeakers = TtsEngine.tts!!.numSpeakers() if (numSpeakers > 1) { OutlinedTextField( @@ -88,7 +84,7 @@ class MainActivity : ComponentActivity() { try { TtsEngine.speakerId = it.toString().toInt() } catch (ex: NumberFormatException) { - Log.i(TAG, "Invalid input: ${it}") + Log.i(TAG, "Invalid input: $it") TtsEngine.speakerId = 0 } } @@ -119,7 +115,7 @@ class MainActivity : ComponentActivity() { Button( modifier = Modifier.padding(20.dp), onClick = { - Log.i(TAG, "Clicked, text: ${testText}") + Log.i(TAG, "Clicked, text: $testText") if (testText.isBlank() || testText.isEmpty()) { Toast.makeText( applicationContext, @@ -136,7 +132,7 @@ class MainActivity : ComponentActivity() { val filename = application.filesDir.absolutePath + "/generated.wav" val ok = - audio.samples.size > 0 && audio.save(filename) + audio.samples.isNotEmpty() && audio.save(filename) if (ok) { stopMediaPlayer() diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt index e02cc069c..1bf92972e 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt @@ -4,8 +4,10 @@ import android.content.Context import android.content.res.AssetManager import android.util.Log import androidx.compose.runtime.MutableState -import androidx.compose.runtime.mutableStateOf -import com.k2fsa.sherpa.onnx.* +import androidx.compose.runtime.mutableFloatStateOf +import androidx.compose.runtime.mutableIntStateOf +import com.k2fsa.sherpa.onnx.OfflineTts +import com.k2fsa.sherpa.onnx.getOfflineTtsConfig import java.io.File import java.io.FileOutputStream import java.io.IOException @@ -21,8 +23,8 @@ object TtsEngine { var lang: String? = null - val speedState: MutableState = mutableStateOf(1.0F) - val speakerIdState: MutableState = mutableStateOf(0) + val speedState: MutableState = mutableFloatStateOf(1.0F) + val speakerIdState: MutableState = mutableIntStateOf(0) var speed: Float get() = speedState.value @@ -113,15 +115,15 @@ object TtsEngine { if (dataDir != null) { val newDir = copyDataDir(context, modelDir!!) - modelDir = newDir + "/" + modelDir - dataDir = newDir + "/" + dataDir + modelDir = "$newDir/$modelDir" + dataDir = "$newDir/$dataDir" assets = null } if (dictDir != null) { val newDir = copyDataDir(context, modelDir!!) - modelDir = newDir + "/" + modelDir - dictDir = modelDir + "/" + "dict" + modelDir = "$newDir/$modelDir" + dictDir = "$modelDir/dict" ruleFsts = "$modelDir/phone.fst,$modelDir/date.fst,$modelDir/number.fst" assets = null } @@ -132,18 +134,18 @@ object TtsEngine { dictDir = dictDir ?: "", ruleFsts = ruleFsts ?: "", ruleFars = ruleFars ?: "" - )!! + ) tts = OfflineTts(assetManager = assets, config = config) } private fun copyDataDir(context: Context, dataDir: String): String { - println("data dir is $dataDir") + Log.i(TAG, "data dir is $dataDir") copyAssets(context, dataDir) val newDataDir = context.getExternalFilesDir(null)!!.absolutePath - println("newDataDir: $newDataDir") + Log.i(TAG, "newDataDir: $newDataDir") return newDataDir } @@ -158,12 +160,12 @@ object TtsEngine { val dir = File(fullPath) dir.mkdirs() for (asset in assets.iterator()) { - val p: String = if (path == "") "" else path + "/" + val p: String = if (path == "") "" else "$path/" copyAssets(context, p + asset) } } } catch (ex: IOException) { - Log.e(TAG, "Failed to copy $path. ${ex.toString()}") + Log.e(TAG, "Failed to copy $path. $ex") } } @@ -183,7 +185,7 @@ object TtsEngine { ostream.flush() ostream.close() } catch (ex: Exception) { - Log.e(TAG, "Failed to copy $filename, ${ex.toString()}") + Log.e(TAG, "Failed to copy $filename, $ex") } } } diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsService.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsService.kt index c89f29cc4..3cfc66a61 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsService.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsService.kt @@ -6,7 +6,6 @@ import android.speech.tts.SynthesisRequest import android.speech.tts.TextToSpeech import android.speech.tts.TextToSpeechService import android.util.Log -import com.k2fsa.sherpa.onnx.* /* https://developer.android.com/reference/java/util/Locale#getISO3Language() diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsViewModel.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsViewModel.kt index 2226c6b93..3ccf19703 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsViewModel.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsViewModel.kt @@ -1,7 +1,6 @@ package com.k2fsa.sherpa.onnx.tts.engine import android.app.Application -import android.os.FileUtils.ProgressListener import android.speech.tts.TextToSpeech import android.speech.tts.TextToSpeech.OnInitListener import android.speech.tts.UtteranceProgressListener @@ -27,7 +26,7 @@ class TtsViewModel : ViewModel() { private val onInitListener = object : OnInitListener { override fun onInit(status: Int) { when (status) { - TextToSpeech.SUCCESS -> Log.i(TAG, "Init tts succeded") + TextToSpeech.SUCCESS -> Log.i(TAG, "Init tts succeeded") TextToSpeech.ERROR -> Log.i(TAG, "Init tts failed") else -> Log.i(TAG, "Unknown status $status") } diff --git a/android/SherpaOnnxVad/app/src/main/AndroidManifest.xml b/android/SherpaOnnxVad/app/src/main/AndroidManifest.xml index 4c591cc53..7e4c36002 100644 --- a/android/SherpaOnnxVad/app/src/main/AndroidManifest.xml +++ b/android/SherpaOnnxVad/app/src/main/AndroidManifest.xml @@ -15,7 +15,7 @@ android:theme="@style/Theme.SherpaOnnxVad" tools:targetApi="31"> diff --git a/android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 4d5ce7e74..2f8e3a95c 100644 --- a/android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -1,4 +1,4 @@ -package com.k2fsa.sherpa.onnx +package com.k2fsa.sherpa.onnx.vad import android.Manifest import android.content.pm.PackageManager @@ -11,6 +11,9 @@ import android.view.View import android.widget.Button import androidx.appcompat.app.AppCompatActivity import androidx.core.app.ActivityCompat +import com.k2fsa.sherpa.onnx.R +import com.k2fsa.sherpa.onnx.Vad +import com.k2fsa.sherpa.onnx.getVadModelConfig import kotlin.concurrent.thread @@ -116,7 +119,7 @@ class MainActivity : AppCompatActivity() { private fun initVadModel() { val type = 0 - println("Select VAD model type ${type}") + Log.i(TAG, "Select VAD model type ${type}") val config = getVadModelConfig(type) vad = Vad( @@ -171,4 +174,4 @@ class MainActivity : AppCompatActivity() { } } } -} \ No newline at end of file +} diff --git a/android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt b/android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt deleted file mode 100644 index 081ae3e8a..000000000 --- a/android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2023 Xiaomi Corporation -package com.k2fsa.sherpa.onnx - -import android.content.res.AssetManager - -data class SileroVadModelConfig( - var model: String, - var threshold: Float = 0.5F, - var minSilenceDuration: Float = 0.25F, - var minSpeechDuration: Float = 0.25F, - var windowSize: Int = 512, -) - -data class VadModelConfig( - var sileroVadModelConfig: SileroVadModelConfig, - var sampleRate: Int = 16000, - var numThreads: Int = 1, - var provider: String = "cpu", - var debug: Boolean = false, -) - -class Vad( - assetManager: AssetManager? = null, - var config: VadModelConfig, -) { - private val ptr: Long - - init { - if (assetManager != null) { - ptr = new(assetManager, config) - } else { - ptr = newFromFile(config) - } - } - - protected fun finalize() { - delete(ptr) - } - - fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples) - - fun empty(): Boolean = empty(ptr) - fun pop() = pop(ptr) - - // return an array containing - // [start: Int, samples: FloatArray] - fun front() = front(ptr) - - fun clear() = clear(ptr) - - fun isSpeechDetected(): Boolean = isSpeechDetected(ptr) - - fun reset() = reset(ptr) - - private external fun delete(ptr: Long) - - private external fun new( - assetManager: AssetManager, - config: VadModelConfig, - ): Long - - private external fun newFromFile( - config: VadModelConfig, - ): Long - - private external fun acceptWaveform(ptr: Long, samples: FloatArray) - private external fun empty(ptr: Long): Boolean - private external fun pop(ptr: Long) - private external fun clear(ptr: Long) - private external fun front(ptr: Long): Array - private external fun isSpeechDetected(ptr: Long): Boolean - private external fun reset(ptr: Long) - - companion object { - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} - -// Please visit -// https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx -// to download silero_vad.onnx -// and put it inside the assets/ -// directory -fun getVadModelConfig(type: Int): VadModelConfig? { - when (type) { - 0 -> { - return VadModelConfig( - sileroVadModelConfig = SileroVadModelConfig( - model = "silero_vad.onnx", - threshold = 0.5F, - minSilenceDuration = 0.25F, - minSpeechDuration = 0.25F, - windowSize = 512, - ), - sampleRate = 16000, - numThreads = 1, - provider = "cpu", - ) - } - } - return null; -} diff --git a/android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt b/android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt new file mode 120000 index 000000000..761b158ce --- /dev/null +++ b/android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/Vad.kt \ No newline at end of file diff --git a/android/SherpaOnnxVad/app/src/main/res/layout/activity_main.xml b/android/SherpaOnnxVad/app/src/main/res/layout/activity_main.xml index cb8294da1..cd1754ac1 100644 --- a/android/SherpaOnnxVad/app/src/main/res/layout/activity_main.xml +++ b/android/SherpaOnnxVad/app/src/main/res/layout/activity_main.xml @@ -4,7 +4,7 @@ xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent" - tools:context=".MainActivity"> + tools:context="com.k2fsa.sherpa.onnx.vad.MainActivity"> \ No newline at end of file + diff --git a/android/SherpaOnnxVadAsr/app/src/main/AndroidManifest.xml b/android/SherpaOnnxVadAsr/app/src/main/AndroidManifest.xml index 986a17d50..7657dba60 100644 --- a/android/SherpaOnnxVadAsr/app/src/main/AndroidManifest.xml +++ b/android/SherpaOnnxVadAsr/app/src/main/AndroidManifest.xml @@ -15,7 +15,7 @@ android:theme="@style/Theme.SherpaOnnxVadAsr" tools:targetApi="31"> diff --git a/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt b/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt new file mode 120000 index 000000000..952fae878 --- /dev/null +++ b/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/FeatureConfig.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/FeatureConfig.kt \ No newline at end of file diff --git a/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 6668bb37d..a7d051853 100644 --- a/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -1,4 +1,4 @@ -package com.k2fsa.sherpa.onnx +package com.k2fsa.sherpa.onnx.vad.asr import android.Manifest import android.content.pm.PackageManager @@ -13,6 +13,13 @@ import android.widget.Button import android.widget.TextView import androidx.appcompat.app.AppCompatActivity import androidx.core.app.ActivityCompat +import com.k2fsa.sherpa.onnx.OfflineRecognizer +import com.k2fsa.sherpa.onnx.OfflineRecognizerConfig +import com.k2fsa.sherpa.onnx.R +import com.k2fsa.sherpa.onnx.Vad +import com.k2fsa.sherpa.onnx.getFeatureConfig +import com.k2fsa.sherpa.onnx.getOfflineModelConfig +import com.k2fsa.sherpa.onnx.getVadModelConfig import kotlin.concurrent.thread @@ -40,7 +47,7 @@ class MainActivity : AppCompatActivity() { private val permissions: Array = arrayOf(Manifest.permission.RECORD_AUDIO) // Non-streaming ASR - private lateinit var offlineRecognizer: SherpaOnnxOffline + private lateinit var offlineRecognizer: OfflineRecognizer private var idx: Int = 0 private var lastText: String = "" @@ -122,7 +129,7 @@ class MainActivity : AppCompatActivity() { private fun initVadModel() { val type = 0 - println("Select VAD model type ${type}") + Log.i(TAG, "Select VAD model type ${type}") val config = getVadModelConfig(type) vad = Vad( @@ -194,20 +201,25 @@ class MainActivity : AppCompatActivity() { // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html // for a list of available models val secondType = 0 - println("Select model type ${secondType} for the second pass") + Log.i(TAG, "Select model type ${secondType} for the second pass") val config = OfflineRecognizerConfig( featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), modelConfig = getOfflineModelConfig(type = secondType)!!, ) - offlineRecognizer = SherpaOnnxOffline( + offlineRecognizer = OfflineRecognizer( assetManager = application.assets, config = config, ) } private fun runSecondPass(samples: FloatArray): String { - return offlineRecognizer.decode(samples, sampleRateInHz) + val stream = offlineRecognizer.createStream() + stream.acceptWaveform(samples, sampleRateInHz) + offlineRecognizer.decode(stream) + val result = offlineRecognizer.getResult(stream) + stream.release() + return result.text } -} \ No newline at end of file +} diff --git a/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineRecognizer.kt b/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineRecognizer.kt new file mode 120000 index 000000000..faa3ab4ac --- /dev/null +++ b/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineRecognizer.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineRecognizer.kt \ No newline at end of file diff --git a/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineStream.kt b/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineStream.kt new file mode 120000 index 000000000..2a3aff864 --- /dev/null +++ b/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/OfflineStream.kt @@ -0,0 +1 @@ +../../../../../../../../../../sherpa-onnx/kotlin-api/OfflineStream.kt \ No newline at end of file diff --git a/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt deleted file mode 120000 index 57ba3e85a..000000000 --- a/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ /dev/null @@ -1 +0,0 @@ -../../../../../../../../../SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt \ No newline at end of file diff --git a/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt b/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt index f430a1056..761b158ce 120000 --- a/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt +++ b/android/SherpaOnnxVadAsr/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt @@ -1 +1 @@ -../../../../../../../../../SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt \ No newline at end of file +../../../../../../../../../../sherpa-onnx/kotlin-api/Vad.kt \ No newline at end of file diff --git a/android/SherpaOnnxVadAsr/app/src/main/res/layout/activity_main.xml b/android/SherpaOnnxVadAsr/app/src/main/res/layout/activity_main.xml index f9b35e862..fc89d4257 100644 --- a/android/SherpaOnnxVadAsr/app/src/main/res/layout/activity_main.xml +++ b/android/SherpaOnnxVadAsr/app/src/main/res/layout/activity_main.xml @@ -4,7 +4,7 @@ xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent" - tools:context=".MainActivity"> + tools:context=".vad.asr.MainActivity"> - VAD-ASR + VAD+ASR Click the Start button to play speech-to-text with Next-gen Kaldi. \n \n\n\n diff --git a/build-android-arm64-v8a.sh b/build-android-arm64-v8a.sh index 181e70d4c..2f9be5a4b 100755 --- a/build-android-arm64-v8a.sh +++ b/build-android-arm64-v8a.sh @@ -59,7 +59,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/ echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR" echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR" +if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then + SHERPA_ONNX_ENABLE_TTS=ON +fi + +if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then + SHERPA_ONNX_ENABLE_BINARY=OFF +fi + cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ + -DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \ + -DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \ -DBUILD_PIPER_PHONMIZE_EXE=OFF \ -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ -DBUILD_ESPEAK_NG_EXE=OFF \ diff --git a/build-android-armv7-eabi.sh b/build-android-armv7-eabi.sh index 5c8bcd132..a574b8f66 100755 --- a/build-android-armv7-eabi.sh +++ b/build-android-armv7-eabi.sh @@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/ echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR" echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR" +if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then + SHERPA_ONNX_ENABLE_TTS=ON +fi + +if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then + SHERPA_ONNX_ENABLE_BINARY=OFF +fi + cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ + -DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \ + -DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \ -DBUILD_PIPER_PHONMIZE_EXE=OFF \ -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ -DBUILD_ESPEAK_NG_EXE=OFF \ diff --git a/build-android-x86-64.sh b/build-android-x86-64.sh index 15241f050..8119834d5 100755 --- a/build-android-x86-64.sh +++ b/build-android-x86-64.sh @@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/ echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR" echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR" +if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then + SHERPA_ONNX_ENABLE_TTS=ON +fi + +if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then + SHERPA_ONNX_ENABLE_BINARY=OFF +fi + cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ + -DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \ + -DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \ -DBUILD_PIPER_PHONMIZE_EXE=OFF \ -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ -DBUILD_ESPEAK_NG_EXE=OFF \ diff --git a/build-android-x86.sh b/build-android-x86.sh index c02d9fc5e..4499fc396 100755 --- a/build-android-x86.sh +++ b/build-android-x86.sh @@ -60,7 +60,17 @@ export SHERPA_ONNXRUNTIME_INCLUDE_DIR=$dir/$onnxruntime_version/headers/ echo "SHERPA_ONNXRUNTIME_LIB_DIR: $SHERPA_ONNXRUNTIME_LIB_DIR" echo "SHERPA_ONNXRUNTIME_INCLUDE_DIR $SHERPA_ONNXRUNTIME_INCLUDE_DIR" +if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then + SHERPA_ONNX_ENABLE_TTS=ON +fi + +if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then + SHERPA_ONNX_ENABLE_BINARY=OFF +fi + cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ + -DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \ + -DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \ -DBUILD_PIPER_PHONMIZE_EXE=OFF \ -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ -DBUILD_ESPEAK_NG_EXE=OFF \ diff --git a/kotlin-api-examples/AudioTagging.kt b/kotlin-api-examples/AudioTagging.kt index ff59d8d34..746902e6d 120000 --- a/kotlin-api-examples/AudioTagging.kt +++ b/kotlin-api-examples/AudioTagging.kt @@ -1 +1 @@ -../android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt \ No newline at end of file +../sherpa-onnx/kotlin-api/AudioTagging.kt \ No newline at end of file diff --git a/kotlin-api-examples/FeatureConfig.kt b/kotlin-api-examples/FeatureConfig.kt new file mode 120000 index 000000000..706de75ca --- /dev/null +++ b/kotlin-api-examples/FeatureConfig.kt @@ -0,0 +1 @@ +../sherpa-onnx/kotlin-api/FeatureConfig.kt \ No newline at end of file diff --git a/kotlin-api-examples/Main.kt b/kotlin-api-examples/Main.kt deleted file mode 100644 index 479ce3428..000000000 --- a/kotlin-api-examples/Main.kt +++ /dev/null @@ -1,245 +0,0 @@ -package com.k2fsa.sherpa.onnx - -import android.content.res.AssetManager - -fun callback(samples: FloatArray): Unit { - println("callback got called with ${samples.size} samples"); -} - -fun main() { - testSpokenLanguageIdentifcation() - testAudioTagging() - testSpeakerRecognition() - testTts() - testAsr("transducer") - testAsr("zipformer2-ctc") -} - -fun testSpokenLanguageIdentifcation() { - val config = SpokenLanguageIdentificationConfig( - whisper = SpokenLanguageIdentificationWhisperConfig( - encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx", - decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx", - tailPaddings = 33, - ), - numThreads=1, - debug=true, - provider="cpu", - ) - val slid = SpokenLanguageIdentification(assetManager=null, config=config) - - val testFiles = arrayOf( - "./spoken-language-identification-test-wavs/ar-arabic.wav", - "./spoken-language-identification-test-wavs/bg-bulgarian.wav", - "./spoken-language-identification-test-wavs/de-german.wav", - ) - - for (waveFilename in testFiles) { - val objArray = WaveReader.readWaveFromFile( - filename = waveFilename, - ) - val samples: FloatArray = objArray[0] as FloatArray - val sampleRate: Int = objArray[1] as Int - - val stream = slid.createStream() - stream.acceptWaveform(samples, sampleRate = sampleRate) - val lang = slid.compute(stream) - stream.release() - println(waveFilename) - println(lang) - } -} - -fun testAudioTagging() { - val config = AudioTaggingConfig( - model=AudioTaggingModelConfig( - zipformer=OfflineZipformerAudioTaggingModelConfig( - model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx", - ), - numThreads=1, - debug=true, - provider="cpu", - ), - labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv", - topK=5, - ) - val tagger = AudioTagging(assetManager=null, config=config) - - val testFiles = arrayOf( - "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav", - "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav", - "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav", - "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav", - ) - println("----------") - for (waveFilename in testFiles) { - val stream = tagger.createStream() - - val objArray = WaveReader.readWaveFromFile( - filename = waveFilename, - ) - val samples: FloatArray = objArray[0] as FloatArray - val sampleRate: Int = objArray[1] as Int - - stream.acceptWaveform(samples, sampleRate = sampleRate) - val events = tagger.compute(stream) - stream.release() - - println(waveFilename) - println(events) - println("----------") - } - - tagger.release() -} - -fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray { - var objArray = WaveReader.readWaveFromFile( - filename = filename, - ) - var samples: FloatArray = objArray[0] as FloatArray - var sampleRate: Int = objArray[1] as Int - - val stream = extractor.createStream() - stream.acceptWaveform(sampleRate = sampleRate, samples=samples) - stream.inputFinished() - check(extractor.isReady(stream)) - - val embedding = extractor.compute(stream) - - stream.release() - - return embedding -} - -fun testSpeakerRecognition() { - val config = SpeakerEmbeddingExtractorConfig( - model="./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx", - ) - val extractor = SpeakerEmbeddingExtractor(config = config) - - val embedding1a = computeEmbedding(extractor, "./speaker1_a_cn_16k.wav") - val embedding2a = computeEmbedding(extractor, "./speaker2_a_cn_16k.wav") - val embedding1b = computeEmbedding(extractor, "./speaker1_b_cn_16k.wav") - - var manager = SpeakerEmbeddingManager(extractor.dim()) - var ok = manager.add(name = "speaker1", embedding=embedding1a) - check(ok) - - manager.add(name = "speaker2", embedding=embedding2a) - check(ok) - - var name = manager.search(embedding=embedding1b, threshold=0.5f) - check(name == "speaker1") - - manager.release() - - manager = SpeakerEmbeddingManager(extractor.dim()) - val embeddingList = mutableListOf(embedding1a, embedding1b) - ok = manager.add(name = "s1", embedding=embeddingList.toTypedArray()) - check(ok) - - name = manager.search(embedding=embedding1b, threshold=0.5f) - check(name == "s1") - - name = manager.search(embedding=embedding2a, threshold=0.5f) - check(name.length == 0) - - manager.release() -} - -fun testTts() { - // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models - // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 - var config = OfflineTtsConfig( - model=OfflineTtsModelConfig( - vits=OfflineTtsVitsModelConfig( - model="./vits-piper-en_US-amy-low/en_US-amy-low.onnx", - tokens="./vits-piper-en_US-amy-low/tokens.txt", - dataDir="./vits-piper-en_US-amy-low/espeak-ng-data", - ), - numThreads=1, - debug=true, - ) - ) - val tts = OfflineTts(config=config) - val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback) - audio.save(filename="test-en.wav") -} - -fun testAsr(type: String) { - var featConfig = FeatureConfig( - sampleRate = 16000, - featureDim = 80, - ) - - var waveFilename: String - var modelConfig: OnlineModelConfig = when (type) { - "transducer" -> { - waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav" - // please refer to - // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html - // to dowload pre-trained models - OnlineModelConfig( - transducer = OnlineTransducerModelConfig( - encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", - decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", - joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", - ), - tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", - numThreads = 1, - debug = false, - ) - } - "zipformer2-ctc" -> { - waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav" - OnlineModelConfig( - zipformer2Ctc = OnlineZipformer2CtcModelConfig( - model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx", - ), - tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt", - numThreads = 1, - debug = false, - ) - } - else -> throw IllegalArgumentException(type) - } - - var endpointConfig = EndpointConfig() - - var lmConfig = OnlineLMConfig() - - var config = OnlineRecognizerConfig( - modelConfig = modelConfig, - lmConfig = lmConfig, - featConfig = featConfig, - endpointConfig = endpointConfig, - enableEndpoint = true, - decodingMethod = "greedy_search", - maxActivePaths = 4, - ) - - var model = SherpaOnnx( - config = config, - ) - - var objArray = WaveReader.readWaveFromFile( - filename = waveFilename, - ) - var samples: FloatArray = objArray[0] as FloatArray - var sampleRate: Int = objArray[1] as Int - - model.acceptWaveform(samples, sampleRate = sampleRate) - while (model.isReady()) { - model.decode() - } - - var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds - model.acceptWaveform(tailPaddings, sampleRate = sampleRate) - model.inputFinished() - while (model.isReady()) { - model.decode() - } - - println("results: ${model.text}") -} diff --git a/kotlin-api-examples/OfflineRecognizer.kt b/kotlin-api-examples/OfflineRecognizer.kt new file mode 120000 index 000000000..68dac9f18 --- /dev/null +++ b/kotlin-api-examples/OfflineRecognizer.kt @@ -0,0 +1 @@ +../sherpa-onnx/kotlin-api/OfflineRecognizer.kt \ No newline at end of file diff --git a/kotlin-api-examples/OfflineStream.kt b/kotlin-api-examples/OfflineStream.kt index 6304bfdf8..1344a9835 120000 --- a/kotlin-api-examples/OfflineStream.kt +++ b/kotlin-api-examples/OfflineStream.kt @@ -1 +1 @@ -../android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/OfflineStream.kt \ No newline at end of file +../sherpa-onnx/kotlin-api/OfflineStream.kt \ No newline at end of file diff --git a/kotlin-api-examples/OnlineRecognizer.kt b/kotlin-api-examples/OnlineRecognizer.kt new file mode 120000 index 000000000..0fdf68577 --- /dev/null +++ b/kotlin-api-examples/OnlineRecognizer.kt @@ -0,0 +1 @@ +../sherpa-onnx/kotlin-api/OnlineRecognizer.kt \ No newline at end of file diff --git a/kotlin-api-examples/OnlineStream.kt b/kotlin-api-examples/OnlineStream.kt new file mode 120000 index 000000000..1c948adbe --- /dev/null +++ b/kotlin-api-examples/OnlineStream.kt @@ -0,0 +1 @@ +../sherpa-onnx/kotlin-api/OnlineStream.kt \ No newline at end of file diff --git a/kotlin-api-examples/SherpaOnnx.kt b/kotlin-api-examples/SherpaOnnx.kt deleted file mode 120000 index 6ebb58f64..000000000 --- a/kotlin-api-examples/SherpaOnnx.kt +++ /dev/null @@ -1 +0,0 @@ -../android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt \ No newline at end of file diff --git a/kotlin-api-examples/SherpaOnnx2Pass.kt b/kotlin-api-examples/SherpaOnnx2Pass.kt deleted file mode 120000 index 48756bcbc..000000000 --- a/kotlin-api-examples/SherpaOnnx2Pass.kt +++ /dev/null @@ -1 +0,0 @@ -../android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt \ No newline at end of file diff --git a/kotlin-api-examples/Speaker.kt b/kotlin-api-examples/Speaker.kt index 5a1f0d51c..3f0dd5a8a 120000 --- a/kotlin-api-examples/Speaker.kt +++ b/kotlin-api-examples/Speaker.kt @@ -1 +1 @@ -../android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt \ No newline at end of file +../sherpa-onnx/kotlin-api/Speaker.kt \ No newline at end of file diff --git a/kotlin-api-examples/SpokenLanguageIdentification.kt b/kotlin-api-examples/SpokenLanguageIdentification.kt index 702a54c97..576a26422 120000 --- a/kotlin-api-examples/SpokenLanguageIdentification.kt +++ b/kotlin-api-examples/SpokenLanguageIdentification.kt @@ -1 +1 @@ -../android/SherpaOnnxSpokenLanguageIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/slid/SpokenLanguageIdentification.kt \ No newline at end of file +../sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt \ No newline at end of file diff --git a/kotlin-api-examples/Vad.kt b/kotlin-api-examples/Vad.kt index 8e553dbe5..0f70f9883 120000 --- a/kotlin-api-examples/Vad.kt +++ b/kotlin-api-examples/Vad.kt @@ -1 +1 @@ -../android/SherpaOnnxVad/app/src/main/java/com/k2fsa/sherpa/onnx/Vad.kt \ No newline at end of file +../sherpa-onnx/kotlin-api/Vad.kt \ No newline at end of file diff --git a/kotlin-api-examples/WaveReader.kt b/kotlin-api-examples/WaveReader.kt index cd487a6cb..d24443934 120000 --- a/kotlin-api-examples/WaveReader.kt +++ b/kotlin-api-examples/WaveReader.kt @@ -1 +1 @@ -../android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt \ No newline at end of file +../sherpa-onnx/kotlin-api/WaveReader.kt \ No newline at end of file diff --git a/kotlin-api-examples/run.sh b/kotlin-api-examples/run.sh index f14e169cd..cb9c04b55 100755 --- a/kotlin-api-examples/run.sh +++ b/kotlin-api-examples/run.sh @@ -44,9 +44,23 @@ function testSpeakerEmbeddingExtractor() { if [ ! -f ./speaker2_a_cn_16k.wav ]; then curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav fi + + out_filename=test_speaker_id.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_speaker_id.kt \ + OnlineStream.kt \ + Speaker.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename } -function testAsr() { + +function testOnlineAsr() { if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then git lfs install git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 @@ -57,6 +71,20 @@ function testAsr() { tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 fi + + out_filename=test_online_asr.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_online_asr.kt \ + FeatureConfig.kt \ + OnlineRecognizer.kt \ + OnlineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename } function testTts() { @@ -65,16 +93,42 @@ function testTts() { tar xf vits-piper-en_US-amy-low.tar.bz2 rm vits-piper-en_US-amy-low.tar.bz2 fi + + out_filename=test_tts.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_tts.kt \ + Tts.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename } + function testAudioTagging() { if [ ! -d sherpa-onnx-zipformer-audio-tagging-2024-04-09 ]; then curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 fi + + out_filename=test_audio_tagging.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_audio_tagging.kt \ + AudioTagging.kt \ + OfflineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename } + function testSpokenLanguageIdentification() { if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2 @@ -87,50 +141,44 @@ function testSpokenLanguageIdentification() { tar xvf spoken-language-identification-test-wavs.tar.bz2 rm spoken-language-identification-test-wavs.tar.bz2 fi -} -function test() { - testSpokenLanguageIdentification - testAudioTagging - testSpeakerEmbeddingExtractor - testAsr - testTts -} + out_filename=test_language_id.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_language_id.kt \ + SpokenLanguageIdentification.kt \ + OfflineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt -test - -kotlinc-jvm -include-runtime -d main.jar \ - AudioTagging.kt \ - Main.kt \ - OfflineStream.kt \ - SherpaOnnx.kt \ - Speaker.kt \ - SpokenLanguageIdentification.kt \ - Tts.kt \ - WaveReader.kt \ - faked-asset-manager.kt \ - faked-log.kt - -ls -lh main.jar - -java -Djava.library.path=../build/lib -jar main.jar - -function testTwoPass() { - if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then - curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 - tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 - rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 - fi + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename +} +function testOfflineAsr() { if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 rm sherpa-onnx-whisper-tiny.en.tar.bz2 fi - kotlinc-jvm -include-runtime -d 2pass.jar test-2pass.kt WaveReader.kt SherpaOnnx2Pass.kt faked-asset-manager.kt - ls -lh 2pass.jar - java -Djava.library.path=../build/lib -jar 2pass.jar + out_filename=test_offline_asr.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_offline_asr.kt \ + FeatureConfig.kt \ + OfflineRecognizer.kt \ + OfflineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt + + ls -lh $out_filename + java -Djava.library.path=../build/lib -jar $out_filename } -testTwoPass +testSpeakerEmbeddingExtractor +testOnlineAsr +testTts +testAudioTagging +testSpokenLanguageIdentification +testOfflineAsr diff --git a/kotlin-api-examples/test-2pass.kt b/kotlin-api-examples/test-2pass.kt deleted file mode 100644 index 7ce5e4569..000000000 --- a/kotlin-api-examples/test-2pass.kt +++ /dev/null @@ -1,49 +0,0 @@ -package com.k2fsa.sherpa.onnx - -fun main() { - test2Pass() -} - -fun test2Pass() { - val firstPass = createFirstPass() - val secondPass = createSecondPass() - - val waveFilename = "./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav" - - var objArray = WaveReader.readWaveFromFile( - filename = waveFilename, - ) - var samples: FloatArray = objArray[0] as FloatArray - var sampleRate: Int = objArray[1] as Int - - firstPass.acceptWaveform(samples, sampleRate = sampleRate) - while (firstPass.isReady()) { - firstPass.decode() - } - - var text = firstPass.text - println("First pass text: $text") - - text = secondPass.decode(samples, sampleRate) - println("Second pass text: $text") -} - -fun createFirstPass(): SherpaOnnx { - val config = OnlineRecognizerConfig( - featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), - modelConfig = getModelConfig(type = 1)!!, - endpointConfig = getEndpointConfig(), - enableEndpoint = true, - ) - - return SherpaOnnx(config = config) -} - -fun createSecondPass(): SherpaOnnxOffline { - val config = OfflineRecognizerConfig( - featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), - modelConfig = getOfflineModelConfig(type = 2)!!, - ) - - return SherpaOnnxOffline(config = config) -} diff --git a/kotlin-api-examples/test_audio_tagging.kt b/kotlin-api-examples/test_audio_tagging.kt new file mode 100644 index 000000000..7bd7fd127 --- /dev/null +++ b/kotlin-api-examples/test_audio_tagging.kt @@ -0,0 +1,49 @@ +package com.k2fsa.sherpa.onnx + +fun main() { + testAudioTagging() +} + +fun testAudioTagging() { + val config = AudioTaggingConfig( + model=AudioTaggingModelConfig( + zipformer=OfflineZipformerAudioTaggingModelConfig( + model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx", + ), + numThreads=1, + debug=true, + provider="cpu", + ), + labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv", + topK=5, + ) + val tagger = AudioTagging(config=config) + + val testFiles = arrayOf( + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav", + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav", + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav", + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav", + ) + println("----------") + for (waveFilename in testFiles) { + val stream = tagger.createStream() + + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + stream.acceptWaveform(samples, sampleRate = sampleRate) + val events = tagger.compute(stream) + stream.release() + + println(waveFilename) + println(events) + println("----------") + } + + tagger.release() +} + diff --git a/kotlin-api-examples/test_language_id.kt b/kotlin-api-examples/test_language_id.kt new file mode 100644 index 000000000..7e1dcda1b --- /dev/null +++ b/kotlin-api-examples/test_language_id.kt @@ -0,0 +1,43 @@ +package com.k2fsa.sherpa.onnx + +fun main() { + testSpokenLanguageIdentifcation() +} + +fun testSpokenLanguageIdentifcation() { + val config = SpokenLanguageIdentificationConfig( + whisper = SpokenLanguageIdentificationWhisperConfig( + encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx", + decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx", + tailPaddings = 33, + ), + numThreads=1, + debug=true, + provider="cpu", + ) + val slid = SpokenLanguageIdentification(config=config) + + val testFiles = arrayOf( + "./spoken-language-identification-test-wavs/ar-arabic.wav", + "./spoken-language-identification-test-wavs/bg-bulgarian.wav", + "./spoken-language-identification-test-wavs/de-german.wav", + ) + + for (waveFilename in testFiles) { + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + val stream = slid.createStream() + stream.acceptWaveform(samples, sampleRate = sampleRate) + val lang = slid.compute(stream) + stream.release() + println(waveFilename) + println(lang) + } + + slid.release() +} + diff --git a/kotlin-api-examples/test_offline_asr.kt b/kotlin-api-examples/test_offline_asr.kt new file mode 100644 index 000000000..d218e4b6a --- /dev/null +++ b/kotlin-api-examples/test_offline_asr.kt @@ -0,0 +1,32 @@ +package com.k2fsa.sherpa.onnx + +fun main() { + val recognizer = createOfflineRecognizer() + + val waveFilename = "./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav" + + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + val stream = recognizer.createStream() + stream.acceptWaveform(samples, sampleRate=sampleRate) + recognizer.decode(stream) + + val result = recognizer.getResult(stream) + println(result) + + stream.release() + recognizer.release() +} + +fun createOfflineRecognizer(): OfflineRecognizer { + val config = OfflineRecognizerConfig( + featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), + modelConfig = getOfflineModelConfig(type = 2)!!, + ) + + return OfflineRecognizer(config = config) +} diff --git a/kotlin-api-examples/test_online_asr.kt b/kotlin-api-examples/test_online_asr.kt new file mode 100644 index 000000000..d6236f8af --- /dev/null +++ b/kotlin-api-examples/test_online_asr.kt @@ -0,0 +1,87 @@ +package com.k2fsa.sherpa.onnx + +fun main() { + testOnlineAsr("transducer") + testOnlineAsr("zipformer2-ctc") +} + +fun testOnlineAsr(type: String) { + val featConfig = FeatureConfig( + sampleRate = 16000, + featureDim = 80, + ) + + val waveFilename: String + val modelConfig: OnlineModelConfig = when (type) { + "transducer" -> { + waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav" + // please refer to + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + // to dowload pre-trained models + OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", + decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", + joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", + ), + tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", + numThreads = 1, + debug = false, + ) + } + "zipformer2-ctc" -> { + waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav" + OnlineModelConfig( + zipformer2Ctc = OnlineZipformer2CtcModelConfig( + model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx", + ), + tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt", + numThreads = 1, + debug = false, + ) + } + else -> throw IllegalArgumentException(type) + } + + val endpointConfig = EndpointConfig() + + val lmConfig = OnlineLMConfig() + + val config = OnlineRecognizerConfig( + modelConfig = modelConfig, + lmConfig = lmConfig, + featConfig = featConfig, + endpointConfig = endpointConfig, + enableEndpoint = true, + decodingMethod = "greedy_search", + maxActivePaths = 4, + ) + + val recognizer = OnlineRecognizer( + config = config, + ) + + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + val stream = recognizer.createStream() + stream.acceptWaveform(samples, sampleRate = sampleRate) + while (recognizer.isReady(stream)) { + recognizer.decode(stream) + } + + val tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds + stream.acceptWaveform(tailPaddings, sampleRate = sampleRate) + stream.inputFinished() + while (recognizer.isReady(stream)) { + recognizer.decode(stream) + } + + println("results: ${recognizer.getResult(stream).text}") + + stream.release() + recognizer.release() +} diff --git a/kotlin-api-examples/test_speaker_id.kt b/kotlin-api-examples/test_speaker_id.kt new file mode 100644 index 000000000..e7126ae17 --- /dev/null +++ b/kotlin-api-examples/test_speaker_id.kt @@ -0,0 +1,62 @@ +package com.k2fsa.sherpa.onnx + +fun main() { + testSpeakerRecognition() +} + +fun testSpeakerRecognition() { + val config = SpeakerEmbeddingExtractorConfig( + model="./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx", + ) + val extractor = SpeakerEmbeddingExtractor(config = config) + + val embedding1a = computeEmbedding(extractor, "./speaker1_a_cn_16k.wav") + val embedding2a = computeEmbedding(extractor, "./speaker2_a_cn_16k.wav") + val embedding1b = computeEmbedding(extractor, "./speaker1_b_cn_16k.wav") + + var manager = SpeakerEmbeddingManager(extractor.dim()) + var ok = manager.add(name = "speaker1", embedding=embedding1a) + check(ok) + + manager.add(name = "speaker2", embedding=embedding2a) + check(ok) + + var name = manager.search(embedding=embedding1b, threshold=0.5f) + check(name == "speaker1") + + manager.release() + + manager = SpeakerEmbeddingManager(extractor.dim()) + val embeddingList = mutableListOf(embedding1a, embedding1b) + ok = manager.add(name = "s1", embedding=embeddingList.toTypedArray()) + check(ok) + + name = manager.search(embedding=embedding1b, threshold=0.5f) + check(name == "s1") + + name = manager.search(embedding=embedding2a, threshold=0.5f) + check(name.length == 0) + + manager.release() + extractor.release() + println("Speaker ID test done!") +} + +fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray { + var objArray = WaveReader.readWaveFromFile( + filename = filename, + ) + var samples: FloatArray = objArray[0] as FloatArray + var sampleRate: Int = objArray[1] as Int + + val stream = extractor.createStream() + stream.acceptWaveform(sampleRate = sampleRate, samples=samples) + stream.inputFinished() + check(extractor.isReady(stream)) + + val embedding = extractor.compute(stream) + + stream.release() + + return embedding +} diff --git a/kotlin-api-examples/test_tts.kt b/kotlin-api-examples/test_tts.kt new file mode 100644 index 000000000..22bcd8c2a --- /dev/null +++ b/kotlin-api-examples/test_tts.kt @@ -0,0 +1,30 @@ +package com.k2fsa.sherpa.onnx + +fun main() { + testTts() +} + +fun testTts() { + // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models + // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 + var config = OfflineTtsConfig( + model=OfflineTtsModelConfig( + vits=OfflineTtsVitsModelConfig( + model="./vits-piper-en_US-amy-low/en_US-amy-low.onnx", + tokens="./vits-piper-en_US-amy-low/tokens.txt", + dataDir="./vits-piper-en_US-amy-low/espeak-ng-data", + ), + numThreads=1, + debug=true, + ) + ) + val tts = OfflineTts(config=config) + val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback) + audio.save(filename="test-en.wav") + tts.release() + println("Saved to test-en.wav") +} + +fun callback(samples: FloatArray): Unit { + println("callback got called with ${samples.size} samples"); +} diff --git a/scripts/apk/build-apk-asr.sh.in b/scripts/apk/build-apk-asr.sh.in new file mode 100644 index 000000000..468959f08 --- /dev/null +++ b/scripts/apk/build-apk-asr.sh.in @@ -0,0 +1,91 @@ +#!/usr/bin/env bash +# +# Auto generated! Please DO NOT EDIT! + +# Please set the environment variable ANDROID_NDK +# before running this script + +# Inside the $ANDROID_NDK directory, you can find a binary ndk-build +# and some other files like the file "build/cmake/android.toolchain.cmake" + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) + +log "Building streaming ASR APK for sherpa-onnx v${SHERPA_ONNX_VERSION}" + +export SHERPA_ONNX_ENABLE_TTS=OFF + +log "====================arm64-v8a=================" +./build-android-arm64-v8a.sh +log "====================armv7-eabi================" +./build-android-armv7-eabi.sh +log "====================x86-64====================" +./build-android-x86-64.sh +log "====================x86====================" +./build-android-x86.sh + +mkdir -p apks + +{% for model in model_list %} +pushd ./android/SherpaOnnx/app/src/main/assets/ +model_name={{ model.model_name }} +type={{ model.idx }} +lang={{ model.lang }} +short_name={{ model.short_name }} + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/${model_name}.tar.bz2 +tar xvf ${model_name}.tar.bz2 + +{{ model.cmd }} + +rm -rf *.tar.bz2 +ls -lh $model_name + +popd +# Now we are at the project root directory + +git checkout . +pushd android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx +sed -i.bak s/"type = 0/type = $type/" ./MainActivity.kt +git diff +popd + +for arch in arm64-v8a armeabi-v7a x86_64 x86; do + log "------------------------------------------------------------" + log "build ASR apk for $arch" + log "------------------------------------------------------------" + src_arch=$arch + if [ $arch == "armeabi-v7a" ]; then + src_arch=armv7-eabi + elif [ $arch == "x86_64" ]; then + src_arch=x86-64 + fi + + ls -lh ./build-android-$src_arch/install/lib/*.so + + cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnx/app/src/main/jniLibs/$arch/ + + pushd ./android/SherpaOnnx + sed -i.bak s/2048/9012/g ./gradle.properties + git diff ./gradle.properties + ./gradlew assembleRelease + popd + + mv android/SherpaOnnx/app/build/outputs/apk/release/app-release-unsigned.apk ./apks/sherpa-onnx-${SHERPA_ONNX_VERSION}-$arch-asr-$lang-$short_name.apk + ls -lh apks + rm -v ./android/SherpaOnnx/app/src/main/jniLibs/$arch/*.so +done + +rm -rf ./android/SherpaOnnx/app/src/main/assets/$model_name +{% endfor %} + +git checkout . + +ls -lh apks/ diff --git a/scripts/apk/build-apk-audio-tagging-wearos.sh.in b/scripts/apk/build-apk-audio-tagging-wearos.sh.in index bc28f5268..7d127a21b 100644 --- a/scripts/apk/build-apk-audio-tagging-wearos.sh.in +++ b/scripts/apk/build-apk-audio-tagging-wearos.sh.in @@ -29,6 +29,8 @@ log "====================x86-64====================" log "====================x86====================" ./build-android-x86.sh +export SHERPA_ONNX_ENABLE_TTS=OFF + mkdir -p apks {% for model in model_list %} diff --git a/scripts/apk/build-apk-audio-tagging.sh.in b/scripts/apk/build-apk-audio-tagging.sh.in index 2c7024644..8cb17f3bb 100644 --- a/scripts/apk/build-apk-audio-tagging.sh.in +++ b/scripts/apk/build-apk-audio-tagging.sh.in @@ -29,6 +29,8 @@ log "====================x86-64====================" log "====================x86====================" ./build-android-x86.sh +export SHERPA_ONNX_ENABLE_TTS=OFF + mkdir -p apks {% for model in model_list %} diff --git a/scripts/apk/build-apk-slid.sh.in b/scripts/apk/build-apk-slid.sh.in index d5c424d9a..27b56593b 100644 --- a/scripts/apk/build-apk-slid.sh.in +++ b/scripts/apk/build-apk-slid.sh.in @@ -29,6 +29,8 @@ log "====================x86-64====================" log "====================x86====================" ./build-android-x86.sh +export SHERPA_ONNX_ENABLE_TTS=OFF + mkdir -p apks {% for model in model_list %} diff --git a/scripts/apk/build-apk-speaker-identification.sh.in b/scripts/apk/build-apk-speaker-identification.sh.in index b4dcf2d16..11ac2b747 100644 --- a/scripts/apk/build-apk-speaker-identification.sh.in +++ b/scripts/apk/build-apk-speaker-identification.sh.in @@ -29,6 +29,8 @@ log "====================x86-64====================" log "====================x86====================" ./build-android-x86.sh +export SHERPA_ONNX_ENABLE_TTS=OFF + mkdir -p apks {% for model in model_list %} diff --git a/scripts/apk/build-apk-tts-engine.sh.in b/scripts/apk/build-apk-tts-engine.sh.in index 80e34df3a..902f6f477 100644 --- a/scripts/apk/build-apk-tts-engine.sh.in +++ b/scripts/apk/build-apk-tts-engine.sh.in @@ -29,6 +29,8 @@ log "====================x86-64====================" log "====================x86====================" ./build-android-x86.sh +export SHERPA_ONNX_ENABLE_TTS=ON + mkdir -p apks {% for tts_model in tts_model_list %} diff --git a/scripts/apk/build-apk-tts.sh.in b/scripts/apk/build-apk-tts.sh.in index 2caf3788a..73139790f 100644 --- a/scripts/apk/build-apk-tts.sh.in +++ b/scripts/apk/build-apk-tts.sh.in @@ -29,6 +29,8 @@ log "====================x86-64====================" log "====================x86====================" ./build-android-x86.sh +export SHERPA_ONNX_ENABLE_TTS=ON + mkdir -p apks {% for tts_model in tts_model_list %} diff --git a/scripts/apk/generate-asr-apk-script.py b/scripts/apk/generate-asr-apk-script.py new file mode 100755 index 000000000..77cad1250 --- /dev/null +++ b/scripts/apk/generate-asr-apk-script.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 + +import argparse +from dataclasses import dataclass +from typing import List, Optional + +import jinja2 + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--total", + type=int, + default=1, + help="Number of runners", + ) + parser.add_argument( + "--index", + type=int, + default=0, + help="Index of the current runner", + ) + return parser.parse_args() + + +@dataclass +class Model: + # We will download + # https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/{model_name}.tar.bz2 + model_name: str + + # The type of the model, e..g, 0, 1, 2. It is hardcoded in the kotlin code + idx: int + + # e.g., zh, en, zh_en + lang: str + + # e.g., whisper, paraformer, zipformer + short_name: str = "" + + # cmd is used to remove extra file from the model directory + cmd: str = "" + + +def get_models(): + models = [ + Model( + model_name="sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20", + idx=8, + lang="bilingual_zh_en", + short_name="zipformer", + cmd=""" + pushd $model_name + rm -v decoder-epoch-99-avg-1.int8.onnx + rm -v encoder-epoch-99-avg-1.onnx + rm -v joiner-epoch-99-avg-1.onnx + + rm -v *.sh + rm -v .gitattributes + rm -v *state* + rm -rfv test_wavs + + ls -lh + + popd + """, + ), + ] + + return models + + +def main(): + args = get_args() + index = args.index + total = args.total + assert 0 <= index < total, (index, total) + + all_model_list = get_models() + + num_models = len(all_model_list) + + num_per_runner = num_models // total + if num_per_runner <= 0: + raise ValueError(f"num_models: {num_models}, num_runners: {total}") + + start = index * num_per_runner + end = start + num_per_runner + + remaining = num_models - args.total * num_per_runner + + print(f"{index}/{total}: {start}-{end}/{num_models}") + + d = dict() + d["model_list"] = all_model_list[start:end] + if index < remaining: + s = args.total * num_per_runner + index + d["model_list"].append(all_model_list[s]) + print(f"{s}/{num_models}") + + filename_list = [ + "./build-apk-asr.sh", + ] + for filename in filename_list: + environment = jinja2.Environment() + with open(f"{filename}.in") as f: + s = f.read() + template = environment.from_string(s) + + s = template.render(**d) + with open(filename, "w") as f: + print(s, file=f) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc index d380ec18b..9eb5b64a3 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc @@ -82,7 +82,7 @@ bool OfflineTtsVitsModelConfig::Validate() const { for (const auto &f : required_files) { if (!FileExists(dict_dir + "/" + f)) { - SHERPA_ONNX_LOGE("'%s/%s' does not exist.", data_dir.c_str(), + SHERPA_ONNX_LOGE("'%s/%s' does not exist.", dict_dir.c_str(), f.c_str()); return false; } diff --git a/sherpa-onnx/jni/CMakeLists.txt b/sherpa-onnx/jni/CMakeLists.txt index bb08bbf35..339e945a5 100644 --- a/sherpa-onnx/jni/CMakeLists.txt +++ b/sherpa-onnx/jni/CMakeLists.txt @@ -12,8 +12,15 @@ endif() set(sources audio-tagging.cc jni.cc + keyword-spotter.cc + offline-recognizer.cc offline-stream.cc + online-recognizer.cc + online-stream.cc + speaker-embedding-extractor.cc + speaker-embedding-manager.cc spoken-language-identification.cc + voice-activity-detector.cc ) if(SHERPA_ONNX_ENABLE_TTS) diff --git a/sherpa-onnx/jni/common.h b/sherpa-onnx/jni/common.h index d06350f86..fede5421c 100644 --- a/sherpa-onnx/jni/common.h +++ b/sherpa-onnx/jni/common.h @@ -6,6 +6,8 @@ #define SHERPA_ONNX_JNI_COMMON_H_ #if __ANDROID_API__ >= 9 +#include + #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index a8b2e4b6d..e70f5e608 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -4,1530 +4,43 @@ // 2022 Pingfeng Luo // 2023 Zhaoming -// TODO(fangjun): Add documentation to functions/methods in this file -// and also show how to use them with kotlin, possibly with java. - -#include -#include -#include -#include - -#include "sherpa-onnx/csrc/keyword-spotter.h" -#include "sherpa-onnx/csrc/macros.h" -#include "sherpa-onnx/csrc/offline-recognizer.h" -#include "sherpa-onnx/csrc/online-recognizer.h" -#include "sherpa-onnx/csrc/onnx-utils.h" -#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" -#include "sherpa-onnx/csrc/speaker-embedding-manager.h" -#include "sherpa-onnx/csrc/voice-activity-detector.h" -#include "sherpa-onnx/csrc/wave-reader.h" -#include "sherpa-onnx/csrc/wave-writer.h" -#include "sherpa-onnx/jni/common.h" - -namespace sherpa_onnx { - -class SherpaOnnx { - public: -#if __ANDROID_API__ >= 9 - SherpaOnnx(AAssetManager *mgr, const OnlineRecognizerConfig &config) - : recognizer_(mgr, config), stream_(recognizer_.CreateStream()) {} -#endif - - explicit SherpaOnnx(const OnlineRecognizerConfig &config) - : recognizer_(config), stream_(recognizer_.CreateStream()) {} - - void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) { - if (input_sample_rate_ == -1) { - input_sample_rate_ = sample_rate; - } - - stream_->AcceptWaveform(sample_rate, samples, n); - } - - void InputFinished() const { - std::vector tail_padding(input_sample_rate_ * 0.6, 0); - stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(), - tail_padding.size()); - stream_->InputFinished(); - } - - std::string GetText() const { - auto result = recognizer_.GetResult(stream_.get()); - return result.text; - } - - const std::vector GetTokens() const { - auto result = recognizer_.GetResult(stream_.get()); - return result.tokens; - } - - bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); } - - bool IsReady() const { return recognizer_.IsReady(stream_.get()); } - - // If keywords is an empty string, it just recreates the decoding stream - // If keywords is not empty, it will create a new decoding stream with - // the given keywords appended to the default keywords. - void Reset(bool recreate, const std::string &keywords = {}) { - if (keywords.empty()) { - if (recreate) { - stream_ = recognizer_.CreateStream(); - } else { - recognizer_.Reset(stream_.get()); - } - } else { - auto stream = recognizer_.CreateStream(keywords); - // Set new keywords failed, the stream_ will not be updated. - if (stream != nullptr) { - stream_ = std::move(stream); - } else { - SHERPA_ONNX_LOGE("Failed to set keywords: %s", keywords.c_str()); - } - } - } - - void Decode() const { recognizer_.DecodeStream(stream_.get()); } - - private: - OnlineRecognizer recognizer_; - std::unique_ptr stream_; - int32_t input_sample_rate_ = -1; -}; - -class SherpaOnnxOffline { - public: -#if __ANDROID_API__ >= 9 - SherpaOnnxOffline(AAssetManager *mgr, const OfflineRecognizerConfig &config) - : recognizer_(mgr, config) {} -#endif - - explicit SherpaOnnxOffline(const OfflineRecognizerConfig &config) - : recognizer_(config) {} - - std::string Decode(int32_t sample_rate, const float *samples, int32_t n) { - auto stream = recognizer_.CreateStream(); - stream->AcceptWaveform(sample_rate, samples, n); - - recognizer_.DecodeStream(stream.get()); - return stream->GetResult().text; - } - - private: - OfflineRecognizer recognizer_; -}; - -class SherpaOnnxVad { - public: -#if __ANDROID_API__ >= 9 - SherpaOnnxVad(AAssetManager *mgr, const VadModelConfig &config) - : vad_(mgr, config) {} -#endif - - explicit SherpaOnnxVad(const VadModelConfig &config) : vad_(config) {} - - void AcceptWaveform(const float *samples, int32_t n) { - vad_.AcceptWaveform(samples, n); - } - - bool Empty() const { return vad_.Empty(); } - - void Pop() { vad_.Pop(); } - - void Clear() { vad_.Clear(); } - - const SpeechSegment &Front() const { return vad_.Front(); } - - bool IsSpeechDetected() const { return vad_.IsSpeechDetected(); } - - void Reset() { vad_.Reset(); } - - private: - VoiceActivityDetector vad_; -}; - -class SherpaOnnxKws { - public: -#if __ANDROID_API__ >= 9 - SherpaOnnxKws(AAssetManager *mgr, const KeywordSpotterConfig &config) - : keyword_spotter_(mgr, config), - stream_(keyword_spotter_.CreateStream()) {} -#endif - - explicit SherpaOnnxKws(const KeywordSpotterConfig &config) - : keyword_spotter_(config), stream_(keyword_spotter_.CreateStream()) {} - - void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) { - if (input_sample_rate_ == -1) { - input_sample_rate_ = sample_rate; - } - - stream_->AcceptWaveform(sample_rate, samples, n); - } - - void InputFinished() const { - std::vector tail_padding(input_sample_rate_ * 0.6, 0); - stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(), - tail_padding.size()); - stream_->InputFinished(); - } - - // If keywords is an empty string, it just recreates the decoding stream - // always returns true in this case. - // If keywords is not empty, it will create a new decoding stream with - // the given keywords appended to the default keywords. - // Return false if errors occurred when adding keywords, true otherwise. - bool Reset(const std::string &keywords = {}) { - if (keywords.empty()) { - stream_ = keyword_spotter_.CreateStream(); - return true; - } else { - auto stream = keyword_spotter_.CreateStream(keywords); - // Set new keywords failed, the stream_ will not be updated. - if (stream == nullptr) { - return false; - } else { - stream_ = std::move(stream); - return true; - } - } - } - - std::string GetKeyword() const { - auto result = keyword_spotter_.GetResult(stream_.get()); - return result.keyword; - } - - std::vector GetTokens() const { - auto result = keyword_spotter_.GetResult(stream_.get()); - return result.tokens; - } - - bool IsReady() const { return keyword_spotter_.IsReady(stream_.get()); } - - void Decode() const { keyword_spotter_.DecodeStream(stream_.get()); } - - private: - KeywordSpotter keyword_spotter_; - std::unique_ptr stream_; - int32_t input_sample_rate_ = -1; -}; - -class SherpaOnnxSpeakerEmbeddingExtractorStream { - public: - explicit SherpaOnnxSpeakerEmbeddingExtractorStream( - std::unique_ptr stream) - : stream_(std::move(stream)) {} - - void AcceptWaveform(int32_t sample_rate, const float *samples, - int32_t n) const { - stream_->AcceptWaveform(sample_rate, samples, n); - } - - void InputFinished() const { stream_->InputFinished(); } - - OnlineStream *Get() const { return stream_.get(); } - - private: - std::unique_ptr stream_; -}; - -class SherpaOnnxSpeakerEmbeddingExtractor { - public: -#if __ANDROID_API__ >= 9 - SherpaOnnxSpeakerEmbeddingExtractor( - AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) - : extractor_(mgr, config) {} -#endif - - explicit SherpaOnnxSpeakerEmbeddingExtractor( - const SpeakerEmbeddingExtractorConfig &config) - : extractor_(config) {} - - int32_t Dim() const { return extractor_.Dim(); } - - bool IsReady(const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const { - return extractor_.IsReady(stream->Get()); - } - - SherpaOnnxSpeakerEmbeddingExtractorStream *CreateStream() const { - return new SherpaOnnxSpeakerEmbeddingExtractorStream( - extractor_.CreateStream()); - } - - std::vector Compute( - const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const { - return extractor_.Compute(stream->Get()); - } - - private: - SpeakerEmbeddingExtractor extractor_; -}; - -static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig( - JNIEnv *env, jobject config) { - SpeakerEmbeddingExtractorConfig ans; - - jclass cls = env->GetObjectClass(config); - - jfieldID fid = env->GetFieldID(cls, "model", "Ljava/lang/String;"); - jstring s = (jstring)env->GetObjectField(config, fid); - const char *p = env->GetStringUTFChars(s, nullptr); - - ans.model = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(cls, "numThreads", "I"); - ans.num_threads = env->GetIntField(config, fid); - - fid = env->GetFieldID(cls, "debug", "Z"); - ans.debug = env->GetBooleanField(config, fid); - - fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.provider = p; - env->ReleaseStringUTFChars(s, p); - - return ans; -} - -static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { - OnlineRecognizerConfig ans; - - jclass cls = env->GetObjectClass(config); - jfieldID fid; - - // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html - // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html - - //---------- decoding ---------- - fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;"); - jstring s = (jstring)env->GetObjectField(config, fid); - const char *p = env->GetStringUTFChars(s, nullptr); - ans.decoding_method = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(cls, "maxActivePaths", "I"); - ans.max_active_paths = env->GetIntField(config, fid); - - fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.hotwords_file = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(cls, "hotwordsScore", "F"); - ans.hotwords_score = env->GetFloatField(config, fid); - - //---------- feat config ---------- - fid = env->GetFieldID(cls, "featConfig", - "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); - jobject feat_config = env->GetObjectField(config, fid); - jclass feat_config_cls = env->GetObjectClass(feat_config); - - fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); - ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); - - fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); - ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); - - //---------- enable endpoint ---------- - fid = env->GetFieldID(cls, "enableEndpoint", "Z"); - ans.enable_endpoint = env->GetBooleanField(config, fid); - - //---------- endpoint_config ---------- - - fid = env->GetFieldID(cls, "endpointConfig", - "Lcom/k2fsa/sherpa/onnx/EndpointConfig;"); - jobject endpoint_config = env->GetObjectField(config, fid); - jclass endpoint_config_cls = env->GetObjectClass(endpoint_config); - - fid = env->GetFieldID(endpoint_config_cls, "rule1", - "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); - jobject rule1 = env->GetObjectField(endpoint_config, fid); - jclass rule_class = env->GetObjectClass(rule1); - - fid = env->GetFieldID(endpoint_config_cls, "rule2", - "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); - jobject rule2 = env->GetObjectField(endpoint_config, fid); - - fid = env->GetFieldID(endpoint_config_cls, "rule3", - "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); - jobject rule3 = env->GetObjectField(endpoint_config, fid); - - fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z"); - ans.endpoint_config.rule1.must_contain_nonsilence = - env->GetBooleanField(rule1, fid); - ans.endpoint_config.rule2.must_contain_nonsilence = - env->GetBooleanField(rule2, fid); - ans.endpoint_config.rule3.must_contain_nonsilence = - env->GetBooleanField(rule3, fid); - - fid = env->GetFieldID(rule_class, "minTrailingSilence", "F"); - ans.endpoint_config.rule1.min_trailing_silence = - env->GetFloatField(rule1, fid); - ans.endpoint_config.rule2.min_trailing_silence = - env->GetFloatField(rule2, fid); - ans.endpoint_config.rule3.min_trailing_silence = - env->GetFloatField(rule3, fid); - - fid = env->GetFieldID(rule_class, "minUtteranceLength", "F"); - ans.endpoint_config.rule1.min_utterance_length = - env->GetFloatField(rule1, fid); - ans.endpoint_config.rule2.min_utterance_length = - env->GetFloatField(rule2, fid); - ans.endpoint_config.rule3.min_utterance_length = - env->GetFloatField(rule3, fid); - - //---------- model config ---------- - fid = env->GetFieldID(cls, "modelConfig", - "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;"); - jobject model_config = env->GetObjectField(config, fid); - jclass model_config_cls = env->GetObjectClass(model_config); - - // transducer - fid = env->GetFieldID(model_config_cls, "transducer", - "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); - jobject transducer_config = env->GetObjectField(model_config, fid); - jclass transducer_config_cls = env->GetObjectClass(transducer_config); - - fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.transducer.encoder = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.transducer.decoder = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.transducer.joiner = p; - env->ReleaseStringUTFChars(s, p); - - // paraformer - fid = env->GetFieldID(model_config_cls, "paraformer", - "Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;"); - jobject paraformer_config = env->GetObjectField(model_config, fid); - jclass paraformer_config_cls = env->GetObjectClass(paraformer_config); - - fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(paraformer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.paraformer.encoder = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(paraformer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.paraformer.decoder = p; - env->ReleaseStringUTFChars(s, p); - - // streaming zipformer2 CTC - fid = - env->GetFieldID(model_config_cls, "zipformer2Ctc", - "Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;"); - jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid); - jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config); - - fid = - env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.zipformer2_ctc.model = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.tokens = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(model_config_cls, "numThreads", "I"); - ans.model_config.num_threads = env->GetIntField(model_config, fid); - - fid = env->GetFieldID(model_config_cls, "debug", "Z"); - ans.model_config.debug = env->GetBooleanField(model_config, fid); - - fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.provider = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.model_type = p; - env->ReleaseStringUTFChars(s, p); - - //---------- rnn lm model config ---------- - fid = env->GetFieldID(cls, "lmConfig", - "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); - jobject lm_model_config = env->GetObjectField(config, fid); - jclass lm_model_config_cls = env->GetObjectClass(lm_model_config); - - fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(lm_model_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.lm_config.model = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(lm_model_config_cls, "scale", "F"); - ans.lm_config.scale = env->GetFloatField(lm_model_config, fid); - - return ans; -} - -static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { - OfflineRecognizerConfig ans; - - jclass cls = env->GetObjectClass(config); - jfieldID fid; - - //---------- decoding ---------- - fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;"); - jstring s = (jstring)env->GetObjectField(config, fid); - const char *p = env->GetStringUTFChars(s, nullptr); - ans.decoding_method = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(cls, "maxActivePaths", "I"); - ans.max_active_paths = env->GetIntField(config, fid); - - fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.hotwords_file = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(cls, "hotwordsScore", "F"); - ans.hotwords_score = env->GetFloatField(config, fid); - - //---------- feat config ---------- - fid = env->GetFieldID(cls, "featConfig", - "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); - jobject feat_config = env->GetObjectField(config, fid); - jclass feat_config_cls = env->GetObjectClass(feat_config); - - fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); - ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); - - fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); - ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); - - //---------- model config ---------- - fid = env->GetFieldID(cls, "modelConfig", - "Lcom/k2fsa/sherpa/onnx/OfflineModelConfig;"); - jobject model_config = env->GetObjectField(config, fid); - jclass model_config_cls = env->GetObjectClass(model_config); - - fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.tokens = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(model_config_cls, "numThreads", "I"); - ans.model_config.num_threads = env->GetIntField(model_config, fid); - - fid = env->GetFieldID(model_config_cls, "debug", "Z"); - ans.model_config.debug = env->GetBooleanField(model_config, fid); - - fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.provider = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.model_type = p; - env->ReleaseStringUTFChars(s, p); - - // transducer - fid = env->GetFieldID(model_config_cls, "transducer", - "Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;"); - jobject transducer_config = env->GetObjectField(model_config, fid); - jclass transducer_config_cls = env->GetObjectClass(transducer_config); - - fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.transducer.encoder_filename = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.transducer.decoder_filename = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.transducer.joiner_filename = p; - env->ReleaseStringUTFChars(s, p); - - // paraformer - fid = env->GetFieldID(model_config_cls, "paraformer", - "Lcom/k2fsa/sherpa/onnx/OfflineParaformerModelConfig;"); - jobject paraformer_config = env->GetObjectField(model_config, fid); - jclass paraformer_config_cls = env->GetObjectClass(paraformer_config); - - fid = env->GetFieldID(paraformer_config_cls, "model", "Ljava/lang/String;"); - - s = (jstring)env->GetObjectField(paraformer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.paraformer.model = p; - env->ReleaseStringUTFChars(s, p); - - // whisper - fid = env->GetFieldID(model_config_cls, "whisper", - "Lcom/k2fsa/sherpa/onnx/OfflineWhisperModelConfig;"); - jobject whisper_config = env->GetObjectField(model_config, fid); - jclass whisper_config_cls = env->GetObjectClass(whisper_config); - - fid = env->GetFieldID(whisper_config_cls, "encoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(whisper_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.whisper.encoder = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(whisper_config_cls, "decoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(whisper_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.whisper.decoder = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(whisper_config_cls, "language", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(whisper_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.whisper.language = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(whisper_config_cls, "task", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(whisper_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.whisper.task = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(whisper_config_cls, "tailPaddings", "I"); - ans.model_config.whisper.tail_paddings = - env->GetIntField(whisper_config, fid); - - return ans; -} - -static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { - KeywordSpotterConfig ans; - - jclass cls = env->GetObjectClass(config); - jfieldID fid; - - // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html - // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html - - //---------- decoding ---------- - fid = env->GetFieldID(cls, "maxActivePaths", "I"); - ans.max_active_paths = env->GetIntField(config, fid); - - fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;"); - jstring s = (jstring)env->GetObjectField(config, fid); - const char *p = env->GetStringUTFChars(s, nullptr); - ans.keywords_file = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(cls, "keywordsScore", "F"); - ans.keywords_score = env->GetFloatField(config, fid); - - fid = env->GetFieldID(cls, "keywordsThreshold", "F"); - ans.keywords_threshold = env->GetFloatField(config, fid); - - fid = env->GetFieldID(cls, "numTrailingBlanks", "I"); - ans.num_trailing_blanks = env->GetIntField(config, fid); - - //---------- feat config ---------- - fid = env->GetFieldID(cls, "featConfig", - "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); - jobject feat_config = env->GetObjectField(config, fid); - jclass feat_config_cls = env->GetObjectClass(feat_config); - - fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); - ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); - - fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); - ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); - - //---------- model config ---------- - fid = env->GetFieldID(cls, "modelConfig", - "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;"); - jobject model_config = env->GetObjectField(config, fid); - jclass model_config_cls = env->GetObjectClass(model_config); - - // transducer - fid = env->GetFieldID(model_config_cls, "transducer", - "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); - jobject transducer_config = env->GetObjectField(model_config, fid); - jclass transducer_config_cls = env->GetObjectClass(transducer_config); - - fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.transducer.encoder = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.transducer.decoder = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.transducer.joiner = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.tokens = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(model_config_cls, "numThreads", "I"); - ans.model_config.num_threads = env->GetIntField(model_config, fid); - - fid = env->GetFieldID(model_config_cls, "debug", "Z"); - ans.model_config.debug = env->GetBooleanField(model_config, fid); - - fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.provider = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.model_config.model_type = p; - env->ReleaseStringUTFChars(s, p); - - return ans; -} - -static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) { - VadModelConfig ans; - - jclass cls = env->GetObjectClass(config); - jfieldID fid; - - // silero_vad - fid = env->GetFieldID(cls, "sileroVadModelConfig", - "Lcom/k2fsa/sherpa/onnx/SileroVadModelConfig;"); - jobject silero_vad_config = env->GetObjectField(config, fid); - jclass silero_vad_config_cls = env->GetObjectClass(silero_vad_config); - - fid = env->GetFieldID(silero_vad_config_cls, "model", "Ljava/lang/String;"); - auto s = (jstring)env->GetObjectField(silero_vad_config, fid); - auto p = env->GetStringUTFChars(s, nullptr); - ans.silero_vad.model = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(silero_vad_config_cls, "threshold", "F"); - ans.silero_vad.threshold = env->GetFloatField(silero_vad_config, fid); - - fid = env->GetFieldID(silero_vad_config_cls, "minSilenceDuration", "F"); - ans.silero_vad.min_silence_duration = - env->GetFloatField(silero_vad_config, fid); - - fid = env->GetFieldID(silero_vad_config_cls, "minSpeechDuration", "F"); - ans.silero_vad.min_speech_duration = - env->GetFloatField(silero_vad_config, fid); - - fid = env->GetFieldID(silero_vad_config_cls, "windowSize", "I"); - ans.silero_vad.window_size = env->GetIntField(silero_vad_config, fid); - - fid = env->GetFieldID(cls, "sampleRate", "I"); - ans.sample_rate = env->GetIntField(config, fid); - - fid = env->GetFieldID(cls, "numThreads", "I"); - ans.num_threads = env->GetIntField(config, fid); - - fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(config, fid); - p = env->GetStringUTFChars(s, nullptr); - ans.provider = p; - env->ReleaseStringUTFChars(s, p); - - fid = env->GetFieldID(cls, "debug", "Z"); - ans.debug = env->GetBooleanField(config, fid); - - return ans; -} - -} // namespace sherpa_onnx - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_new(JNIEnv *env, - jobject /*obj*/, - jobject asset_manager, - jobject _config) { -#if __ANDROID_API__ >= 9 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); - if (!mgr) { - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); - } -#endif - auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config); - SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str()); - - auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor( -#if __ANDROID_API__ >= 9 - mgr, -#endif - config); - - return (jlong)extractor; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile( - JNIEnv *env, jobject /*obj*/, jobject _config) { - auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config); - SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str()); - - if (!config.Validate()) { - SHERPA_ONNX_LOGE("Errors found in config!"); - } - - auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor(config); - - return (jlong)extractor; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_delete(JNIEnv *env, - jobject /*obj*/, - jlong ptr) { - delete reinterpret_cast( - ptr); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_createStream( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - auto stream = - reinterpret_cast(ptr) - ->CreateStream(); - - return (jlong)stream; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jboolean JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_isReady(JNIEnv *env, - jobject /*obj*/, - jlong ptr, - jlong stream_ptr) { - auto extractor = - reinterpret_cast(ptr); - auto stream = reinterpret_cast< - sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr); - return extractor->IsReady(stream); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jfloatArray JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_compute(JNIEnv *env, - jobject /*obj*/, - jlong ptr, - jlong stream_ptr) { - auto extractor = - reinterpret_cast(ptr); - auto stream = reinterpret_cast< - sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr); - - std::vector embedding = extractor->Compute(stream); - jfloatArray embedding_arr = env->NewFloatArray(embedding.size()); - env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(), - embedding.data()); - return embedding_arr; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_dim( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - auto extractor = - reinterpret_cast(ptr); - return extractor->Dim(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_delete( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - delete reinterpret_cast< - sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_acceptWaveform( - JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, - jint sample_rate) { - auto stream = reinterpret_cast< - sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr); - - jfloat *p = env->GetFloatArrayElements(samples, nullptr); - jsize n = env->GetArrayLength(samples); - stream->AcceptWaveform(sample_rate, p, n); - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_inputFinished( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - auto stream = reinterpret_cast< - sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr); - stream->InputFinished(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_new( - JNIEnv *env, jobject /*obj*/, jint dim) { - auto p = new sherpa_onnx::SpeakerEmbeddingManager(dim); - return (jlong)p; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_delete(JNIEnv *env, - jobject /*obj*/, - jlong ptr) { - auto manager = reinterpret_cast(ptr); - delete manager; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jboolean JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_add(JNIEnv *env, - jobject /*obj*/, - jlong ptr, jstring name, - jfloatArray embedding) { - auto manager = reinterpret_cast(ptr); - - jfloat *p = env->GetFloatArrayElements(embedding, nullptr); - jsize n = env->GetArrayLength(embedding); - - if (n != manager->Dim()) { - SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(), - static_cast(n)); - exit(-1); - } - - const char *p_name = env->GetStringUTFChars(name, nullptr); - - jboolean ok = manager->Add(p_name, p); - env->ReleaseStringUTFChars(name, p_name); - env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); - - return ok; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jboolean JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_addList( - JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name, - jobjectArray embedding_arr) { - auto manager = reinterpret_cast(ptr); - - int num_embeddings = env->GetArrayLength(embedding_arr); - if (num_embeddings == 0) { - return false; - } - - std::vector> embedding_list; - embedding_list.reserve(num_embeddings); - for (int32_t i = 0; i != num_embeddings; ++i) { - jfloatArray embedding = - (jfloatArray)env->GetObjectArrayElement(embedding_arr, i); - - jfloat *p = env->GetFloatArrayElements(embedding, nullptr); - jsize n = env->GetArrayLength(embedding); - - if (n != manager->Dim()) { - SHERPA_ONNX_LOGE("i: %d. Expected dim %d, given %d", i, manager->Dim(), - static_cast(n)); - exit(-1); - } - - embedding_list.push_back({p, p + n}); - env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); - } - - const char *p_name = env->GetStringUTFChars(name, nullptr); - - jboolean ok = manager->Add(p_name, embedding_list); - - env->ReleaseStringUTFChars(name, p_name); - - return ok; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jboolean JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_remove(JNIEnv *env, - jobject /*obj*/, - jlong ptr, - jstring name) { - auto manager = reinterpret_cast(ptr); - - const char *p_name = env->GetStringUTFChars(name, nullptr); - - jboolean ok = manager->Remove(p_name); - - env->ReleaseStringUTFChars(name, p_name); - - return ok; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jstring JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_search(JNIEnv *env, - jobject /*obj*/, - jlong ptr, - jfloatArray embedding, - jfloat threshold) { - auto manager = reinterpret_cast(ptr); - - jfloat *p = env->GetFloatArrayElements(embedding, nullptr); - jsize n = env->GetArrayLength(embedding); - - if (n != manager->Dim()) { - SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(), - static_cast(n)); - exit(-1); - } - - std::string name = manager->Search(p, threshold); - - env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); - - return env->NewStringUTF(name.c_str()); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jboolean JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_verify( - JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name, - jfloatArray embedding, jfloat threshold) { - auto manager = reinterpret_cast(ptr); - - jfloat *p = env->GetFloatArrayElements(embedding, nullptr); - jsize n = env->GetArrayLength(embedding); - - if (n != manager->Dim()) { - SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(), - static_cast(n)); - exit(-1); - } - - const char *p_name = env->GetStringUTFChars(name, nullptr); - - jboolean ok = manager->Verify(p_name, p, threshold); - - env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); - - env->ReleaseStringUTFChars(name, p_name); - - return ok; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jboolean JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_contains(JNIEnv *env, - jobject /*obj*/, - jlong ptr, - jstring name) { - auto manager = reinterpret_cast(ptr); - - const char *p_name = env->GetStringUTFChars(name, nullptr); - - jboolean ok = manager->Contains(p_name); - - env->ReleaseStringUTFChars(name, p_name); - - return ok; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jint JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_numSpeakers(JNIEnv *env, - jobject /*obj*/, - jlong ptr) { - auto manager = reinterpret_cast(ptr); - return manager->NumSpeakers(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jobjectArray JNICALL -Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - auto manager = reinterpret_cast(ptr); - std::vector all_speakers = manager->GetAllSpeakers(); - - jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( - all_speakers.size(), env->FindClass("java/lang/String"), nullptr); - - int32_t i = 0; - for (auto &s : all_speakers) { - jstring js = env->NewStringUTF(s.c_str()); - env->SetObjectArrayElement(obj_arr, i, js); - - ++i; - } - - return obj_arr; -} - -// see -// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables -jobject NewInteger(JNIEnv *env, int32_t value) { - jclass cls = env->FindClass("java/lang/Integer"); - jmethodID constructor = env->GetMethodID(cls, "", "(I)V"); - return env->NewObject(cls, constructor, value); -} - -jobject NewFloat(JNIEnv *env, float value) { - jclass cls = env->FindClass("java/lang/Float"); - jmethodID constructor = env->GetMethodID(cls, "", "(F)V"); - return env->NewObject(cls, constructor, value); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl( - JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, - jint sample_rate) { - const char *p_filename = env->GetStringUTFChars(filename, nullptr); - - jfloat *p = env->GetFloatArrayElements(samples, nullptr); - jsize n = env->GetArrayLength(samples); - - bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n); - - env->ReleaseStringUTFChars(filename, p_filename); - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); - - return ok; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new( - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { -#if __ANDROID_API__ >= 9 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); - if (!mgr) { - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); - } -#endif - auto config = sherpa_onnx::GetVadModelConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - auto model = new sherpa_onnx::SherpaOnnxVad( -#if __ANDROID_API__ >= 9 - mgr, -#endif - config); - - return (jlong)model; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile( - JNIEnv *env, jobject /*obj*/, jobject _config) { - auto config = sherpa_onnx::GetVadModelConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - auto model = new sherpa_onnx::SherpaOnnxVad(config); - - return (jlong)model; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_delete(JNIEnv *env, - jobject /*obj*/, - jlong ptr) { - delete reinterpret_cast(ptr); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_acceptWaveform( - JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) { - auto model = reinterpret_cast(ptr); - - jfloat *p = env->GetFloatArrayElements(samples, nullptr); - jsize n = env->GetArrayLength(samples); - - model->AcceptWaveform(p, n); - - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_empty(JNIEnv *env, - jobject /*obj*/, - jlong ptr) { - auto model = reinterpret_cast(ptr); - return model->Empty(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env, - jobject /*obj*/, - jlong ptr) { - auto model = reinterpret_cast(ptr); - model->Pop(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env, - jobject /*obj*/, - jlong ptr) { - auto model = reinterpret_cast(ptr); - model->Clear(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jobjectArray JNICALL -Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) { - const auto &front = - reinterpret_cast(ptr)->Front(); - - jfloatArray samples_arr = env->NewFloatArray(front.samples.size()); - env->SetFloatArrayRegion(samples_arr, 0, front.samples.size(), - front.samples.data()); - - jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( - 2, env->FindClass("java/lang/Object"), nullptr); - - env->SetObjectArrayElement(obj_arr, 0, NewInteger(env, front.start)); - env->SetObjectArrayElement(obj_arr, 1, samples_arr); - - return obj_arr; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_isSpeechDetected( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - auto model = reinterpret_cast(ptr); - return model->IsSpeechDetected(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_reset(JNIEnv *env, - jobject /*obj*/, - jlong ptr) { - auto model = reinterpret_cast(ptr); - model->Reset(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new( - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { -#if __ANDROID_API__ >= 9 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); - if (!mgr) { - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); - } -#endif - auto config = sherpa_onnx::GetConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - auto model = new sherpa_onnx::SherpaOnnx( -#if __ANDROID_API__ >= 9 - mgr, -#endif - config); - - return (jlong)model; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_newFromFile( - JNIEnv *env, jobject /*obj*/, jobject _config) { - auto config = sherpa_onnx::GetConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - auto model = new sherpa_onnx::SherpaOnnx(config); - - return (jlong)model; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - delete reinterpret_cast(ptr); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_new( - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { -#if __ANDROID_API__ >= 9 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); - if (!mgr) { - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); - } -#endif - auto config = sherpa_onnx::GetOfflineConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - auto model = new sherpa_onnx::SherpaOnnxOffline( -#if __ANDROID_API__ >= 9 - mgr, -#endif - config); - - return (jlong)model; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL -Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_newFromFile(JNIEnv *env, - jobject /*obj*/, - jobject _config) { - auto config = sherpa_onnx::GetOfflineConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - auto model = new sherpa_onnx::SherpaOnnxOffline(config); - - return (jlong)model; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_delete( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - delete reinterpret_cast(ptr); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset( - JNIEnv *env, jobject /*obj*/, jlong ptr, jboolean recreate, - jstring keywords) { - auto model = reinterpret_cast(ptr); - const char *p_keywords = env->GetStringUTFChars(keywords, nullptr); - model->Reset(recreate, p_keywords); - env->ReleaseStringUTFChars(keywords, p_keywords); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - auto model = reinterpret_cast(ptr); - return model->IsReady(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - auto model = reinterpret_cast(ptr); - return model->IsEndpoint(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - auto model = reinterpret_cast(ptr); - model->Decode(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform( - JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, - jint sample_rate) { - auto model = reinterpret_cast(ptr); - - jfloat *p = env->GetFloatArrayElements(samples, nullptr); - jsize n = env->GetArrayLength(samples); - - model->AcceptWaveform(sample_rate, p, n); - - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_decode( - JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, - jint sample_rate) { - auto model = reinterpret_cast(ptr); - - jfloat *p = env->GetFloatArrayElements(samples, nullptr); - jsize n = env->GetArrayLength(samples); - - auto text = model->Decode(sample_rate, p, n); - - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); - - return env->NewStringUTF(text.c_str()); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_inputFinished( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - reinterpret_cast(ptr)->InputFinished(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - // see - // https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni - auto text = reinterpret_cast(ptr)->GetText(); - return env->NewStringUTF(text.c_str()); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - auto tokens = reinterpret_cast(ptr)->GetTokens(); - int32_t size = tokens.size(); - jclass stringClass = env->FindClass("java/lang/String"); - - // convert C++ list into jni string array - jobjectArray result = env->NewObjectArray(size, stringClass, nullptr); - for (int32_t i = 0; i < size; i++) { - // Convert the C++ string to a C string - const char *cstr = tokens[i].c_str(); - - // Convert the C string to a jstring - jstring jstr = env->NewStringUTF(cstr); - - // Set the array element - env->SetObjectArrayElement(result, i, jstr); - } - - return result; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_new( - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { -#if __ANDROID_API__ >= 9 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); - if (!mgr) { - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); - } -#endif - auto config = sherpa_onnx::GetKwsConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - auto model = new sherpa_onnx::SherpaOnnxKws( -#if __ANDROID_API__ >= 9 - mgr, -#endif - config); - - return (jlong)model; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_newFromFile( - JNIEnv *env, jobject /*obj*/, jobject _config) { - auto config = sherpa_onnx::GetKwsConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - auto model = new sherpa_onnx::SherpaOnnxKws(config); - - return (jlong)model; -} +#include -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_delete( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - delete reinterpret_cast(ptr); -} +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/wave-reader.h" +#include "sherpa-onnx/csrc/wave-writer.h" +#include "sherpa-onnx/jni/common.h" -SHERPA_ONNX_EXTERN_C -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_isReady( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - auto model = reinterpret_cast(ptr); - return model->IsReady(); +// see +// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables +jobject NewInteger(JNIEnv *env, int32_t value) { + jclass cls = env->FindClass("java/lang/Integer"); + jmethodID constructor = env->GetMethodID(cls, "", "(I)V"); + return env->NewObject(cls, constructor, value); } -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_decode( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - auto model = reinterpret_cast(ptr); - model->Decode(); +jobject NewFloat(JNIEnv *env, float value) { + jclass cls = env->FindClass("java/lang/Float"); + jmethodID constructor = env->GetMethodID(cls, "", "(F)V"); + return env->NewObject(cls, constructor, value); } SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_acceptWaveform( - JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, +JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl( + JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, jint sample_rate) { - auto model = reinterpret_cast(ptr); + const char *p_filename = env->GetStringUTFChars(filename, nullptr); jfloat *p = env->GetFloatArrayElements(samples, nullptr); jsize n = env->GetArrayLength(samples); - model->AcceptWaveform(sample_rate, p, n); + bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n); + env->ReleaseStringUTFChars(filename, p_filename); env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_inputFinished( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - reinterpret_cast(ptr)->InputFinished(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getKeyword( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - // see - // https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni - auto text = reinterpret_cast(ptr)->GetKeyword(); - return env->NewStringUTF(text.c_str()); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_reset( - JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) { - const char *p_keywords = env->GetStringUTFChars(keywords, nullptr); - - std::string keywords_str = p_keywords; - - bool status = - reinterpret_cast(ptr)->Reset(keywords_str); - env->ReleaseStringUTFChars(keywords, p_keywords); - return status; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jobjectArray JNICALL -Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getTokens(JNIEnv *env, jobject /*obj*/, - jlong ptr) { - auto tokens = - reinterpret_cast(ptr)->GetTokens(); - int32_t size = tokens.size(); - jclass stringClass = env->FindClass("java/lang/String"); - - // convert C++ list into jni string array - jobjectArray result = env->NewObjectArray(size, stringClass, nullptr); - for (int32_t i = 0; i < size; i++) { - // Convert the C++ string to a C string - const char *cstr = tokens[i].c_str(); - - // Convert the C string to a jstring - jstring jstr = env->NewStringUTF(cstr); - - // Set the array element - env->SetObjectArrayElement(result, i, jstr); - } - return result; + return ok; } static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is, @@ -1593,81 +106,7 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromAsset( return obj_arr; } -// ******warpper for OnlineRecognizer******* - -// wav reader for java interface -SHERPA_ONNX_EXTERN_C -JNIEXPORT jobjectArray JNICALL -Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_readWave(JNIEnv *env, - jclass /*cls*/, - jstring filename) { - auto data = - Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromAsset( - env, nullptr, nullptr, filename); - return data; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL - -Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createOnlineRecognizer( - - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { -#if __ANDROID_API__ >= 9 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); - if (!mgr) { - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); - } -#endif - sherpa_onnx::OnlineRecognizerConfig config = - sherpa_onnx::GetConfig(env, _config); - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - auto p_recognizer = new sherpa_onnx::OnlineRecognizer( -#if __ANDROID_API__ >= 9 - mgr, -#endif - config); - return (jlong)p_recognizer; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL -Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_deleteOnlineRecognizer( - JNIEnv *env, jobject /*obj*/, jlong ptr) { - delete reinterpret_cast(ptr); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jlong JNICALL -Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createStream(JNIEnv *env, - jobject /*obj*/, - jlong ptr) { - std::unique_ptr s = - reinterpret_cast(ptr)->CreateStream(); - sherpa_onnx::OnlineStream *p_stream = s.release(); - return reinterpret_cast(p_stream); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isReady( - JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) { - sherpa_onnx::OnlineRecognizer *model = - reinterpret_cast(ptr); - sherpa_onnx::OnlineStream *s = - reinterpret_cast(s_ptr); - return model->IsReady(s); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStream( - JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) { - sherpa_onnx::OnlineRecognizer *model = - reinterpret_cast(ptr); - sherpa_onnx::OnlineStream *s = - reinterpret_cast(s_ptr); - model->DecodeStream(s); -} - +#if 0 SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStreams(JNIEnv *env, @@ -1687,92 +126,4 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStreams(JNIEnv *env, model->DecodeStreams(p_ss.data(), n); env->ReleaseLongArrayElements(ss_ptr, p, JNI_ABORT); } - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult( - JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) { - sherpa_onnx::OnlineRecognizer *model = - reinterpret_cast(ptr); - sherpa_onnx::OnlineStream *s = - reinterpret_cast(s_ptr); - sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s); - return env->NewStringUTF(result.text.c_str()); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isEndpoint( - JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) { - sherpa_onnx::OnlineRecognizer *model = - reinterpret_cast(ptr); - sherpa_onnx::OnlineStream *s = - reinterpret_cast(s_ptr); - return model->IsEndpoint(s); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_reSet( - JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) { - sherpa_onnx::OnlineRecognizer *model = - reinterpret_cast(ptr); - sherpa_onnx::OnlineStream *s = - reinterpret_cast(s_ptr); - model->Reset(s); -} - -// *********for OnlineStream ********* -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_acceptWaveform( - JNIEnv *env, jobject /*obj*/, jlong s_ptr, jint sample_rate, - jfloatArray waveform) { - sherpa_onnx::OnlineStream *s = - reinterpret_cast(s_ptr); - jfloat *p = env->GetFloatArrayElements(waveform, nullptr); - jsize n = env->GetArrayLength(waveform); - s->AcceptWaveform(sample_rate, p, n); - env->ReleaseFloatArrayElements(waveform, p, JNI_ABORT); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_inputFinished( - JNIEnv *env, jobject /*obj*/, jlong s_ptr) { - sherpa_onnx::OnlineStream *s = - reinterpret_cast(s_ptr); - s->InputFinished(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_deleteStream( - JNIEnv *env, jobject /*obj*/, jlong s_ptr) { - delete reinterpret_cast(s_ptr); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_numFramesReady( - JNIEnv *env, jobject /*obj*/, jlong s_ptr) { - sherpa_onnx::OnlineStream *s = - reinterpret_cast(s_ptr); - return s->NumFramesReady(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_isLastFrame( - JNIEnv *env, jobject /*obj*/, jlong s_ptr, jint frame) { - sherpa_onnx::OnlineStream *s = - reinterpret_cast(s_ptr); - return s->IsLastFrame(frame); -} -SHERPA_ONNX_EXTERN_C -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_reSet( - JNIEnv *env, jobject /*obj*/, jlong s_ptr) { - sherpa_onnx::OnlineStream *s = - reinterpret_cast(s_ptr); - s->Reset(); -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_featureDim( - JNIEnv *env, jobject /*obj*/, jlong s_ptr) { - sherpa_onnx::OnlineStream *s = - reinterpret_cast(s_ptr); - return s->FeatureDim(); -} +#endif diff --git a/sherpa-onnx/jni/keyword-spotter.cc b/sherpa-onnx/jni/keyword-spotter.cc new file mode 100644 index 000000000..72b263046 --- /dev/null +++ b/sherpa-onnx/jni/keyword-spotter.cc @@ -0,0 +1,233 @@ +// sherpa-onnx/jni/keyword-spotter.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/keyword-spotter.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/jni/common.h" + +namespace sherpa_onnx { + +static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { + KeywordSpotterConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html + // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html + + //---------- decoding ---------- + fid = env->GetFieldID(cls, "maxActivePaths", "I"); + ans.max_active_paths = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.keywords_file = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "keywordsScore", "F"); + ans.keywords_score = env->GetFloatField(config, fid); + + fid = env->GetFieldID(cls, "keywordsThreshold", "F"); + ans.keywords_threshold = env->GetFloatField(config, fid); + + fid = env->GetFieldID(cls, "numTrailingBlanks", "I"); + ans.num_trailing_blanks = env->GetIntField(config, fid); + + //---------- feat config ---------- + fid = env->GetFieldID(cls, "featConfig", + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); + jobject feat_config = env->GetObjectField(config, fid); + jclass feat_config_cls = env->GetObjectClass(feat_config); + + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); + + fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); + + //---------- model config ---------- + fid = env->GetFieldID(cls, "modelConfig", + "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;"); + jobject model_config = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model_config); + + // transducer + fid = env->GetFieldID(model_config_cls, "transducer", + "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); + jobject transducer_config = env->GetObjectField(model_config, fid); + jclass transducer_config_cls = env->GetObjectClass(transducer_config); + + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.decoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.joiner = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.tokens = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); + ans.model_config.num_threads = env->GetIntField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "debug", "Z"); + ans.model_config.debug = env->GetBooleanField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.provider = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.model_type = p; + env->ReleaseStringUTFChars(s, p); + + return ans; +} + +} // namespace sherpa_onnx + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + } +#endif + auto config = sherpa_onnx::GetKwsConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + auto kws = new sherpa_onnx::KeywordSpotter( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)kws; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_onnx::GetKwsConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto kws = new sherpa_onnx::KeywordSpotter(config); + + return (jlong)kws; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_delete( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto kws = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + kws->DecodeStream(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream( + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) { + auto kws = reinterpret_cast(ptr); + + const char *p = env->GetStringUTFChars(keywords, nullptr); + std::unique_ptr stream; + + if (strlen(p) == 0) { + stream = kws->CreateStream(); + } else { + stream = kws->CreateStream(p); + } + + env->ReleaseStringUTFChars(keywords, p); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from + // ./offline-stream.cc + sherpa_onnx::OnlineStream *ans = stream.release(); + return (jlong)ans; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_isReady( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto kws = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + return kws->IsReady(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_KeywordSpotter_getResult(JNIEnv *env, + jobject /*obj*/, jlong ptr, + jlong stream_ptr) { + auto kws = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + sherpa_onnx::KeywordResult result = kws->GetResult(stream); + + // [0]: keyword, jstring + // [1]: tokens, array of jstring + // [2]: timestamps, array of float + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 3, env->FindClass("java/lang/Object"), nullptr); + + jstring keyword = env->NewStringUTF(result.keyword.c_str()); + env->SetObjectArrayElement(obj_arr, 0, keyword); + + jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray( + result.tokens.size(), env->FindClass("java/lang/String"), nullptr); + + int32_t i = 0; + for (const auto &t : result.tokens) { + jstring jtext = env->NewStringUTF(t.c_str()); + env->SetObjectArrayElement(tokens_arr, i, jtext); + i += 1; + } + + env->SetObjectArrayElement(obj_arr, 1, tokens_arr); + + jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size()); + env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(), + result.timestamps.data()); + + env->SetObjectArrayElement(obj_arr, 2, timestamps_arr); + + return obj_arr; +} diff --git a/sherpa-onnx/jni/offline-recognizer.cc b/sherpa-onnx/jni/offline-recognizer.cc new file mode 100644 index 000000000..0103417a2 --- /dev/null +++ b/sherpa-onnx/jni/offline-recognizer.cc @@ -0,0 +1,263 @@ +// sherpa-onnx/jni/offline-recognizer.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-recognizer.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/jni/common.h" + +namespace sherpa_onnx { + +static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { + OfflineRecognizerConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + //---------- decoding ---------- + fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.decoding_method = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "maxActivePaths", "I"); + ans.max_active_paths = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.hotwords_file = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "hotwordsScore", "F"); + ans.hotwords_score = env->GetFloatField(config, fid); + + //---------- feat config ---------- + fid = env->GetFieldID(cls, "featConfig", + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); + jobject feat_config = env->GetObjectField(config, fid); + jclass feat_config_cls = env->GetObjectClass(feat_config); + + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); + + fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); + + //---------- model config ---------- + fid = env->GetFieldID(cls, "modelConfig", + "Lcom/k2fsa/sherpa/onnx/OfflineModelConfig;"); + jobject model_config = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model_config); + + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.tokens = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); + ans.model_config.num_threads = env->GetIntField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "debug", "Z"); + ans.model_config.debug = env->GetBooleanField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.provider = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.model_type = p; + env->ReleaseStringUTFChars(s, p); + + // transducer + fid = env->GetFieldID(model_config_cls, "transducer", + "Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;"); + jobject transducer_config = env->GetObjectField(model_config, fid); + jclass transducer_config_cls = env->GetObjectClass(transducer_config); + + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.encoder_filename = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.decoder_filename = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.joiner_filename = p; + env->ReleaseStringUTFChars(s, p); + + // paraformer + fid = env->GetFieldID(model_config_cls, "paraformer", + "Lcom/k2fsa/sherpa/onnx/OfflineParaformerModelConfig;"); + jobject paraformer_config = env->GetObjectField(model_config, fid); + jclass paraformer_config_cls = env->GetObjectClass(paraformer_config); + + fid = env->GetFieldID(paraformer_config_cls, "model", "Ljava/lang/String;"); + + s = (jstring)env->GetObjectField(paraformer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.paraformer.model = p; + env->ReleaseStringUTFChars(s, p); + + // whisper + fid = env->GetFieldID(model_config_cls, "whisper", + "Lcom/k2fsa/sherpa/onnx/OfflineWhisperModelConfig;"); + jobject whisper_config = env->GetObjectField(model_config, fid); + jclass whisper_config_cls = env->GetObjectClass(whisper_config); + + fid = env->GetFieldID(whisper_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(whisper_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.whisper.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_config_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(whisper_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.whisper.decoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_config_cls, "language", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(whisper_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.whisper.language = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_config_cls, "task", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(whisper_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.whisper.task = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_config_cls, "tailPaddings", "I"); + ans.model_config.whisper.tail_paddings = + env->GetIntField(whisper_config, fid); + + return ans; +} + +} // namespace sherpa_onnx + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromAsset(JNIEnv *env, + jobject /*obj*/, + jobject asset_manager, + jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + } +#endif + auto config = sherpa_onnx::GetOfflineConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + auto model = new sherpa_onnx::OfflineRecognizer( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromFile(JNIEnv *env, + jobject /*obj*/, + jobject _config) { + auto config = sherpa_onnx::GetOfflineConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto model = new sherpa_onnx::OfflineRecognizer(config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_delete( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_createStream(JNIEnv *env, + jobject /*obj*/, + jlong ptr) { + auto recognizer = reinterpret_cast(ptr); + std::unique_ptr s = recognizer->CreateStream(); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from + // ./offline-stream.cc + sherpa_onnx::OfflineStream *p = s.release(); + return (jlong)p; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_decode( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong streamPtr) { + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(streamPtr); + + recognizer->DecodeStream(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env, + jobject /*obj*/, + jlong streamPtr) { + auto stream = reinterpret_cast(streamPtr); + sherpa_onnx::OfflineRecognitionResult result = stream->GetResult(); + + // [0]: text, jstring + // [1]: tokens, array of jstring + // [2]: timestamps, array of float + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 3, env->FindClass("java/lang/Object"), nullptr); + + jstring text = env->NewStringUTF(result.text.c_str()); + env->SetObjectArrayElement(obj_arr, 0, text); + + jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray( + result.tokens.size(), env->FindClass("java/lang/String"), nullptr); + + int32_t i = 0; + for (const auto &t : result.tokens) { + jstring jtext = env->NewStringUTF(t.c_str()); + env->SetObjectArrayElement(tokens_arr, i, jtext); + i += 1; + } + + env->SetObjectArrayElement(obj_arr, 1, tokens_arr); + + jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size()); + env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(), + result.timestamps.data()); + + env->SetObjectArrayElement(obj_arr, 2, timestamps_arr); + + return obj_arr; +} diff --git a/sherpa-onnx/jni/online-recognizer.cc b/sherpa-onnx/jni/online-recognizer.cc new file mode 100644 index 000000000..8fa069c05 --- /dev/null +++ b/sherpa-onnx/jni/online-recognizer.cc @@ -0,0 +1,352 @@ +// sherpa-onnx/jni/online-recognizer.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-recognizer.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/jni/common.h" + +namespace sherpa_onnx { + +static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { + OnlineRecognizerConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html + // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html + + //---------- decoding ---------- + fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.decoding_method = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "maxActivePaths", "I"); + ans.max_active_paths = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.hotwords_file = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "hotwordsScore", "F"); + ans.hotwords_score = env->GetFloatField(config, fid); + + //---------- feat config ---------- + fid = env->GetFieldID(cls, "featConfig", + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); + jobject feat_config = env->GetObjectField(config, fid); + jclass feat_config_cls = env->GetObjectClass(feat_config); + + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); + + fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); + + //---------- enable endpoint ---------- + fid = env->GetFieldID(cls, "enableEndpoint", "Z"); + ans.enable_endpoint = env->GetBooleanField(config, fid); + + //---------- endpoint_config ---------- + + fid = env->GetFieldID(cls, "endpointConfig", + "Lcom/k2fsa/sherpa/onnx/EndpointConfig;"); + jobject endpoint_config = env->GetObjectField(config, fid); + jclass endpoint_config_cls = env->GetObjectClass(endpoint_config); + + fid = env->GetFieldID(endpoint_config_cls, "rule1", + "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); + jobject rule1 = env->GetObjectField(endpoint_config, fid); + jclass rule_class = env->GetObjectClass(rule1); + + fid = env->GetFieldID(endpoint_config_cls, "rule2", + "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); + jobject rule2 = env->GetObjectField(endpoint_config, fid); + + fid = env->GetFieldID(endpoint_config_cls, "rule3", + "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); + jobject rule3 = env->GetObjectField(endpoint_config, fid); + + fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z"); + ans.endpoint_config.rule1.must_contain_nonsilence = + env->GetBooleanField(rule1, fid); + ans.endpoint_config.rule2.must_contain_nonsilence = + env->GetBooleanField(rule2, fid); + ans.endpoint_config.rule3.must_contain_nonsilence = + env->GetBooleanField(rule3, fid); + + fid = env->GetFieldID(rule_class, "minTrailingSilence", "F"); + ans.endpoint_config.rule1.min_trailing_silence = + env->GetFloatField(rule1, fid); + ans.endpoint_config.rule2.min_trailing_silence = + env->GetFloatField(rule2, fid); + ans.endpoint_config.rule3.min_trailing_silence = + env->GetFloatField(rule3, fid); + + fid = env->GetFieldID(rule_class, "minUtteranceLength", "F"); + ans.endpoint_config.rule1.min_utterance_length = + env->GetFloatField(rule1, fid); + ans.endpoint_config.rule2.min_utterance_length = + env->GetFloatField(rule2, fid); + ans.endpoint_config.rule3.min_utterance_length = + env->GetFloatField(rule3, fid); + + //---------- model config ---------- + fid = env->GetFieldID(cls, "modelConfig", + "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;"); + jobject model_config = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model_config); + + // transducer + fid = env->GetFieldID(model_config_cls, "transducer", + "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); + jobject transducer_config = env->GetObjectField(model_config, fid); + jclass transducer_config_cls = env->GetObjectClass(transducer_config); + + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.decoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.joiner = p; + env->ReleaseStringUTFChars(s, p); + + // paraformer + fid = env->GetFieldID(model_config_cls, "paraformer", + "Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;"); + jobject paraformer_config = env->GetObjectField(model_config, fid); + jclass paraformer_config_cls = env->GetObjectClass(paraformer_config); + + fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(paraformer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.paraformer.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(paraformer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.paraformer.decoder = p; + env->ReleaseStringUTFChars(s, p); + + // streaming zipformer2 CTC + fid = + env->GetFieldID(model_config_cls, "zipformer2Ctc", + "Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;"); + jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid); + jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config); + + fid = + env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.zipformer2_ctc.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.tokens = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); + ans.model_config.num_threads = env->GetIntField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "debug", "Z"); + ans.model_config.debug = env->GetBooleanField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.provider = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.model_type = p; + env->ReleaseStringUTFChars(s, p); + + //---------- rnn lm model config ---------- + fid = env->GetFieldID(cls, "lmConfig", + "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); + jobject lm_model_config = env->GetObjectField(config, fid); + jclass lm_model_config_cls = env->GetObjectClass(lm_model_config); + + fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(lm_model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.lm_config.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(lm_model_config_cls, "scale", "F"); + ans.lm_config.scale = env->GetFloatField(lm_model_config, fid); + + return ans; +} +} // namespace sherpa_onnx + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromAsset(JNIEnv *env, + jobject /*obj*/, + jobject asset_manager, + jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + } +#endif + auto config = sherpa_onnx::GetConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto recognizer = new sherpa_onnx::OnlineRecognizer( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)recognizer; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_onnx::GetConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto recognizer = new sherpa_onnx::OnlineRecognizer(config); + + return (jlong)recognizer; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_delete( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_reset( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + recognizer->Reset(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isReady( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + return recognizer->IsReady(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isEndpoint( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + return recognizer->IsEndpoint(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decode( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + recognizer->DecodeStream(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createStream(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jstring hotwords) { + auto recognizer = reinterpret_cast(ptr); + + const char *p = env->GetStringUTFChars(hotwords, nullptr); + std::unique_ptr stream; + + if (strlen(p) == 0) { + stream = recognizer->CreateStream(); + } else { + stream = recognizer->CreateStream(p); + } + + env->ReleaseStringUTFChars(hotwords, p); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from + // ./offline-stream.cc + sherpa_onnx::OnlineStream *ans = stream.release(); + return (jlong)ans; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jlong stream_ptr) { + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + sherpa_onnx::OnlineRecognizerResult result = recognizer->GetResult(stream); + + // [0]: text, jstring + // [1]: tokens, array of jstring + // [2]: timestamps, array of float + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 3, env->FindClass("java/lang/Object"), nullptr); + + jstring text = env->NewStringUTF(result.text.c_str()); + env->SetObjectArrayElement(obj_arr, 0, text); + + jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray( + result.tokens.size(), env->FindClass("java/lang/String"), nullptr); + + int32_t i = 0; + for (const auto &t : result.tokens) { + jstring jtext = env->NewStringUTF(t.c_str()); + env->SetObjectArrayElement(tokens_arr, i, jtext); + i += 1; + } + + env->SetObjectArrayElement(obj_arr, 1, tokens_arr); + + jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size()); + env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(), + result.timestamps.data()); + + env->SetObjectArrayElement(obj_arr, 2, timestamps_arr); + + return obj_arr; +} diff --git a/sherpa-onnx/jni/online-stream.cc b/sherpa-onnx/jni/online-stream.cc new file mode 100644 index 000000000..2ff3a0fe2 --- /dev/null +++ b/sherpa-onnx/jni/online-stream.cc @@ -0,0 +1,32 @@ +// sherpa-onnx/jni/online-stream.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-stream.h" + +#include "sherpa-onnx/jni/common.h" + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_delete( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_acceptWaveform( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, + jint sample_rate) { + auto stream = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + stream->AcceptWaveform(sample_rate, p, n); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_inputFinished( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + auto stream = reinterpret_cast(ptr); + stream->InputFinished(); +} diff --git a/sherpa-onnx/jni/speaker-embedding-extractor.cc b/sherpa-onnx/jni/speaker-embedding-extractor.cc new file mode 100644 index 000000000..49598e77d --- /dev/null +++ b/sherpa-onnx/jni/speaker-embedding-extractor.cc @@ -0,0 +1,137 @@ +// sherpa-onnx/jni/speaker-embedding-extractor.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" + +#include "sherpa-onnx/jni/common.h" + +namespace sherpa_onnx { + +static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig( + JNIEnv *env, jobject config) { + SpeakerEmbeddingExtractorConfig ans; + + jclass cls = env->GetObjectClass(config); + + jfieldID fid = env->GetFieldID(cls, "model", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + + ans.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "numThreads", "I"); + ans.num_threads = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "debug", "Z"); + ans.debug = env->GetBooleanField(config, fid); + + fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.provider = p; + env->ReleaseStringUTFChars(s, p); + + return ans; +} + +} // namespace sherpa_onnx + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + } +#endif + auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config); + SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str()); + + auto extractor = new sherpa_onnx::SpeakerEmbeddingExtractor( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)extractor; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config); + SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + } + + auto extractor = new sherpa_onnx::SpeakerEmbeddingExtractor(config); + + return (jlong)extractor; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_delete(JNIEnv *env, + jobject /*obj*/, + jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_createStream( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + std::unique_ptr s = + reinterpret_cast(ptr) + ->CreateStream(); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_onnx_OnlineStream_delete() from + // ./online-stream.cc + sherpa_onnx::OnlineStream *p = s.release(); + return (jlong)p; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_isReady(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jlong stream_ptr) { + auto extractor = + reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + return extractor->IsReady(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jfloatArray JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_compute(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jlong stream_ptr) { + auto extractor = + reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + std::vector embedding = extractor->Compute(stream); + jfloatArray embedding_arr = env->NewFloatArray(embedding.size()); + env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(), + embedding.data()); + return embedding_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_dim( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + auto extractor = + reinterpret_cast(ptr); + return extractor->Dim(); +} diff --git a/sherpa-onnx/jni/speaker-embedding-manager.cc b/sherpa-onnx/jni/speaker-embedding-manager.cc new file mode 100644 index 000000000..10ac285d1 --- /dev/null +++ b/sherpa-onnx/jni/speaker-embedding-manager.cc @@ -0,0 +1,207 @@ +// sherpa-onnx/jni/speaker-embedding-manager.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-onnx/csrc/speaker-embedding-manager.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/jni/common.h" + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_create(JNIEnv *env, + jobject /*obj*/, + jint dim) { + auto p = new sherpa_onnx::SpeakerEmbeddingManager(dim); + return (jlong)p; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_delete(JNIEnv *env, + jobject /*obj*/, + jlong ptr) { + auto manager = reinterpret_cast(ptr); + delete manager; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_add(JNIEnv *env, + jobject /*obj*/, + jlong ptr, jstring name, + jfloatArray embedding) { + auto manager = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(embedding, nullptr); + jsize n = env->GetArrayLength(embedding); + + if (n != manager->Dim()) { + SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(), + static_cast(n)); + exit(-1); + } + + const char *p_name = env->GetStringUTFChars(name, nullptr); + + jboolean ok = manager->Add(p_name, p); + env->ReleaseStringUTFChars(name, p_name); + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); + + return ok; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_addList( + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name, + jobjectArray embedding_arr) { + auto manager = reinterpret_cast(ptr); + + int num_embeddings = env->GetArrayLength(embedding_arr); + if (num_embeddings == 0) { + return false; + } + + std::vector> embedding_list; + embedding_list.reserve(num_embeddings); + for (int32_t i = 0; i != num_embeddings; ++i) { + jfloatArray embedding = + (jfloatArray)env->GetObjectArrayElement(embedding_arr, i); + + jfloat *p = env->GetFloatArrayElements(embedding, nullptr); + jsize n = env->GetArrayLength(embedding); + + if (n != manager->Dim()) { + SHERPA_ONNX_LOGE("i: %d. Expected dim %d, given %d", i, manager->Dim(), + static_cast(n)); + exit(-1); + } + + embedding_list.push_back({p, p + n}); + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); + } + + const char *p_name = env->GetStringUTFChars(name, nullptr); + + jboolean ok = manager->Add(p_name, embedding_list); + + env->ReleaseStringUTFChars(name, p_name); + + return ok; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_remove(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jstring name) { + auto manager = reinterpret_cast(ptr); + + const char *p_name = env->GetStringUTFChars(name, nullptr); + + jboolean ok = manager->Remove(p_name); + + env->ReleaseStringUTFChars(name, p_name); + + return ok; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jstring JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_search(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jfloatArray embedding, + jfloat threshold) { + auto manager = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(embedding, nullptr); + jsize n = env->GetArrayLength(embedding); + + if (n != manager->Dim()) { + SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(), + static_cast(n)); + exit(-1); + } + + std::string name = manager->Search(p, threshold); + + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); + + return env->NewStringUTF(name.c_str()); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_verify( + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name, + jfloatArray embedding, jfloat threshold) { + auto manager = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(embedding, nullptr); + jsize n = env->GetArrayLength(embedding); + + if (n != manager->Dim()) { + SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(), + static_cast(n)); + exit(-1); + } + + const char *p_name = env->GetStringUTFChars(name, nullptr); + + jboolean ok = manager->Verify(p_name, p, threshold); + + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); + + env->ReleaseStringUTFChars(name, p_name); + + return ok; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_contains(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jstring name) { + auto manager = reinterpret_cast(ptr); + + const char *p_name = env->GetStringUTFChars(name, nullptr); + + jboolean ok = manager->Contains(p_name); + + env->ReleaseStringUTFChars(name, p_name); + + return ok; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_numSpeakers(JNIEnv *env, + jobject /*obj*/, + jlong ptr) { + auto manager = reinterpret_cast(ptr); + return manager->NumSpeakers(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + auto manager = reinterpret_cast(ptr); + std::vector all_speakers = manager->GetAllSpeakers(); + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + all_speakers.size(), env->FindClass("java/lang/String"), nullptr); + + int32_t i = 0; + for (auto &s : all_speakers) { + jstring js = env->NewStringUTF(s.c_str()); + env->SetObjectArrayElement(obj_arr, i, js); + + ++i; + } + + return obj_arr; +} diff --git a/sherpa-onnx/jni/voice-activity-detector.cc b/sherpa-onnx/jni/voice-activity-detector.cc new file mode 100644 index 000000000..bfa31204b --- /dev/null +++ b/sherpa-onnx/jni/voice-activity-detector.cc @@ -0,0 +1,175 @@ +// sherpa-onnx/csrc/voice-activity-detector.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-onnx/csrc/voice-activity-detector.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/jni/common.h" + +namespace sherpa_onnx { + +static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) { + VadModelConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + // silero_vad + fid = env->GetFieldID(cls, "sileroVadModelConfig", + "Lcom/k2fsa/sherpa/onnx/SileroVadModelConfig;"); + jobject silero_vad_config = env->GetObjectField(config, fid); + jclass silero_vad_config_cls = env->GetObjectClass(silero_vad_config); + + fid = env->GetFieldID(silero_vad_config_cls, "model", "Ljava/lang/String;"); + auto s = (jstring)env->GetObjectField(silero_vad_config, fid); + auto p = env->GetStringUTFChars(s, nullptr); + ans.silero_vad.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(silero_vad_config_cls, "threshold", "F"); + ans.silero_vad.threshold = env->GetFloatField(silero_vad_config, fid); + + fid = env->GetFieldID(silero_vad_config_cls, "minSilenceDuration", "F"); + ans.silero_vad.min_silence_duration = + env->GetFloatField(silero_vad_config, fid); + + fid = env->GetFieldID(silero_vad_config_cls, "minSpeechDuration", "F"); + ans.silero_vad.min_speech_duration = + env->GetFloatField(silero_vad_config, fid); + + fid = env->GetFieldID(silero_vad_config_cls, "windowSize", "I"); + ans.silero_vad.window_size = env->GetIntField(silero_vad_config, fid); + + fid = env->GetFieldID(cls, "sampleRate", "I"); + ans.sample_rate = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "numThreads", "I"); + ans.num_threads = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.provider = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "debug", "Z"); + ans.debug = env->GetBooleanField(config, fid); + + return ans; +} + +} // namespace sherpa_onnx + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + } +#endif + auto config = sherpa_onnx::GetVadModelConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + auto model = new sherpa_onnx::VoiceActivityDetector( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_onnx::GetVadModelConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto model = new sherpa_onnx::VoiceActivityDetector(config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_delete(JNIEnv *env, + jobject /*obj*/, + jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_acceptWaveform( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) { + auto model = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + + model->AcceptWaveform(p, n); + + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_empty(JNIEnv *env, + jobject /*obj*/, + jlong ptr) { + auto model = reinterpret_cast(ptr); + return model->Empty(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env, + jobject /*obj*/, + jlong ptr) { + auto model = reinterpret_cast(ptr); + model->Pop(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env, + jobject /*obj*/, + jlong ptr) { + auto model = reinterpret_cast(ptr); + model->Clear(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) { + const auto &front = + reinterpret_cast(ptr)->Front(); + + jfloatArray samples_arr = env->NewFloatArray(front.samples.size()); + env->SetFloatArrayRegion(samples_arr, 0, front.samples.size(), + front.samples.data()); + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 2, env->FindClass("java/lang/Object"), nullptr); + + env->SetObjectArrayElement(obj_arr, 0, NewInteger(env, front.start)); + env->SetObjectArrayElement(obj_arr, 1, samples_arr); + + return obj_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_isSpeechDetected( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + auto model = reinterpret_cast(ptr); + return model->IsSpeechDetected(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_reset(JNIEnv *env, + jobject /*obj*/, + jlong ptr) { + auto model = reinterpret_cast(ptr); + model->Reset(); +} diff --git a/sherpa-onnx/kotlin-api/AudioTagging.kt b/sherpa-onnx/kotlin-api/AudioTagging.kt new file mode 100644 index 000000000..9b241ee56 --- /dev/null +++ b/sherpa-onnx/kotlin-api/AudioTagging.kt @@ -0,0 +1,186 @@ +package com.k2fsa.sherpa.onnx + +import android.content.res.AssetManager + +data class OfflineZipformerAudioTaggingModelConfig( + var model: String = "", +) + +data class AudioTaggingModelConfig( + var zipformer: OfflineZipformerAudioTaggingModelConfig = OfflineZipformerAudioTaggingModelConfig(), + var ced: String = "", + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +data class AudioTaggingConfig( + var model: AudioTaggingModelConfig, + var labels: String, + var topK: Int = 5, +) + +data class AudioEvent( + val name: String, + val index: Int, + val prob: Float, +) + +class AudioTagging( + assetManager: AssetManager? = null, + config: AudioTaggingConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun createStream(): OfflineStream { + val p = createStream(ptr) + return OfflineStream(p) + } + + @Suppress("UNCHECKED_CAST") + fun compute(stream: OfflineStream, topK: Int = -1): ArrayList { + val events: Array = compute(ptr, stream.ptr, topK) + val ans = ArrayList() + + for (e in events) { + val p: Array = e as Array + ans.add( + AudioEvent( + name = p[0] as String, + index = p[1] as Int, + prob = p[2] as Float, + ) + ) + } + + return ans + } + + private external fun newFromAsset( + assetManager: AssetManager, + config: AudioTaggingConfig, + ): Long + + private external fun newFromFile( + config: AudioTaggingConfig, + ): Long + + private external fun delete(ptr: Long) + + private external fun createStream(ptr: Long): Long + + private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} + +// please refer to +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models +// to download more models +// +// See also +// https://k2-fsa.github.io/sherpa/onnx/audio-tagging/ +fun getAudioTaggingConfig(type: Int, numThreads: Int = 1): AudioTaggingConfig? { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"), + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + + 1 -> { + val modelDir = "sherpa-onnx-zipformer-audio-tagging-2024-04-09" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"), + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + + 2 -> { + val modelDir = "sherpa-onnx-ced-tiny-audio-tagging-2024-04-19" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + ced = "$modelDir/model.int8.onnx", + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + + 3 -> { + val modelDir = "sherpa-onnx-ced-mini-audio-tagging-2024-04-19" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + ced = "$modelDir/model.int8.onnx", + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + + 4 -> { + val modelDir = "sherpa-onnx-ced-small-audio-tagging-2024-04-19" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + ced = "$modelDir/model.int8.onnx", + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + + 5 -> { + val modelDir = "sherpa-onnx-ced-base-audio-tagging-2024-04-19" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + ced = "$modelDir/model.int8.onnx", + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + } + + return null +} diff --git a/sherpa-onnx/kotlin-api/FeatureConfig.kt b/sherpa-onnx/kotlin-api/FeatureConfig.kt new file mode 100644 index 000000000..ed55e9fcc --- /dev/null +++ b/sherpa-onnx/kotlin-api/FeatureConfig.kt @@ -0,0 +1,10 @@ +package com.k2fsa.sherpa.onnx + +data class FeatureConfig( + var sampleRate: Int = 16000, + var featureDim: Int = 80, +) + +fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig { + return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim) +} diff --git a/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/sherpa-onnx/kotlin-api/KeywordSpotter.kt similarity index 66% rename from android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt rename to sherpa-onnx/kotlin-api/KeywordSpotter.kt index d40692665..803762e51 100644 --- a/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ b/sherpa-onnx/kotlin-api/KeywordSpotter.kt @@ -3,26 +3,6 @@ package com.k2fsa.sherpa.onnx import android.content.res.AssetManager -data class OnlineTransducerModelConfig( - var encoder: String = "", - var decoder: String = "", - var joiner: String = "", -) - -data class OnlineModelConfig( - var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(), - var tokens: String, - var numThreads: Int = 1, - var debug: Boolean = false, - var provider: String = "cpu", - var modelType: String = "", -) - -data class FeatureConfig( - var sampleRate: Int = 16000, - var featureDim: Int = 80, -) - data class KeywordSpotterConfig( var featConfig: FeatureConfig = FeatureConfig(), var modelConfig: OnlineModelConfig, @@ -33,17 +13,24 @@ data class KeywordSpotterConfig( var numTrailingBlanks: Int = 2, ) -class SherpaOnnxKws( +data class KeywordSpotterResult( + val keyword: String, + val tokens: Array, + val timestamps: FloatArray, + // TODO(fangjun): Add more fields +) + +class KeywordSpotter( assetManager: AssetManager? = null, - var config: KeywordSpotterConfig, + val config: KeywordSpotterConfig, ) { private val ptr: Long init { - if (assetManager != null) { - ptr = new(assetManager, config) + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) } else { - ptr = newFromFile(config) + newFromFile(config) } } @@ -51,20 +38,28 @@ class SherpaOnnxKws( delete(ptr) } - fun acceptWaveform(samples: FloatArray, sampleRate: Int) = - acceptWaveform(ptr, samples, sampleRate) + fun release() = finalize() - fun inputFinished() = inputFinished(ptr) - fun decode() = decode(ptr) - fun isReady(): Boolean = isReady(ptr) - fun reset(keywords: String): Boolean = reset(ptr, keywords) + fun createStream(keywords: String = ""): OnlineStream { + val p = createStream(ptr, keywords) + return OnlineStream(p) + } + + fun decode(stream: OnlineStream) = decode(ptr, stream.ptr) + fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr) + fun getResult(stream: OnlineStream): KeywordSpotterResult { + val objArray = getResult(ptr, stream.ptr) - val keyword: String - get() = getKeyword(ptr) + val keyword = objArray[0] as String + val tokens = objArray[1] as Array + val timestamps = objArray[2] as FloatArray + + return KeywordSpotterResult(keyword = keyword, tokens = tokens, timestamps = timestamps) + } private external fun delete(ptr: Long) - private external fun new( + private external fun newFromAsset( assetManager: AssetManager, config: KeywordSpotterConfig, ): Long @@ -73,12 +68,10 @@ class SherpaOnnxKws( config: KeywordSpotterConfig, ): Long - private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) - private external fun inputFinished(ptr: Long) - private external fun getKeyword(ptr: Long): String - private external fun reset(ptr: Long, keywords: String): Boolean - private external fun decode(ptr: Long) - private external fun isReady(ptr: Long): Boolean + private external fun createStream(ptr: Long, keywords: String): Long + private external fun isReady(ptr: Long, streamPtr: Long): Boolean + private external fun decode(ptr: Long, streamPtr: Long) + private external fun getResult(ptr: Long, streamPtr: Long): Array companion object { init { @@ -87,10 +80,6 @@ class SherpaOnnxKws( } } -fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig { - return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim) -} - /* Please see https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html @@ -108,7 +97,7 @@ by following the code) https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/summary */ -fun getModelConfig(type: Int): OnlineModelConfig? { +fun getKwsModelConfig(type: Int): OnlineModelConfig? { when (type) { 0 -> { val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01" @@ -137,15 +126,15 @@ fun getModelConfig(type: Int): OnlineModelConfig? { } } - return null; + return null } /* * Get the default keywords for each model. * Caution: The types and modelDir should be the same as those in getModelConfig * function above. - */ -fun getKeywordsFile(type: Int) : String { + */ +fun getKeywordsFile(type: Int): String { when (type) { 0 -> { val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01" @@ -158,5 +147,5 @@ fun getKeywordsFile(type: Int) : String { } } - return ""; + return "" } diff --git a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt new file mode 100644 index 000000000..a559e662b --- /dev/null +++ b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt @@ -0,0 +1,221 @@ +package com.k2fsa.sherpa.onnx + +import android.content.res.AssetManager + +data class OfflineRecognizerResult( + val text: String, + val tokens: Array, + val timestamps: FloatArray, +) + +data class OfflineTransducerModelConfig( + var encoder: String = "", + var decoder: String = "", + var joiner: String = "", +) + +data class OfflineParaformerModelConfig( + var model: String = "", +) + +data class OfflineWhisperModelConfig( + var encoder: String = "", + var decoder: String = "", + var language: String = "en", // Used with multilingual model + var task: String = "transcribe", // transcribe or translate + var tailPaddings: Int = 1000, // Padding added at the end of the samples +) + +data class OfflineModelConfig( + var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(), + var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(), + var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(), + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", + var modelType: String = "", + var tokens: String, +) + +data class OfflineRecognizerConfig( + var featConfig: FeatureConfig = FeatureConfig(), + var modelConfig: OfflineModelConfig, + // var lmConfig: OfflineLMConfig(), // TODO(fangjun): enable it + var decodingMethod: String = "greedy_search", + var maxActivePaths: Int = 4, + var hotwordsFile: String = "", + var hotwordsScore: Float = 1.5f, +) + +class OfflineRecognizer( + assetManager: AssetManager? = null, + config: OfflineRecognizerConfig, +) { + private val ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + delete(ptr) + } + + fun release() = finalize() + + fun createStream(): OfflineStream { + val p = createStream(ptr) + return OfflineStream(p) + } + + fun getResult(stream: OfflineStream): OfflineRecognizerResult { + val objArray = getResult(stream.ptr) + + val text = objArray[0] as String + val tokens = objArray[1] as Array + val timestamps = objArray[2] as FloatArray + return OfflineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps) + } + + fun decode(stream: OfflineStream) = decode(ptr, stream.ptr) + + private external fun delete(ptr: Long) + + private external fun createStream(ptr: Long): Long + + private external fun newFromAsset( + assetManager: AssetManager, + config: OfflineRecognizerConfig, + ): Long + + private external fun newFromFile( + config: OfflineRecognizerConfig, + ): Long + + private external fun decode(ptr: Long, streamPtr: Long) + + private external fun getResult(streamPtr: Long): Array + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} + +/* +Please see +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models. + +We only add a few here. Please change the following code +to add your own. (It should be straightforward to add a new model +by following the code) + +@param type + +0 - csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 (Chinese) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese + int8 + +1 - icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04 (English) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#icefall-asr-multidataset-pruned-transducer-stateless7-2023-05-04-english + encoder int8, decoder/joiner float32 + +2 - sherpa-onnx-whisper-tiny.en + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en + encoder int8, decoder int8 + +3 - sherpa-onnx-whisper-base.en + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en + encoder int8, decoder int8 + +4 - pkufool/icefall-asr-zipformer-wenetspeech-20230615 (Chinese) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#pkufool-icefall-asr-zipformer-wenetspeech-20230615-chinese + encoder/joiner int8, decoder fp32 + + */ +fun getOfflineModelConfig(type: Int): OfflineModelConfig? { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-paraformer-zh-2023-03-28" + return OfflineModelConfig( + paraformer = OfflineParaformerModelConfig( + model = "$modelDir/model.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "paraformer", + ) + } + + 1 -> { + val modelDir = "icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-30-avg-4.int8.onnx", + decoder = "$modelDir/decoder-epoch-30-avg-4.onnx", + joiner = "$modelDir/joiner-epoch-30-avg-4.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + 2 -> { + val modelDir = "sherpa-onnx-whisper-tiny.en" + return OfflineModelConfig( + whisper = OfflineWhisperModelConfig( + encoder = "$modelDir/tiny.en-encoder.int8.onnx", + decoder = "$modelDir/tiny.en-decoder.int8.onnx", + ), + tokens = "$modelDir/tiny.en-tokens.txt", + modelType = "whisper", + ) + } + + 3 -> { + val modelDir = "sherpa-onnx-whisper-base.en" + return OfflineModelConfig( + whisper = OfflineWhisperModelConfig( + encoder = "$modelDir/base.en-encoder.int8.onnx", + decoder = "$modelDir/base.en-decoder.int8.onnx", + ), + tokens = "$modelDir/base.en-tokens.txt", + modelType = "whisper", + ) + } + + + 4 -> { + val modelDir = "icefall-asr-zipformer-wenetspeech-20230615" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-12-avg-4.int8.onnx", + decoder = "$modelDir/decoder-epoch-12-avg-4.onnx", + joiner = "$modelDir/joiner-epoch-12-avg-4.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + 5 -> { + val modelDir = "sherpa-onnx-zipformer-multi-zh-hans-2023-9-2" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-20-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-20-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-20-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer2", + ) + } + + } + return null +} diff --git a/sherpa-onnx/kotlin-api/OfflineStream.kt b/sherpa-onnx/kotlin-api/OfflineStream.kt new file mode 100644 index 000000000..49652e72d --- /dev/null +++ b/sherpa-onnx/kotlin-api/OfflineStream.kt @@ -0,0 +1,24 @@ +package com.k2fsa.sherpa.onnx + +class OfflineStream(var ptr: Long) { + fun acceptWaveform(samples: FloatArray, sampleRate: Int) = + acceptWaveform(ptr, samples, sampleRate) + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) + private external fun delete(ptr: Long) + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt similarity index 78% rename from android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt rename to sherpa-onnx/kotlin-api/OnlineRecognizer.kt index dfd8a4d80..cd2629c97 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt @@ -1,4 +1,3 @@ -// Copyright (c) 2023 Xiaomi Corporation package com.k2fsa.sherpa.onnx import android.content.res.AssetManager @@ -46,15 +45,11 @@ data class OnlineLMConfig( var scale: Float = 0.5f, ) -data class FeatureConfig( - var sampleRate: Int = 16000, - var featureDim: Int = 80, -) data class OnlineRecognizerConfig( var featConfig: FeatureConfig = FeatureConfig(), var modelConfig: OnlineModelConfig, - var lmConfig: OnlineLMConfig, + var lmConfig: OnlineLMConfig = OnlineLMConfig(), var endpointConfig: EndpointConfig = EndpointConfig(), var enableEndpoint: Boolean = true, var decodingMethod: String = "greedy_search", @@ -63,17 +58,24 @@ data class OnlineRecognizerConfig( var hotwordsScore: Float = 1.5f, ) -class SherpaOnnx( +data class OnlineRecognizerResult( + val text: String, + val tokens: Array, + val timestamps: FloatArray, + // TODO(fangjun): Add more fields +) + +class OnlineRecognizer( assetManager: AssetManager? = null, - var config: OnlineRecognizerConfig, + val config: OnlineRecognizerConfig, ) { private val ptr: Long init { - if (assetManager != null) { - ptr = new(assetManager, config) + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) } else { - ptr = newFromFile(config) + newFromFile(config) } } @@ -81,24 +83,30 @@ class SherpaOnnx( delete(ptr) } - fun acceptWaveform(samples: FloatArray, sampleRate: Int) = - acceptWaveform(ptr, samples, sampleRate) + fun release() = finalize() + + fun createStream(hotwords: String = ""): OnlineStream { + val p = createStream(ptr, hotwords) + return OnlineStream(p) + } - fun inputFinished() = inputFinished(ptr) - fun reset(recreate: Boolean = false, hotwords: String = "") = reset(ptr, recreate, hotwords) - fun decode() = decode(ptr) - fun isEndpoint(): Boolean = isEndpoint(ptr) - fun isReady(): Boolean = isReady(ptr) + fun reset(stream: OnlineStream) = reset(ptr, stream.ptr) + fun decode(stream: OnlineStream) = decode(ptr, stream.ptr) + fun isEndpoint(stream: OnlineStream) = isEndpoint(ptr, stream.ptr) + fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr) + fun getResult(stream: OnlineStream): OnlineRecognizerResult { + val objArray = getResult(ptr, stream.ptr) - val text: String - get() = getText(ptr) + val text = objArray[0] as String + val tokens = objArray[1] as Array + val timestamps = objArray[2] as FloatArray - val tokens: Array - get() = getTokens(ptr) + return OnlineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps) + } private external fun delete(ptr: Long) - private external fun new( + private external fun newFromAsset( assetManager: AssetManager, config: OnlineRecognizerConfig, ): Long @@ -107,14 +115,12 @@ class SherpaOnnx( config: OnlineRecognizerConfig, ): Long - private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) - private external fun inputFinished(ptr: Long) - private external fun getText(ptr: Long): String - private external fun reset(ptr: Long, recreate: Boolean, hotwords: String) - private external fun decode(ptr: Long) - private external fun isEndpoint(ptr: Long): Boolean - private external fun isReady(ptr: Long): Boolean - private external fun getTokens(ptr: Long): Array + private external fun createStream(ptr: Long, hotwords: String): Long + private external fun reset(ptr: Long, streamPtr: Long) + private external fun decode(ptr: Long, streamPtr: Long) + private external fun isEndpoint(ptr: Long, streamPtr: Long): Boolean + private external fun isReady(ptr: Long, streamPtr: Long): Boolean + private external fun getResult(ptr: Long, streamPtr: Long): Array companion object { init { @@ -123,9 +129,6 @@ class SherpaOnnx( } } -fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig { - return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim) -} /* Please see @@ -277,14 +280,40 @@ fun getModelConfig(type: Int): OnlineModelConfig? { transducer = OnlineTransducerModelConfig( encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", - joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + 9 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + 10 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", ), tokens = "$modelDir/tokens.txt", modelType = "zipformer", ) } } - return null; + return null } /* @@ -310,7 +339,7 @@ fun getOnlineLMConfig(type: Int): OnlineLMConfig { ) } } - return OnlineLMConfig(); + return OnlineLMConfig() } fun getEndpointConfig(): EndpointConfig { @@ -320,3 +349,4 @@ fun getEndpointConfig(): EndpointConfig { rule3 = EndpointRule(false, 0.0f, 20.0f) ) } + diff --git a/sherpa-onnx/kotlin-api/OnlineStream.kt b/sherpa-onnx/kotlin-api/OnlineStream.kt new file mode 100644 index 000000000..6057fabd0 --- /dev/null +++ b/sherpa-onnx/kotlin-api/OnlineStream.kt @@ -0,0 +1,27 @@ +package com.k2fsa.sherpa.onnx + +class OnlineStream(var ptr: Long = 0) { + fun acceptWaveform(samples: FloatArray, sampleRate: Int) = + acceptWaveform(ptr, samples, sampleRate) + + fun inputFinished() = inputFinished(ptr) + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) + private external fun inputFinished(ptr: Long) + private external fun delete(ptr: Long) + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} diff --git a/sherpa-onnx/kotlin-api/Speaker.kt b/sherpa-onnx/kotlin-api/Speaker.kt new file mode 100644 index 000000000..93b1b9e4e --- /dev/null +++ b/sherpa-onnx/kotlin-api/Speaker.kt @@ -0,0 +1,164 @@ +package com.k2fsa.sherpa.onnx + +import android.content.res.AssetManager +import android.util.Log + +data class SpeakerEmbeddingExtractorConfig( + val model: String, + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +class SpeakerEmbeddingExtractor( + assetManager: AssetManager? = null, + config: SpeakerEmbeddingExtractorConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun createStream(): OnlineStream { + val p = createStream(ptr) + return OnlineStream(p) + } + + fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr) + fun compute(stream: OnlineStream) = compute(ptr, stream.ptr) + fun dim() = dim(ptr) + + private external fun newFromAsset( + assetManager: AssetManager, + config: SpeakerEmbeddingExtractorConfig, + ): Long + + private external fun newFromFile( + config: SpeakerEmbeddingExtractorConfig, + ): Long + + private external fun delete(ptr: Long) + + private external fun createStream(ptr: Long): Long + + private external fun isReady(ptr: Long, streamPtr: Long): Boolean + + private external fun compute(ptr: Long, streamPtr: Long): FloatArray + + private external fun dim(ptr: Long): Int + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} + +class SpeakerEmbeddingManager(val dim: Int) { + private var ptr: Long + + init { + ptr = create(dim) + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + fun add(name: String, embedding: FloatArray) = add(ptr, name, embedding) + fun add(name: String, embedding: Array) = addList(ptr, name, embedding) + fun remove(name: String) = remove(ptr, name) + fun search(embedding: FloatArray, threshold: Float) = search(ptr, embedding, threshold) + fun verify(name: String, embedding: FloatArray, threshold: Float) = + verify(ptr, name, embedding, threshold) + + fun contains(name: String) = contains(ptr, name) + fun numSpeakers() = numSpeakers(ptr) + + fun allSpeakerNames() = allSpeakerNames(ptr) + + private external fun create(dim: Int): Long + private external fun delete(ptr: Long): Unit + private external fun add(ptr: Long, name: String, embedding: FloatArray): Boolean + private external fun addList(ptr: Long, name: String, embedding: Array): Boolean + private external fun remove(ptr: Long, name: String): Boolean + private external fun search(ptr: Long, embedding: FloatArray, threshold: Float): String + private external fun verify( + ptr: Long, + name: String, + embedding: FloatArray, + threshold: Float + ): Boolean + + private external fun contains(ptr: Long, name: String): Boolean + private external fun numSpeakers(ptr: Long): Int + + private external fun allSpeakerNames(ptr: Long): Array + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} + +// Please download the model file from +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +// and put it inside the assets directory. +// +// Please don't put it in a subdirectory of assets +private val modelName = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" + +object SpeakerRecognition { + var _extractor: SpeakerEmbeddingExtractor? = null + var _manager: SpeakerEmbeddingManager? = null + + val extractor: SpeakerEmbeddingExtractor + get() { + return _extractor!! + } + + val manager: SpeakerEmbeddingManager + get() { + return _manager!! + } + + fun initExtractor(assetManager: AssetManager? = null) { + synchronized(this) { + if (_extractor != null) { + return + } + Log.i("sherpa-onnx", "Initializing speaker embedding extractor") + + _extractor = SpeakerEmbeddingExtractor( + assetManager = assetManager, + config = SpeakerEmbeddingExtractorConfig( + model = modelName, + numThreads = 2, + debug = false, + provider = "cpu", + ) + ) + + _manager = SpeakerEmbeddingManager(dim = _extractor!!.dim()) + } + } +} diff --git a/sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt b/sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt new file mode 100644 index 000000000..00caca281 --- /dev/null +++ b/sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt @@ -0,0 +1,103 @@ +package com.k2fsa.sherpa.onnx + +import android.content.res.AssetManager + +data class SpokenLanguageIdentificationWhisperConfig( + var encoder: String, + var decoder: String, + var tailPaddings: Int = -1, +) + +data class SpokenLanguageIdentificationConfig( + var whisper: SpokenLanguageIdentificationWhisperConfig, + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +class SpokenLanguageIdentification( + assetManager: AssetManager? = null, + config: SpokenLanguageIdentificationConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun createStream(): OfflineStream { + val p = createStream(ptr) + return OfflineStream(p) + } + + fun compute(stream: OfflineStream) = compute(ptr, stream.ptr) + + private external fun newFromAsset( + assetManager: AssetManager, + config: SpokenLanguageIdentificationConfig, + ): Long + + private external fun newFromFile( + config: SpokenLanguageIdentificationConfig, + ): Long + + private external fun delete(ptr: Long) + + private external fun createStream(ptr: Long): Long + + private external fun compute(ptr: Long, streamPtr: Long): String + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} + +// please refer to +// https://k2-fsa.github.io/sherpa/onnx/spolken-language-identification/pretrained_models.html#whisper +// to download more models +fun getSpokenLanguageIdentificationConfig( + type: Int, + numThreads: Int = 1 +): SpokenLanguageIdentificationConfig? { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-whisper-tiny" + return SpokenLanguageIdentificationConfig( + whisper = SpokenLanguageIdentificationWhisperConfig( + encoder = "$modelDir/tiny-encoder.int8.onnx", + decoder = "$modelDir/tiny-decoder.int8.onnx", + ), + numThreads = numThreads, + debug = true, + ) + } + + 1 -> { + val modelDir = "sherpa-onnx-whisper-base" + return SpokenLanguageIdentificationConfig( + whisper = SpokenLanguageIdentificationWhisperConfig( + encoder = "$modelDir/tiny-encoder.int8.onnx", + decoder = "$modelDir/tiny-decoder.int8.onnx", + ), + numThreads = 1, + debug = true, + ) + } + } + return null +} diff --git a/sherpa-onnx/kotlin-api/Vad.kt b/sherpa-onnx/kotlin-api/Vad.kt new file mode 100644 index 000000000..7791166c9 --- /dev/null +++ b/sherpa-onnx/kotlin-api/Vad.kt @@ -0,0 +1,104 @@ +// Copyright (c) 2023 Xiaomi Corporation +package com.k2fsa.sherpa.onnx + +import android.content.res.AssetManager + +data class SileroVadModelConfig( + var model: String, + var threshold: Float = 0.5F, + var minSilenceDuration: Float = 0.25F, + var minSpeechDuration: Float = 0.25F, + var windowSize: Int = 512, +) + +data class VadModelConfig( + var sileroVadModelConfig: SileroVadModelConfig, + var sampleRate: Int = 16000, + var numThreads: Int = 1, + var provider: String = "cpu", + var debug: Boolean = false, +) + +class Vad( + assetManager: AssetManager? = null, + var config: VadModelConfig, +) { + private val ptr: Long + + init { + if (assetManager != null) { + ptr = newFromAsset(assetManager, config) + } else { + ptr = newFromFile(config) + } + } + + protected fun finalize() { + delete(ptr) + } + + fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples) + + fun empty(): Boolean = empty(ptr) + fun pop() = pop(ptr) + + // return an array containing + // [start: Int, samples: FloatArray] + fun front() = front(ptr) + + fun clear() = clear(ptr) + + fun isSpeechDetected(): Boolean = isSpeechDetected(ptr) + + fun reset() = reset(ptr) + + private external fun delete(ptr: Long) + + private external fun newFromAsset( + assetManager: AssetManager, + config: VadModelConfig, + ): Long + + private external fun newFromFile( + config: VadModelConfig, + ): Long + + private external fun acceptWaveform(ptr: Long, samples: FloatArray) + private external fun empty(ptr: Long): Boolean + private external fun pop(ptr: Long) + private external fun clear(ptr: Long) + private external fun front(ptr: Long): Array + private external fun isSpeechDetected(ptr: Long): Boolean + private external fun reset(ptr: Long) + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} + +// Please visit +// https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx +// to download silero_vad.onnx +// and put it inside the assets/ +// directory +fun getVadModelConfig(type: Int): VadModelConfig? { + when (type) { + 0 -> { + return VadModelConfig( + sileroVadModelConfig = SileroVadModelConfig( + model = "silero_vad.onnx", + threshold = 0.5F, + minSilenceDuration = 0.25F, + minSpeechDuration = 0.25F, + windowSize = 512, + ), + sampleRate = 16000, + numThreads = 1, + provider = "cpu", + ) + } + } + return null; +} diff --git a/android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt b/sherpa-onnx/kotlin-api/WaveReader.kt similarity index 100% rename from android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt rename to sherpa-onnx/kotlin-api/WaveReader.kt