Skip to content

Commit

Permalink
encoder decoder code. does work
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael McCulloch committed Mar 11, 2023
1 parent deaa592 commit 2b1e53b
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 158 deletions.
12 changes: 12 additions & 0 deletions app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">


<uses-permission android:name="android.permission.RECORD_AUDIO" />
<uses-permission android:name="android.permission.FOREGROUND_SERVICE" />

Expand All @@ -12,6 +13,15 @@
android:roundIcon="@drawable/mic"
android:supportsRtl="true"
android:theme="@style/Theme.WhisperVoiceKeyboard">
<uses-native-library
android:name="libOpenCL-pixel.so"
android:required="false" />
<uses-native-library
android:name="libOpenCL.so"
android:required="false" />
<uses-native-library
android:name="libGLES_mali.so"
android:required="false" />

<activity
android:name="com.mjm.whisperVoiceRecognition.Wizard"
Expand All @@ -37,6 +47,8 @@
<meta-data
android:name="android.view.im"
android:resource="@xml/method" />


</service>

</application>
Expand Down
Binary file not shown.
28 changes: 23 additions & 5 deletions app/src/main/java/com/mjm/whisperVoiceRecognition/Dictionary.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ public Dictionary(Vocab tokenMappings, Map<String, String> phraseMappings) {
* @return String composed of words of the tokens in the array
*/
@NonNull
public String tokensToString(int[][] output) {
public String tokensToString(long[] output) {
StringBuilder sb = new StringBuilder();
for (int token : output[0]) {
if (token == _vocab.token_eot) {
for (long token : output) {
if (token == _vocab.tokenEndOfTranscript) {
break;
}
if (token != 50257 && token != 50362) {
String word = _vocab.id_to_token.get(token);
if (token != _vocab.tokenStartOfTranscript && token != _vocab.tokenNoTimeStamps) {
String word = _vocab.id_to_token.get((int) token);
Log.i("tokenization", "token: " + token + " word " + word);
sb.append(word);
}
Expand All @@ -39,6 +39,13 @@ public String tokensToString(int[][] output) {
}


public void logAllTokens() {
for (int token = 0; token <= 51865; token += 1) {
String word = _vocab.id_to_token.get(token);
Log.i("tokenization", "token: " + token + " word " + word);
}
}

/**
* This method takes a string as an argument and replaces key phrases with special tokens.
*
Expand All @@ -56,4 +63,15 @@ public String injectTokens(String text) {
return result;
}

public int getNotTimeStamps() {
return _vocab.tokenNoTimeStamps;
}

public int getStartOfTranscript() {
return _vocab.tokenStartOfTranscript;
}

public int getEndOfTranscript() {
return _vocab.tokenEndOfTranscript;
}
}
33 changes: 19 additions & 14 deletions app/src/main/java/com/mjm/whisperVoiceRecognition/ExtractVocab.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.mjm.whisperVoiceRecognition;

import android.util.Log;

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
Expand All @@ -10,34 +12,37 @@

public class ExtractVocab {

public static Vocab extractVocab(InputStream filters_vocab_gen_bin) throws IOException {
if (readI32(filters_vocab_gen_bin) == 0x5553454e) {
public static Vocab extractVocab(InputStream filtersVocabBin) throws IOException {
if (readI32(filtersVocabBin) == 0x5553454e) {


readI32(filters_vocab_gen_bin);
readI32(filters_vocab_gen_bin);
readVecF32(filters_vocab_gen_bin, 80 * 201);
readI32(filtersVocabBin);
readI32(filtersVocabBin);
readVecF32(filtersVocabBin, 80 * 201);

Vocab vocab = new Vocab();
int word_count = readI32(filters_vocab_gen_bin);
int word_count = readI32(filtersVocabBin);
assert (50257 == word_count);
HashMap<Integer, String> words = new HashMap<>(word_count);
for (int i = 0; i < word_count; i++) {
int nextWordLen = readU32(filters_vocab_gen_bin);
String word = readString(filters_vocab_gen_bin, nextWordLen);
int nextWordLen = readU32(filtersVocabBin);
String word = readString(filtersVocabBin, nextWordLen);
words.put(i, word);
}
vocab.n_vocab = word_count;
vocab.id_to_token = words;
if (vocab.isMultilingual()) {
vocab.token_eot += 1;
vocab.token_sot += 1;
if (true) {
// if (vocab.isMultilingual()) {

Log.i("extractVocab", "Using Multilingual");
vocab.tokenEndOfTranscript += 1;
vocab.tokenStartOfTranscript += 1;
vocab.token_prev += 1;
vocab.token_solm += 1;
vocab.token_not += 1;
vocab.tokenNoTimeStamps += 1;
vocab.token_beg += 1;
}
for (int i = word_count; i < vocab.n_vocab; i++) {
/* for (int i = word_count; i < vocab.n_vocab; i++) {
String word;
if (i > vocab.token_beg) {
word = "[_TT_" + (i - vocab.token_beg) + "]";
Expand All @@ -55,7 +60,7 @@ public static Vocab extractVocab(InputStream filters_vocab_gen_bin) throws IOExc
word = "[_extra_token_" + i + "]";
}
vocab.id_to_token.put(i, word);
}
}*/
System.out.println("Succeeded in Loading Vocab! " + vocab.n_vocab + " (" + vocab.id_to_token.size() + ") Words.");
return vocab;
} else throw new IOException("bad magic");
Expand Down
164 changes: 164 additions & 0 deletions app/src/main/java/com/mjm/whisperVoiceRecognition/Transcriber.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package com.mjm.whisperVoiceRecognition;

import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.util.Log;

import androidx.annotation.NonNull;

import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.flex.FlexDelegate;
import org.tensorflow.lite.gpu.GpuDelegate;
import org.tensorflow.lite.nnapi.NnApiDelegate;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class Transcriber {
public static final String SIGNATURE_KEY = "serving_default";
public static final int[] ENCODER_INPUT_SHAPE = new int[]{1, 80, 3000};
private static final String WHISPER_ENCODER = "nnmodel/mine/whisper-encoder.tflite";
private static final String WHISPER_DECODER_LANGUAGE = "nnmodel/mine/whisper-decoder_language.tflite";
private Interpreter _encoder;
private Interpreter _decoder;
private Dictionary _dictionary;

public Transcriber(AssetManager assetManager) {


Interpreter.Options nnapiOptions = new Interpreter.Options();
NnApiDelegate nnapiDelegate = new NnApiDelegate();
FlexDelegate flexDelegate = new FlexDelegate();
GpuDelegate gpuDelegate = new GpuDelegate();


// nnapiOptions.addDelegate(flexDelegate);
// nnapiOptions.addDelegate(gpuDelegate);
// nnapiOptions.addDelegate(nnapiDelegate);


nnapiOptions.setNumThreads(8);
nnapiOptions.setUseXNNPACK(true);
nnapiOptions.setUseNNAPI(false);

try {


MappedByteBuffer whisper_encoder = loadWhisperModel(assetManager, WHISPER_ENCODER);
MappedByteBuffer whisper_decoder_language = loadWhisperModel(assetManager, WHISPER_DECODER_LANGUAGE);

_encoder = new Interpreter(whisper_encoder, nnapiOptions);
_decoder = new Interpreter(whisper_decoder_language, nnapiOptions);
Vocab vocab = ExtractVocab.extractVocab(assetManager.open("filters_vocab_multilingual.bin"));
HashMap<String, String> phraseMappings = new HashMap<>();
_dictionary = new Dictionary(vocab, phraseMappings);

} catch (Exception e) {
e.printStackTrace();
}

}

@NonNull
String transcribeAudio(float[] byteBuffer) {

float[][][] encoderOutputBuffer = new float[1][1500][384];
float[][][] decoderOutputBuffer = new float[1][384][51865];

int noTimestamps = _dictionary.getNotTimeStamps();
int startOfTranscript = _dictionary.getStartOfTranscript();
long[][] decoder_ids = new long[1][384];
decoder_ids[0][0] = startOfTranscript;
decoder_ids[0][1] = startOfTranscript + 1; //+ lang;
decoder_ids[0][2] = Vocab.TOKEN_SPEECH_TO_TEXT;
decoder_ids[0][3] = noTimestamps;
int prefixLen = 4;


Map<String, Object> encoderInputsMap = new HashMap<String, Object>();
String[] encoderInputs = _encoder.getSignatureInputs(SIGNATURE_KEY);
encoderInputsMap.put(encoderInputs[0], reshape(byteBuffer, ENCODER_INPUT_SHAPE));

Map<String, Object> encoderOutputsMap = new HashMap<String, Object>();
String[] encoderOutputs = _encoder.getSignatureOutputs(SIGNATURE_KEY);
encoderOutputsMap.put(encoderOutputs[0], encoderOutputBuffer);

_encoder.runSignature(encoderInputsMap, encoderOutputsMap, SIGNATURE_KEY);


Map<String, Object> decoderInputsMap = new HashMap<String, Object>();
String[] decoderInputs = _decoder.getSignatureInputs(SIGNATURE_KEY);
decoderInputsMap.put(decoderInputs[0], encoderOutputBuffer);
decoderInputsMap.put(decoderInputs[1], decoder_ids);

Map<String, Object> decoderOutputsMap = new HashMap<String, Object>();
String[] decoderOutputs = _decoder.getSignatureOutputs(SIGNATURE_KEY);
decoderOutputsMap.put(decoderOutputs[0], decoderOutputBuffer);

int nextToken = -1;
while (nextToken != _dictionary.getEndOfTranscript()) {
_decoder.resizeInput(1, new int[]{1, prefixLen});
_decoder.runSignature(decoderInputsMap, decoderOutputsMap, SIGNATURE_KEY);
nextToken = maxTokenIndex(decoderOutputBuffer, prefixLen);

decoder_ids[0][prefixLen] = nextToken;
Log.i("transcribeAudio", "token: " + nextToken);
Log.i("transcribeAudio", "token: " + Arrays.toString(decoder_ids[0]));
prefixLen += 1;

}

long[] output = new long[prefixLen];
System.arraycopy(decoder_ids[0], 0, output, 0, prefixLen);

// _dictionary.logAllTokens();

String whisperOutput = _dictionary.tokensToString(output);
return _dictionary.injectTokens(whisperOutput);
}

private int maxTokenIndex(float[][][] decoderOutputBuffer, int index) {
float[] sentence = decoderOutputBuffer[0][index];


int lastTokenIndex = 0;
float maxValue = Float.MIN_VALUE;
for (int i = 0; i < sentence.length; i++) {
if (sentence[i] > maxValue) {
maxValue = sentence[i];
lastTokenIndex = i;
}
}
return lastTokenIndex;
}

private static MappedByteBuffer loadWhisperModel(AssetManager assets, String modelName)
throws IOException {
AssetFileDescriptor fileDescriptor = assets.openFd(modelName);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

@NonNull
private float[][][] reshape(float[] byteBuffer, int[] inputShape) {
float[][][] reshapedFloats = new float[inputShape[0]][inputShape[1]][inputShape[2]];
int index = 0;
for (int k = 0; k < inputShape[2]; k++) {
for (int j = 0; j < inputShape[1]; j++) {
for (int i = 0; i < inputShape[0]; i++) {
reshapedFloats[i][j][k] = byteBuffer[index];
index++;
}
}
}
return reshapedFloats;
}
}
15 changes: 9 additions & 6 deletions app/src/main/java/com/mjm/whisperVoiceRecognition/Vocab.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@
public class Vocab {

public int n_vocab;
public int token_eot;
public int token_sot;
public int tokenEndOfTranscript;
public int tokenStartOfTranscript;
public int token_prev;
public int token_solm;
public int token_not;
public int tokenNoTimeStamps;
public int token_beg;
public HashMap<Integer, String> id_to_token;

public static final int TOKEN_SPEECH_TO_TEXT = 50358;
public static final int TOKEN_UNKNOWN_PURPOSE = 50359;

public Vocab() {
// Magic Numbers evidently derived from https://github.com/ggerganov/whisper.cpp
this.n_vocab = 51864;
this.token_eot = 50256;
this.token_sot = 50257;
this.tokenEndOfTranscript = 50256;
this.tokenStartOfTranscript = 50257;
this.token_prev = 50360;
this.token_solm = 50361;
this.token_not = 50362;
this.tokenNoTimeStamps = 50362;
this.token_beg = 50363;
this.id_to_token = new HashMap<Integer, String>();
}
Expand Down
Loading

0 comments on commit 2b1e53b

Please sign in to comment.