diff --git a/app/src/main/java/com/mjm/whisperVoiceRecognition/Dictionary.java b/app/src/main/java/com/mjm/whisperVoiceRecognition/Dictionary.java index 33ee7d8..df0ece1 100644 --- a/app/src/main/java/com/mjm/whisperVoiceRecognition/Dictionary.java +++ b/app/src/main/java/com/mjm/whisperVoiceRecognition/Dictionary.java @@ -4,6 +4,7 @@ import androidx.annotation.NonNull; +import java.util.Collection; import java.util.Map; public class Dictionary { @@ -23,7 +24,7 @@ public Dictionary(Vocab tokenMappings, Map phraseMappings) { * @return String composed of words of the tokens in the array */ @NonNull - public String tokensToString(long[] output) { + public String tokensToString(Collection output) { StringBuilder sb = new StringBuilder(); for (long token : output) { if (token == _vocab.tokenEndOfTranscript) { diff --git a/app/src/main/java/com/mjm/whisperVoiceRecognition/Transcriber.java b/app/src/main/java/com/mjm/whisperVoiceRecognition/Transcriber.java index 7ec4682..99e45fb 100644 --- a/app/src/main/java/com/mjm/whisperVoiceRecognition/Transcriber.java +++ b/app/src/main/java/com/mjm/whisperVoiceRecognition/Transcriber.java @@ -18,6 +18,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.Vector; public class Transcriber { public static final String SIGNATURE_KEY = "serving_default"; @@ -37,7 +38,7 @@ public Transcriber(AssetManager assetManager) { GpuDelegate gpuDelegate = new GpuDelegate(); -// nnapiOptions.addDelegate(flexDelegate); + nnapiOptions.addDelegate(flexDelegate); // nnapiOptions.addDelegate(gpuDelegate); // nnapiOptions.addDelegate(nnapiDelegate); @@ -70,14 +71,15 @@ String transcribeAudio(float[] byteBuffer) { float[][][] encoderOutputBuffer = new float[1][1500][384]; float[][][] decoderOutputBuffer = new float[1][384][51865]; - int noTimestamps = _dictionary.getNotTimeStamps(); - int startOfTranscript = _dictionary.getStartOfTranscript(); - long[][] decoder_ids = new long[1][384]; - decoder_ids[0][0] = startOfTranscript; - decoder_ids[0][1] = 50259; //+ lang; - decoder_ids[0][2] = Vocab.TRANSCRIBE; - decoder_ids[0][3] = noTimestamps; - int prefixLen = 4; + long[][] decoderInputIds = new long[1][384]; + long[] prefix = {_dictionary.getStartOfTranscript(), 50259, Vocab.TRANSLATE, _dictionary.getNotTimeStamps()}; + int prefixLen = prefix.length; + System.arraycopy(prefix, 0, decoderInputIds[0], 0, prefixLen); + + Vector tokenStream = new Vector<>(4); + for (int p = 0; p < prefixLen; p++) { + tokenStream.add(prefix[p]); + } Map encoderInputsMap = new HashMap(); @@ -94,7 +96,7 @@ String transcribeAudio(float[] byteBuffer) { Map decoderInputsMap = new HashMap(); String[] decoderInputs = _decoder.getSignatureInputs(SIGNATURE_KEY); decoderInputsMap.put(decoderInputs[0], encoderOutputBuffer); - decoderInputsMap.put(decoderInputs[1], decoder_ids); + decoderInputsMap.put(decoderInputs[1], decoderInputIds); Map decoderOutputsMap = new HashMap(); String[] decoderOutputs = _decoder.getSignatureOutputs(SIGNATURE_KEY); @@ -104,37 +106,41 @@ String transcribeAudio(float[] byteBuffer) { while (nextToken != _dictionary.getEndOfTranscript()) { _decoder.resizeInput(1, new int[]{1, prefixLen}); _decoder.runSignature(decoderInputsMap, decoderOutputsMap, SIGNATURE_KEY); - nextToken = maxTokenIndex(decoderOutputBuffer, prefixLen); + int[] cleaned = argmax(decoderOutputBuffer[0]); + + Log.i("transcribeAudio", "index: " + prefixLen); + Log.i("transcribeAudio", "cleaned: " + Arrays.toString(cleaned)); + nextToken = cleaned[prefixLen - 1]; - decoder_ids[0][prefixLen] = nextToken; - Log.i("transcribeAudio", "token: " + nextToken); - Log.i("transcribeAudio", "token: " + Arrays.toString(decoder_ids[0])); + tokenStream.add((long) nextToken); + decoderInputIds[0][prefixLen] = nextToken; + + Log.i("transcribeAudio", "token: " + Arrays.toString(decoderInputIds[0])); prefixLen += 1; } - long[] output = new long[prefixLen]; - System.arraycopy(decoder_ids[0], 0, output, 0, prefixLen); // _dictionary.logAllTokens(); - String whisperOutput = _dictionary.tokensToString(output); + String whisperOutput = _dictionary.tokensToString(tokenStream); return _dictionary.injectTokens(whisperOutput); } - private int maxTokenIndex(float[][][] decoderOutputBuffer, int index) { - float[] sentence = decoderOutputBuffer[0][index]; - - - int lastTokenIndex = 0; - float maxValue = Float.MIN_VALUE; - for (int i = 0; i < sentence.length; i++) { - if (sentence[i] > maxValue) { - maxValue = sentence[i]; - lastTokenIndex = i; + private int[] argmax(float[][] decoderOutputBuffer) { + int[] result = new int[decoderOutputBuffer.length]; + for (int i = 0; i < result.length; i++) { + int maxIndex = 0; + for (int j = 0; j < decoderOutputBuffer[i].length; j++) { + if (decoderOutputBuffer[i][j] > decoderOutputBuffer[i][maxIndex]) { + maxIndex = j; + } } + + result[i] = maxIndex; } - return lastTokenIndex; + + return result; } private static MappedByteBuffer loadWhisperModel(AssetManager assets, String modelName)