diff --git a/.idea/deploymentTargetDropDown.xml b/.idea/deploymentTargetDropDown.xml
new file mode 100644
index 0000000..cdf10b6
--- /dev/null
+++ b/.idea/deploymentTargetDropDown.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/app/build.gradle b/app/build.gradle
index 6df36b7..ab0277e 100644
--- a/app/build.gradle
+++ b/app/build.gradle
@@ -77,11 +77,15 @@ dependencies {
implementation 'com.google.android.material:material:1.7.0'
implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
implementation 'org.tensorflow:tensorflow-lite:2.11.0'
+ implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.11.0'
implementation 'org.tensorflow:tensorflow-lite-api:2.11.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.11.0'
implementation 'org.tensorflow:tensorflow-lite-gpu-api:2.11.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.4.3'
implementation 'org.tensorflow:tensorflow-lite-support-api:0.4.3'
+
+ // This dependency adds the necessary TF op support.
+
testImplementation 'junit:junit:4.13.2'
androidTestImplementation 'androidx.test.ext:junit:1.1.4'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0'
diff --git a/app/src/main/assets/nnmodel/mine/whisper-decoder-language-hybrid.tflite b/app/src/main/assets/nnmodel/mine/whisper-decoder-language-hybrid.tflite
new file mode 100644
index 0000000..5fbd8f9
Binary files /dev/null and b/app/src/main/assets/nnmodel/mine/whisper-decoder-language-hybrid.tflite differ
diff --git a/app/src/main/assets/nnmodel/mine/whisper-encoder-hybrid.tflite b/app/src/main/assets/nnmodel/mine/whisper-encoder-hybrid.tflite
new file mode 100644
index 0000000..b43becb
Binary files /dev/null and b/app/src/main/assets/nnmodel/mine/whisper-encoder-hybrid.tflite differ
diff --git a/app/src/main/assets/whisper.tflite b/app/src/main/assets/nnmodel/mine/whisper.tflite
similarity index 100%
rename from app/src/main/assets/whisper.tflite
rename to app/src/main/assets/nnmodel/mine/whisper.tflite
diff --git a/app/src/main/assets/nnmodel/mine/whisper_f16.tflite b/app/src/main/assets/nnmodel/mine/whisper_f16.tflite
new file mode 100644
index 0000000..7d4a912
Binary files /dev/null and b/app/src/main/assets/nnmodel/mine/whisper_f16.tflite differ
diff --git a/app/src/main/assets/nnmodel/mobilenet_v2_quantized_1x3x224x224.tflite b/app/src/main/assets/nnmodel/mobilenet_v2_quantized_1x3x224x224.tflite
new file mode 100644
index 0000000..d5c59d1
Binary files /dev/null and b/app/src/main/assets/nnmodel/mobilenet_v2_quantized_1x3x224x224.tflite differ
diff --git a/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java b/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java
index 6e96712..54c1935 100644
--- a/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java
+++ b/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java
@@ -29,6 +29,9 @@
import com.example.WhisperVoiceKeyboard.R;
import org.tensorflow.lite.Interpreter;
+import org.tensorflow.lite.flex.FlexDelegate;
+import org.tensorflow.lite.gpu.GpuDelegate;
+import org.tensorflow.lite.nnapi.NnApiDelegate;
import java.io.FileInputStream;
import java.io.IOException;
@@ -40,10 +43,13 @@
public class VoiceKeyboardInputMethodService extends InputMethodService {
- private Interpreter _whisperInterpreter;
+ private Interpreter _nnapiEncoder;
+ private Interpreter _nnapiDecoder;
private Dictionary _dictionary;
- private static final String WHISPER_TFLITE = "whisper.tflite";
+
+ private static final String WHISPER_ENCODER = "nnmodel/nyadia/whisper-encoder.tflite";
+ private static final String WHISPER_DECODER_LANGUAGE = "nnmodel/nyadia/whisper-decoder_language.tflite";
private static final boolean LOG_AND_DRAW = false;
@@ -54,25 +60,40 @@ public void onCreate() {
super.onCreate();
- try {
- Vocab vocab = ExtractVocab.extractVocab(getAssets().open("filters_vocab_gen.bin"));
- HashMap phraseMappings = new HashMap<>();
+ Interpreter.Options nnapiOptions = new Interpreter.Options();
+ NnApiDelegate nnapiDelegate = new NnApiDelegate();
+ FlexDelegate flexDelegate = new FlexDelegate();
+ GpuDelegate gpuDelegate = new GpuDelegate();
- _dictionary = new Dictionary(vocab, phraseMappings);
- MappedByteBuffer model = loadWhisperModel(getAssets());
+ nnapiOptions.addDelegate(flexDelegate);
+ nnapiOptions.addDelegate(gpuDelegate);
+ nnapiOptions.addDelegate(nnapiDelegate);
+
- Interpreter.Options options = new Interpreter.Options();
+ nnapiOptions.setNumThreads(0);
+ nnapiOptions.setUseXNNPACK(true);
+ nnapiOptions.setUseNNAPI(true);
- options.setUseXNNPACK(true);
- options.setNumThreads(8);
+ try {
+
+
+ MappedByteBuffer whisper_encoder = loadWhisperModel(getAssets(), WHISPER_ENCODER);
+ MappedByteBuffer whisper_decoder_language = loadWhisperModel(getAssets(), WHISPER_DECODER_LANGUAGE);
- _whisperInterpreter = new Interpreter(model, options);
+ _nnapiEncoder = new Interpreter(whisper_encoder, nnapiOptions);
+ _nnapiDecoder = new Interpreter(whisper_decoder_language, nnapiOptions);
+ Vocab vocab = ExtractVocab.extractVocab(getAssets().open("filters_vocab_gen.bin"));
+ HashMap phraseMappings = new HashMap<>();
+ _dictionary = new Dictionary(vocab, phraseMappings);
- } catch (IOException e) {
+ } catch (Exception e) {
e.printStackTrace();
+ System.exit(-1);
}
+
+
RustLib.init(getAssets());
}
@@ -188,16 +209,51 @@ private void sendDelete() {
private String transcribeAudio(float[] byteBuffer) {
int[] inputShape = {1, 80, 3000};
- float[][][] reshapedFloats = reshapeInput(byteBuffer, inputShape);
- int[][] output = new int[1][224];
+ Map inputsEncoder = new HashMap<>();
+ Map outputsEncoder = new HashMap<>();
+ Map inputsDecoder = new HashMap<>();
+ Map outputsDecoder = new HashMap<>();
+
+ String signatureKey = "serving_default";
+ String[] nnapiEncoderSignatureInputs = _nnapiEncoder.getSignatureInputs(signatureKey);
+ String[] nnapiEncoderSignatureOutputs = _nnapiEncoder.getSignatureOutputs(signatureKey);
+ String[] nnapiDecoderSignatureInputs = _nnapiDecoder.getSignatureInputs(signatureKey);
+ String[] nnapiDecoderSignatureOutputs = _nnapiDecoder.getSignatureOutputs(signatureKey);
+
+ String encoderInputKey0 = nnapiEncoderSignatureInputs[0];
+ String encoderOutputKey0 = nnapiEncoderSignatureOutputs[0];
+ String decoderInputKey0 = nnapiDecoderSignatureInputs[0];
+ String decoderInputKey1 = nnapiDecoderSignatureInputs[1];
+ String decoderOutputKey0 = nnapiDecoderSignatureOutputs[0];
+
+ inputsEncoder.put(encoderInputKey0, reshapeInput(byteBuffer, inputShape));
+ float[][][] encoder_output = new float[1][1500][384];
+ outputsEncoder.put(encoderOutputKey0, encoder_output);
+
+
+ _nnapiEncoder.runSignature(inputsEncoder, outputsEncoder, signatureKey);
+
+
+ long[][][] encoder_output_int = new long[1][1500][384];
+
+ inputsDecoder.put(decoderInputKey0, encoder_output_int);
+
+ float[][] decoder_ids = new float[1][384];
+ decoder_ids[0][0] = 50258;
+ decoder_ids[0][1] = 50266;
+ decoder_ids[0][2] = 50358;
+ decoder_ids[0][3] = 50363;
+ inputsDecoder.put(decoderInputKey1, decoder_ids);
- Map inputs = new HashMap<>();
- inputs.put("input_features", reshapedFloats);
- Map outputs = new HashMap<>();
- outputs.put("sequences", output);
+ int[] shape = new int[2];
+ shape[0] = 1;
+ shape[1] = 4;
+ _nnapiDecoder.resizeInput(1, shape);
+ float[][] output = new float[1][224];
+ outputsDecoder.put(decoderOutputKey0, output);
- _whisperInterpreter.runSignature(inputs, outputs, "serving_default");
- String whisperOutput = _dictionary.tokensToString(output);
+ _nnapiDecoder.runSignature(inputsDecoder, outputsDecoder, signatureKey);
+ String whisperOutput = _dictionary.tokensToString(new int[1][224]);
return _dictionary.injectTokens(whisperOutput);
}
@@ -217,9 +273,9 @@ private float[][][] reshapeInput(float[] byteBuffer, int[] inputShape) {
return reshapedFloats;
}
- private static MappedByteBuffer loadWhisperModel(AssetManager assets)
+ private static MappedByteBuffer loadWhisperModel(AssetManager assets, String modelName)
throws IOException {
- AssetFileDescriptor fileDescriptor = assets.openFd(WHISPER_TFLITE);
+ AssetFileDescriptor fileDescriptor = assets.openFd(modelName);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
diff --git a/app/src/main/jniLibs/arm64-v8a/librust.so b/app/src/main/jniLibs/arm64-v8a/librust.so
index ef0d1a1..cf0f0b7 100755
Binary files a/app/src/main/jniLibs/arm64-v8a/librust.so and b/app/src/main/jniLibs/arm64-v8a/librust.so differ
diff --git a/app/src/main/jniLibs/armeabi-v7a/librust.so b/app/src/main/jniLibs/armeabi-v7a/librust.so
index f869b12..fdbcf98 100755
Binary files a/app/src/main/jniLibs/armeabi-v7a/librust.so and b/app/src/main/jniLibs/armeabi-v7a/librust.so differ
diff --git a/app/src/main/jniLibs/x86/librust.so b/app/src/main/jniLibs/x86/librust.so
index 0ae8e61..4eb0d32 100755
Binary files a/app/src/main/jniLibs/x86/librust.so and b/app/src/main/jniLibs/x86/librust.so differ
diff --git a/app/src/main/jniLibs/x86_64/librust.so b/app/src/main/jniLibs/x86_64/librust.so
index 906d0a0..0efa4cd 100755
Binary files a/app/src/main/jniLibs/x86_64/librust.so and b/app/src/main/jniLibs/x86_64/librust.so differ