diff --git a/.idea/deploymentTargetDropDown.xml b/.idea/deploymentTargetDropDown.xml new file mode 100644 index 0000000..cdf10b6 --- /dev/null +++ b/.idea/deploymentTargetDropDown.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/app/build.gradle b/app/build.gradle index 6df36b7..ab0277e 100644 --- a/app/build.gradle +++ b/app/build.gradle @@ -77,11 +77,15 @@ dependencies { implementation 'com.google.android.material:material:1.7.0' implementation 'androidx.constraintlayout:constraintlayout:2.1.4' implementation 'org.tensorflow:tensorflow-lite:2.11.0' + implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.11.0' implementation 'org.tensorflow:tensorflow-lite-api:2.11.0' implementation 'org.tensorflow:tensorflow-lite-gpu:2.11.0' implementation 'org.tensorflow:tensorflow-lite-gpu-api:2.11.0' implementation 'org.tensorflow:tensorflow-lite-support:0.4.3' implementation 'org.tensorflow:tensorflow-lite-support-api:0.4.3' + + // This dependency adds the necessary TF op support. + testImplementation 'junit:junit:4.13.2' androidTestImplementation 'androidx.test.ext:junit:1.1.4' androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0' diff --git a/app/src/main/assets/nnmodel/mine/whisper-decoder-language-hybrid.tflite b/app/src/main/assets/nnmodel/mine/whisper-decoder-language-hybrid.tflite new file mode 100644 index 0000000..5fbd8f9 Binary files /dev/null and b/app/src/main/assets/nnmodel/mine/whisper-decoder-language-hybrid.tflite differ diff --git a/app/src/main/assets/nnmodel/mine/whisper-encoder-hybrid.tflite b/app/src/main/assets/nnmodel/mine/whisper-encoder-hybrid.tflite new file mode 100644 index 0000000..b43becb Binary files /dev/null and b/app/src/main/assets/nnmodel/mine/whisper-encoder-hybrid.tflite differ diff --git a/app/src/main/assets/whisper.tflite b/app/src/main/assets/nnmodel/mine/whisper.tflite similarity index 100% rename from app/src/main/assets/whisper.tflite rename to app/src/main/assets/nnmodel/mine/whisper.tflite diff --git a/app/src/main/assets/nnmodel/mine/whisper_f16.tflite b/app/src/main/assets/nnmodel/mine/whisper_f16.tflite new file mode 100644 index 0000000..7d4a912 Binary files /dev/null and b/app/src/main/assets/nnmodel/mine/whisper_f16.tflite differ diff --git a/app/src/main/assets/nnmodel/mobilenet_v2_quantized_1x3x224x224.tflite b/app/src/main/assets/nnmodel/mobilenet_v2_quantized_1x3x224x224.tflite new file mode 100644 index 0000000..d5c59d1 Binary files /dev/null and b/app/src/main/assets/nnmodel/mobilenet_v2_quantized_1x3x224x224.tflite differ diff --git a/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java b/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java index 6e96712..54c1935 100644 --- a/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java +++ b/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java @@ -29,6 +29,9 @@ import com.example.WhisperVoiceKeyboard.R; import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.flex.FlexDelegate; +import org.tensorflow.lite.gpu.GpuDelegate; +import org.tensorflow.lite.nnapi.NnApiDelegate; import java.io.FileInputStream; import java.io.IOException; @@ -40,10 +43,13 @@ public class VoiceKeyboardInputMethodService extends InputMethodService { - private Interpreter _whisperInterpreter; + private Interpreter _nnapiEncoder; + private Interpreter _nnapiDecoder; private Dictionary _dictionary; - private static final String WHISPER_TFLITE = "whisper.tflite"; + + private static final String WHISPER_ENCODER = "nnmodel/nyadia/whisper-encoder.tflite"; + private static final String WHISPER_DECODER_LANGUAGE = "nnmodel/nyadia/whisper-decoder_language.tflite"; private static final boolean LOG_AND_DRAW = false; @@ -54,25 +60,40 @@ public void onCreate() { super.onCreate(); - try { - Vocab vocab = ExtractVocab.extractVocab(getAssets().open("filters_vocab_gen.bin")); - HashMap phraseMappings = new HashMap<>(); + Interpreter.Options nnapiOptions = new Interpreter.Options(); + NnApiDelegate nnapiDelegate = new NnApiDelegate(); + FlexDelegate flexDelegate = new FlexDelegate(); + GpuDelegate gpuDelegate = new GpuDelegate(); - _dictionary = new Dictionary(vocab, phraseMappings); - MappedByteBuffer model = loadWhisperModel(getAssets()); + nnapiOptions.addDelegate(flexDelegate); + nnapiOptions.addDelegate(gpuDelegate); + nnapiOptions.addDelegate(nnapiDelegate); + - Interpreter.Options options = new Interpreter.Options(); + nnapiOptions.setNumThreads(0); + nnapiOptions.setUseXNNPACK(true); + nnapiOptions.setUseNNAPI(true); - options.setUseXNNPACK(true); - options.setNumThreads(8); + try { + + + MappedByteBuffer whisper_encoder = loadWhisperModel(getAssets(), WHISPER_ENCODER); + MappedByteBuffer whisper_decoder_language = loadWhisperModel(getAssets(), WHISPER_DECODER_LANGUAGE); - _whisperInterpreter = new Interpreter(model, options); + _nnapiEncoder = new Interpreter(whisper_encoder, nnapiOptions); + _nnapiDecoder = new Interpreter(whisper_decoder_language, nnapiOptions); + Vocab vocab = ExtractVocab.extractVocab(getAssets().open("filters_vocab_gen.bin")); + HashMap phraseMappings = new HashMap<>(); + _dictionary = new Dictionary(vocab, phraseMappings); - } catch (IOException e) { + } catch (Exception e) { e.printStackTrace(); + System.exit(-1); } + + RustLib.init(getAssets()); } @@ -188,16 +209,51 @@ private void sendDelete() { private String transcribeAudio(float[] byteBuffer) { int[] inputShape = {1, 80, 3000}; - float[][][] reshapedFloats = reshapeInput(byteBuffer, inputShape); - int[][] output = new int[1][224]; + Map inputsEncoder = new HashMap<>(); + Map outputsEncoder = new HashMap<>(); + Map inputsDecoder = new HashMap<>(); + Map outputsDecoder = new HashMap<>(); + + String signatureKey = "serving_default"; + String[] nnapiEncoderSignatureInputs = _nnapiEncoder.getSignatureInputs(signatureKey); + String[] nnapiEncoderSignatureOutputs = _nnapiEncoder.getSignatureOutputs(signatureKey); + String[] nnapiDecoderSignatureInputs = _nnapiDecoder.getSignatureInputs(signatureKey); + String[] nnapiDecoderSignatureOutputs = _nnapiDecoder.getSignatureOutputs(signatureKey); + + String encoderInputKey0 = nnapiEncoderSignatureInputs[0]; + String encoderOutputKey0 = nnapiEncoderSignatureOutputs[0]; + String decoderInputKey0 = nnapiDecoderSignatureInputs[0]; + String decoderInputKey1 = nnapiDecoderSignatureInputs[1]; + String decoderOutputKey0 = nnapiDecoderSignatureOutputs[0]; + + inputsEncoder.put(encoderInputKey0, reshapeInput(byteBuffer, inputShape)); + float[][][] encoder_output = new float[1][1500][384]; + outputsEncoder.put(encoderOutputKey0, encoder_output); + + + _nnapiEncoder.runSignature(inputsEncoder, outputsEncoder, signatureKey); + + + long[][][] encoder_output_int = new long[1][1500][384]; + + inputsDecoder.put(decoderInputKey0, encoder_output_int); + + float[][] decoder_ids = new float[1][384]; + decoder_ids[0][0] = 50258; + decoder_ids[0][1] = 50266; + decoder_ids[0][2] = 50358; + decoder_ids[0][3] = 50363; + inputsDecoder.put(decoderInputKey1, decoder_ids); - Map inputs = new HashMap<>(); - inputs.put("input_features", reshapedFloats); - Map outputs = new HashMap<>(); - outputs.put("sequences", output); + int[] shape = new int[2]; + shape[0] = 1; + shape[1] = 4; + _nnapiDecoder.resizeInput(1, shape); + float[][] output = new float[1][224]; + outputsDecoder.put(decoderOutputKey0, output); - _whisperInterpreter.runSignature(inputs, outputs, "serving_default"); - String whisperOutput = _dictionary.tokensToString(output); + _nnapiDecoder.runSignature(inputsDecoder, outputsDecoder, signatureKey); + String whisperOutput = _dictionary.tokensToString(new int[1][224]); return _dictionary.injectTokens(whisperOutput); } @@ -217,9 +273,9 @@ private float[][][] reshapeInput(float[] byteBuffer, int[] inputShape) { return reshapedFloats; } - private static MappedByteBuffer loadWhisperModel(AssetManager assets) + private static MappedByteBuffer loadWhisperModel(AssetManager assets, String modelName) throws IOException { - AssetFileDescriptor fileDescriptor = assets.openFd(WHISPER_TFLITE); + AssetFileDescriptor fileDescriptor = assets.openFd(modelName); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); diff --git a/app/src/main/jniLibs/arm64-v8a/librust.so b/app/src/main/jniLibs/arm64-v8a/librust.so index ef0d1a1..cf0f0b7 100755 Binary files a/app/src/main/jniLibs/arm64-v8a/librust.so and b/app/src/main/jniLibs/arm64-v8a/librust.so differ diff --git a/app/src/main/jniLibs/armeabi-v7a/librust.so b/app/src/main/jniLibs/armeabi-v7a/librust.so index f869b12..fdbcf98 100755 Binary files a/app/src/main/jniLibs/armeabi-v7a/librust.so and b/app/src/main/jniLibs/armeabi-v7a/librust.so differ diff --git a/app/src/main/jniLibs/x86/librust.so b/app/src/main/jniLibs/x86/librust.so index 0ae8e61..4eb0d32 100755 Binary files a/app/src/main/jniLibs/x86/librust.so and b/app/src/main/jniLibs/x86/librust.so differ diff --git a/app/src/main/jniLibs/x86_64/librust.so b/app/src/main/jniLibs/x86_64/librust.so index 906d0a0..0efa4cd 100755 Binary files a/app/src/main/jniLibs/x86_64/librust.so and b/app/src/main/jniLibs/x86_64/librust.so differ