diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index a963ade..a60d479 100644 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -2,6 +2,7 @@ + @@ -12,6 +13,15 @@ android:roundIcon="@drawable/mic" android:supportsRtl="true" android:theme="@style/Theme.WhisperVoiceKeyboard"> + + + + + diff --git a/app/src/main/assets/nnmodel/mine/whisper-decoder-language-hybrid.tflite b/app/src/main/assets/nnmodel/mine/whisper-decoder_language.tflite similarity index 99% rename from app/src/main/assets/nnmodel/mine/whisper-decoder-language-hybrid.tflite rename to app/src/main/assets/nnmodel/mine/whisper-decoder_language.tflite index 5fbd8f9..42b4ed5 100644 Binary files a/app/src/main/assets/nnmodel/mine/whisper-decoder-language-hybrid.tflite and b/app/src/main/assets/nnmodel/mine/whisper-decoder_language.tflite differ diff --git a/app/src/main/assets/nnmodel/mine/whisper-encoder-hybrid.tflite b/app/src/main/assets/nnmodel/mine/whisper-encoder.tflite similarity index 100% rename from app/src/main/assets/nnmodel/mine/whisper-encoder-hybrid.tflite rename to app/src/main/assets/nnmodel/mine/whisper-encoder.tflite diff --git a/app/src/main/java/com/mjm/whisperVoiceRecognition/Dictionary.java b/app/src/main/java/com/mjm/whisperVoiceRecognition/Dictionary.java index 78c4e7f..33ee7d8 100644 --- a/app/src/main/java/com/mjm/whisperVoiceRecognition/Dictionary.java +++ b/app/src/main/java/com/mjm/whisperVoiceRecognition/Dictionary.java @@ -23,14 +23,14 @@ public Dictionary(Vocab tokenMappings, Map phraseMappings) { * @return String composed of words of the tokens in the array */ @NonNull - public String tokensToString(int[][] output) { + public String tokensToString(long[] output) { StringBuilder sb = new StringBuilder(); - for (int token : output[0]) { - if (token == _vocab.token_eot) { + for (long token : output) { + if (token == _vocab.tokenEndOfTranscript) { break; } - if (token != 50257 && token != 50362) { - String word = _vocab.id_to_token.get(token); + if (token != _vocab.tokenStartOfTranscript && token != _vocab.tokenNoTimeStamps) { + String word = _vocab.id_to_token.get((int) token); Log.i("tokenization", "token: " + token + " word " + word); sb.append(word); } @@ -39,6 +39,13 @@ public String tokensToString(int[][] output) { } + public void logAllTokens() { + for (int token = 0; token <= 51865; token += 1) { + String word = _vocab.id_to_token.get(token); + Log.i("tokenization", "token: " + token + " word " + word); + } + } + /** * This method takes a string as an argument and replaces key phrases with special tokens. * @@ -56,4 +63,15 @@ public String injectTokens(String text) { return result; } + public int getNotTimeStamps() { + return _vocab.tokenNoTimeStamps; + } + + public int getStartOfTranscript() { + return _vocab.tokenStartOfTranscript; + } + + public int getEndOfTranscript() { + return _vocab.tokenEndOfTranscript; + } } diff --git a/app/src/main/java/com/mjm/whisperVoiceRecognition/ExtractVocab.java b/app/src/main/java/com/mjm/whisperVoiceRecognition/ExtractVocab.java index b6f1fcc..90e6133 100644 --- a/app/src/main/java/com/mjm/whisperVoiceRecognition/ExtractVocab.java +++ b/app/src/main/java/com/mjm/whisperVoiceRecognition/ExtractVocab.java @@ -1,5 +1,7 @@ package com.mjm.whisperVoiceRecognition; +import android.util.Log; + import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; @@ -10,34 +12,37 @@ public class ExtractVocab { - public static Vocab extractVocab(InputStream filters_vocab_gen_bin) throws IOException { - if (readI32(filters_vocab_gen_bin) == 0x5553454e) { + public static Vocab extractVocab(InputStream filtersVocabBin) throws IOException { + if (readI32(filtersVocabBin) == 0x5553454e) { - readI32(filters_vocab_gen_bin); - readI32(filters_vocab_gen_bin); - readVecF32(filters_vocab_gen_bin, 80 * 201); + readI32(filtersVocabBin); + readI32(filtersVocabBin); + readVecF32(filtersVocabBin, 80 * 201); Vocab vocab = new Vocab(); - int word_count = readI32(filters_vocab_gen_bin); + int word_count = readI32(filtersVocabBin); assert (50257 == word_count); HashMap words = new HashMap<>(word_count); for (int i = 0; i < word_count; i++) { - int nextWordLen = readU32(filters_vocab_gen_bin); - String word = readString(filters_vocab_gen_bin, nextWordLen); + int nextWordLen = readU32(filtersVocabBin); + String word = readString(filtersVocabBin, nextWordLen); words.put(i, word); } vocab.n_vocab = word_count; vocab.id_to_token = words; - if (vocab.isMultilingual()) { - vocab.token_eot += 1; - vocab.token_sot += 1; + if (true) { +// if (vocab.isMultilingual()) { + + Log.i("extractVocab", "Using Multilingual"); + vocab.tokenEndOfTranscript += 1; + vocab.tokenStartOfTranscript += 1; vocab.token_prev += 1; vocab.token_solm += 1; - vocab.token_not += 1; + vocab.tokenNoTimeStamps += 1; vocab.token_beg += 1; } - for (int i = word_count; i < vocab.n_vocab; i++) { +/* for (int i = word_count; i < vocab.n_vocab; i++) { String word; if (i > vocab.token_beg) { word = "[_TT_" + (i - vocab.token_beg) + "]"; @@ -55,7 +60,7 @@ public static Vocab extractVocab(InputStream filters_vocab_gen_bin) throws IOExc word = "[_extra_token_" + i + "]"; } vocab.id_to_token.put(i, word); - } + }*/ System.out.println("Succeeded in Loading Vocab! " + vocab.n_vocab + " (" + vocab.id_to_token.size() + ") Words."); return vocab; } else throw new IOException("bad magic"); diff --git a/app/src/main/java/com/mjm/whisperVoiceRecognition/Transcriber.java b/app/src/main/java/com/mjm/whisperVoiceRecognition/Transcriber.java new file mode 100644 index 0000000..f38e425 --- /dev/null +++ b/app/src/main/java/com/mjm/whisperVoiceRecognition/Transcriber.java @@ -0,0 +1,164 @@ +package com.mjm.whisperVoiceRecognition; + +import android.content.res.AssetFileDescriptor; +import android.content.res.AssetManager; +import android.util.Log; + +import androidx.annotation.NonNull; + +import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.flex.FlexDelegate; +import org.tensorflow.lite.gpu.GpuDelegate; +import org.tensorflow.lite.nnapi.NnApiDelegate; + +import java.io.FileInputStream; +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +public class Transcriber { + public static final String SIGNATURE_KEY = "serving_default"; + public static final int[] ENCODER_INPUT_SHAPE = new int[]{1, 80, 3000}; + private static final String WHISPER_ENCODER = "nnmodel/mine/whisper-encoder.tflite"; + private static final String WHISPER_DECODER_LANGUAGE = "nnmodel/mine/whisper-decoder_language.tflite"; + private Interpreter _encoder; + private Interpreter _decoder; + private Dictionary _dictionary; + + public Transcriber(AssetManager assetManager) { + + + Interpreter.Options nnapiOptions = new Interpreter.Options(); + NnApiDelegate nnapiDelegate = new NnApiDelegate(); + FlexDelegate flexDelegate = new FlexDelegate(); + GpuDelegate gpuDelegate = new GpuDelegate(); + + +// nnapiOptions.addDelegate(flexDelegate); +// nnapiOptions.addDelegate(gpuDelegate); +// nnapiOptions.addDelegate(nnapiDelegate); + + + nnapiOptions.setNumThreads(8); + nnapiOptions.setUseXNNPACK(true); + nnapiOptions.setUseNNAPI(false); + + try { + + + MappedByteBuffer whisper_encoder = loadWhisperModel(assetManager, WHISPER_ENCODER); + MappedByteBuffer whisper_decoder_language = loadWhisperModel(assetManager, WHISPER_DECODER_LANGUAGE); + + _encoder = new Interpreter(whisper_encoder, nnapiOptions); + _decoder = new Interpreter(whisper_decoder_language, nnapiOptions); + Vocab vocab = ExtractVocab.extractVocab(assetManager.open("filters_vocab_multilingual.bin")); + HashMap phraseMappings = new HashMap<>(); + _dictionary = new Dictionary(vocab, phraseMappings); + + } catch (Exception e) { + e.printStackTrace(); + } + + } + + @NonNull + String transcribeAudio(float[] byteBuffer) { + + float[][][] encoderOutputBuffer = new float[1][1500][384]; + float[][][] decoderOutputBuffer = new float[1][384][51865]; + + int noTimestamps = _dictionary.getNotTimeStamps(); + int startOfTranscript = _dictionary.getStartOfTranscript(); + long[][] decoder_ids = new long[1][384]; + decoder_ids[0][0] = startOfTranscript; + decoder_ids[0][1] = startOfTranscript + 1; //+ lang; + decoder_ids[0][2] = Vocab.TOKEN_SPEECH_TO_TEXT; + decoder_ids[0][3] = noTimestamps; + int prefixLen = 4; + + + Map encoderInputsMap = new HashMap(); + String[] encoderInputs = _encoder.getSignatureInputs(SIGNATURE_KEY); + encoderInputsMap.put(encoderInputs[0], reshape(byteBuffer, ENCODER_INPUT_SHAPE)); + + Map encoderOutputsMap = new HashMap(); + String[] encoderOutputs = _encoder.getSignatureOutputs(SIGNATURE_KEY); + encoderOutputsMap.put(encoderOutputs[0], encoderOutputBuffer); + + _encoder.runSignature(encoderInputsMap, encoderOutputsMap, SIGNATURE_KEY); + + + Map decoderInputsMap = new HashMap(); + String[] decoderInputs = _decoder.getSignatureInputs(SIGNATURE_KEY); + decoderInputsMap.put(decoderInputs[0], encoderOutputBuffer); + decoderInputsMap.put(decoderInputs[1], decoder_ids); + + Map decoderOutputsMap = new HashMap(); + String[] decoderOutputs = _decoder.getSignatureOutputs(SIGNATURE_KEY); + decoderOutputsMap.put(decoderOutputs[0], decoderOutputBuffer); + + int nextToken = -1; + while (nextToken != _dictionary.getEndOfTranscript()) { + _decoder.resizeInput(1, new int[]{1, prefixLen}); + _decoder.runSignature(decoderInputsMap, decoderOutputsMap, SIGNATURE_KEY); + nextToken = maxTokenIndex(decoderOutputBuffer, prefixLen); + + decoder_ids[0][prefixLen] = nextToken; + Log.i("transcribeAudio", "token: " + nextToken); + Log.i("transcribeAudio", "token: " + Arrays.toString(decoder_ids[0])); + prefixLen += 1; + + } + + long[] output = new long[prefixLen]; + System.arraycopy(decoder_ids[0], 0, output, 0, prefixLen); + +// _dictionary.logAllTokens(); + + String whisperOutput = _dictionary.tokensToString(output); + return _dictionary.injectTokens(whisperOutput); + } + + private int maxTokenIndex(float[][][] decoderOutputBuffer, int index) { + float[] sentence = decoderOutputBuffer[0][index]; + + + int lastTokenIndex = 0; + float maxValue = Float.MIN_VALUE; + for (int i = 0; i < sentence.length; i++) { + if (sentence[i] > maxValue) { + maxValue = sentence[i]; + lastTokenIndex = i; + } + } + return lastTokenIndex; + } + + private static MappedByteBuffer loadWhisperModel(AssetManager assets, String modelName) + throws IOException { + AssetFileDescriptor fileDescriptor = assets.openFd(modelName); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } + + @NonNull + private float[][][] reshape(float[] byteBuffer, int[] inputShape) { + float[][][] reshapedFloats = new float[inputShape[0]][inputShape[1]][inputShape[2]]; + int index = 0; + for (int k = 0; k < inputShape[2]; k++) { + for (int j = 0; j < inputShape[1]; j++) { + for (int i = 0; i < inputShape[0]; i++) { + reshapedFloats[i][j][k] = byteBuffer[index]; + index++; + } + } + } + return reshapedFloats; + } +} \ No newline at end of file diff --git a/app/src/main/java/com/mjm/whisperVoiceRecognition/Vocab.java b/app/src/main/java/com/mjm/whisperVoiceRecognition/Vocab.java index 0c5f139..c15dd92 100644 --- a/app/src/main/java/com/mjm/whisperVoiceRecognition/Vocab.java +++ b/app/src/main/java/com/mjm/whisperVoiceRecognition/Vocab.java @@ -5,22 +5,25 @@ public class Vocab { public int n_vocab; - public int token_eot; - public int token_sot; + public int tokenEndOfTranscript; + public int tokenStartOfTranscript; public int token_prev; public int token_solm; - public int token_not; + public int tokenNoTimeStamps; public int token_beg; public HashMap id_to_token; + public static final int TOKEN_SPEECH_TO_TEXT = 50358; + public static final int TOKEN_UNKNOWN_PURPOSE = 50359; + public Vocab() { // Magic Numbers evidently derived from https://github.com/ggerganov/whisper.cpp this.n_vocab = 51864; - this.token_eot = 50256; - this.token_sot = 50257; + this.tokenEndOfTranscript = 50256; + this.tokenStartOfTranscript = 50257; this.token_prev = 50360; this.token_solm = 50361; - this.token_not = 50362; + this.tokenNoTimeStamps = 50362; this.token_beg = 50363; this.id_to_token = new HashMap(); } diff --git a/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java b/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java index 54c1935..f44e758 100644 --- a/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java +++ b/app/src/main/java/com/mjm/whisperVoiceRecognition/VoiceKeyboardInputMethodService.java @@ -6,7 +6,6 @@ import android.content.Context; import android.content.Intent; import android.content.pm.PackageManager; -import android.content.res.AssetFileDescriptor; import android.content.res.AssetManager; import android.inputmethodservice.InputMethodService; import android.media.AudioManager; @@ -23,78 +22,29 @@ import android.widget.Toast; import android.widget.ToggleButton; -import androidx.annotation.NonNull; import androidx.constraintlayout.widget.ConstraintLayout; import com.example.WhisperVoiceKeyboard.R; -import org.tensorflow.lite.Interpreter; -import org.tensorflow.lite.flex.FlexDelegate; -import org.tensorflow.lite.gpu.GpuDelegate; -import org.tensorflow.lite.nnapi.NnApiDelegate; - -import java.io.FileInputStream; -import java.io.IOException; -import java.nio.MappedByteBuffer; -import java.nio.channels.FileChannel; -import java.util.HashMap; -import java.util.Map; import java.util.Optional; public class VoiceKeyboardInputMethodService extends InputMethodService { - private Interpreter _nnapiEncoder; - private Interpreter _nnapiDecoder; - private Dictionary _dictionary; + private Transcriber _transcriber; - private static final String WHISPER_ENCODER = "nnmodel/nyadia/whisper-encoder.tflite"; - private static final String WHISPER_DECODER_LANGUAGE = "nnmodel/nyadia/whisper-decoder_language.tflite"; - private static final boolean LOG_AND_DRAW = false; @Override public void onCreate() { - super.onCreate(); + AssetManager assetManager = getAssets(); - Interpreter.Options nnapiOptions = new Interpreter.Options(); - NnApiDelegate nnapiDelegate = new NnApiDelegate(); - FlexDelegate flexDelegate = new FlexDelegate(); - GpuDelegate gpuDelegate = new GpuDelegate(); - - - nnapiOptions.addDelegate(flexDelegate); - nnapiOptions.addDelegate(gpuDelegate); - nnapiOptions.addDelegate(nnapiDelegate); - - - nnapiOptions.setNumThreads(0); - nnapiOptions.setUseXNNPACK(true); - nnapiOptions.setUseNNAPI(true); - - try { - - - MappedByteBuffer whisper_encoder = loadWhisperModel(getAssets(), WHISPER_ENCODER); - MappedByteBuffer whisper_decoder_language = loadWhisperModel(getAssets(), WHISPER_DECODER_LANGUAGE); + RustLib.init(assetManager); - _nnapiEncoder = new Interpreter(whisper_encoder, nnapiOptions); - _nnapiDecoder = new Interpreter(whisper_decoder_language, nnapiOptions); - - Vocab vocab = ExtractVocab.extractVocab(getAssets().open("filters_vocab_gen.bin")); - HashMap phraseMappings = new HashMap<>(); - _dictionary = new Dictionary(vocab, phraseMappings); - - } catch (Exception e) { - e.printStackTrace(); - System.exit(-1); - } - - - RustLib.init(getAssets()); + _transcriber = new Transcriber(assetManager); } @@ -185,7 +135,7 @@ public boolean onTouch(View view, MotionEvent motionEvent) { cancelButton.setVisibility(View.GONE); Optional byteBuffer = RustLib.endRec(); if (byteBuffer.isPresent()) { - String transcribeAudio = transcribeAudio(byteBuffer.get()); + String transcribeAudio = _transcriber.transcribeAudio(byteBuffer.get()); String transcribed = transcribeAudio.trim() + " "; getCurrentInputConnection().commitText(transcribed, 1); if (LOG_AND_DRAW) { @@ -205,82 +155,4 @@ private void sendDelete() { } - @NonNull - private String transcribeAudio(float[] byteBuffer) { - int[] inputShape = {1, 80, 3000}; - - Map inputsEncoder = new HashMap<>(); - Map outputsEncoder = new HashMap<>(); - Map inputsDecoder = new HashMap<>(); - Map outputsDecoder = new HashMap<>(); - - String signatureKey = "serving_default"; - String[] nnapiEncoderSignatureInputs = _nnapiEncoder.getSignatureInputs(signatureKey); - String[] nnapiEncoderSignatureOutputs = _nnapiEncoder.getSignatureOutputs(signatureKey); - String[] nnapiDecoderSignatureInputs = _nnapiDecoder.getSignatureInputs(signatureKey); - String[] nnapiDecoderSignatureOutputs = _nnapiDecoder.getSignatureOutputs(signatureKey); - - String encoderInputKey0 = nnapiEncoderSignatureInputs[0]; - String encoderOutputKey0 = nnapiEncoderSignatureOutputs[0]; - String decoderInputKey0 = nnapiDecoderSignatureInputs[0]; - String decoderInputKey1 = nnapiDecoderSignatureInputs[1]; - String decoderOutputKey0 = nnapiDecoderSignatureOutputs[0]; - - inputsEncoder.put(encoderInputKey0, reshapeInput(byteBuffer, inputShape)); - float[][][] encoder_output = new float[1][1500][384]; - outputsEncoder.put(encoderOutputKey0, encoder_output); - - - _nnapiEncoder.runSignature(inputsEncoder, outputsEncoder, signatureKey); - - - long[][][] encoder_output_int = new long[1][1500][384]; - - inputsDecoder.put(decoderInputKey0, encoder_output_int); - - float[][] decoder_ids = new float[1][384]; - decoder_ids[0][0] = 50258; - decoder_ids[0][1] = 50266; - decoder_ids[0][2] = 50358; - decoder_ids[0][3] = 50363; - inputsDecoder.put(decoderInputKey1, decoder_ids); - - int[] shape = new int[2]; - shape[0] = 1; - shape[1] = 4; - _nnapiDecoder.resizeInput(1, shape); - float[][] output = new float[1][224]; - outputsDecoder.put(decoderOutputKey0, output); - - _nnapiDecoder.runSignature(inputsDecoder, outputsDecoder, signatureKey); - String whisperOutput = _dictionary.tokensToString(new int[1][224]); - return _dictionary.injectTokens(whisperOutput); - } - - - @NonNull - private float[][][] reshapeInput(float[] byteBuffer, int[] inputShape) { - float[][][] reshapedFloats = new float[inputShape[0]][inputShape[1]][inputShape[2]]; - int index = 0; - for (int k = 0; k < inputShape[2]; k++) { - for (int j = 0; j < inputShape[1]; j++) { - for (int i = 0; i < inputShape[0]; i++) { - reshapedFloats[i][j][k] = byteBuffer[index]; - index++; - } - } - } - return reshapedFloats; - } - - private static MappedByteBuffer loadWhisperModel(AssetManager assets, String modelName) - throws IOException { - AssetFileDescriptor fileDescriptor = assets.openFd(modelName); - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); - FileChannel fileChannel = inputStream.getChannel(); - long startOffset = fileDescriptor.getStartOffset(); - long declaredLength = fileDescriptor.getDeclaredLength(); - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); - } - } \ No newline at end of file