Skip to content
This repository has been archived by the owner on Apr 4, 2023. It is now read-only.

Commit

Permalink
Add custom (TensorFlow Lite) models support to the ML Kit feature #702
Browse files Browse the repository at this point in the history
  • Loading branch information
EddyVerbruggen committed Jan 17, 2019
1 parent a72d653 commit df29885
Show file tree
Hide file tree
Showing 18 changed files with 360 additions and 99 deletions.
3 changes: 3 additions & 0 deletions demo-ng/app/custommodel/nutella/nutella_labels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
cereals
juice
nutella
Binary file not shown.
13 changes: 8 additions & 5 deletions demo-ng/app/tabs/mlkit/custommodel/custommodel.component.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
<MLKitCustomModel
width="100%"
height="100%"
confidenceThreshold="0.6"
localModelFile="~/custommodel/inception/inception_v3_quant.tflite"
labelsFile="~/custommodel/inception/inception_labels.txt"
modelInputShape="1, 299, 299, 3"
modelInputType="QUANT"
processEveryNthFrame="30"
maxResults="5"
(scanResult)="onCustomModelResult($event)">
</MLKitCustomModel>

Expand All @@ -33,16 +38,14 @@
<Label height="1" marginBottom="1" borderBottomWidth="1" borderColor="rgba(81, 184, 237, 1)"></Label>
</StackLayout>
</GridLayout>
<Label [text]="result" row="0" rowSpan="3" col="0"></Label>

<!--ListView separatorColor="transparent" row="0" rowSpan="3" col="0" colSpan="3" [items]="result" class="m-t-20" backgroundColor="transparent">
<ListView separatorColor="transparent" row="0" rowSpan="3" col="0" colSpan="3" [items]="labels" class="m-t-20" backgroundColor="transparent">
<ng-template let-item="item">
<GridLayout columns="3*, 2*">
<Label col="0" class="mlkit-result" textWrap="true" [text]="item.text"></Label>
<Label col="1" class="mlkit-result" textWrap="true" [text]="item.confidence | number"></Label>
</GridLayout>
</ng-template>
</ListView-->
</ListView>
</GridLayout>

<GridLayout rows="auto" columns="auto, auto" horizontalAlignment="right" class="m-t-4 m-r-8">
Expand Down
7 changes: 5 additions & 2 deletions demo-ng/app/tabs/mlkit/custommodel/custommodel.component.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import { AbstractMLKitViewComponent } from "~/tabs/mlkit/abstract.mlkitview.comp
templateUrl: "./custommodel.component.html",
})
export class CustomModelComponent extends AbstractMLKitViewComponent {
result: any;
labels: Array<{
text: string;
confidence: number;
}>;

onCustomModelResult(scanResult: any): void {
const value: MLKitCustomModelResult = scanResult.value;
this.result = value.result;
this.labels = value.result;
}
}
12 changes: 8 additions & 4 deletions demo-ng/app/tabs/mlkit/mlkit.component.ts
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,15 @@ export class MLKitComponent {
// cloudModelName: "~/mobilenet_quant_v2_1_0_299",
// cloudModelName: "~/inception_v3_quant",

localModelFile: "~/custommodel/mobilenet/mobilenet_quant_v2_1.0_299.tflite",
labelsFile: "~/custommodel/mobilenet/mobilenet_labels.txt",
// note that there's an issue with this model (making the app crash): "ValueError: Model provided has model identifier 'Mobi', should be 'TFL3'" (reported by https://github.com/EddyVerbruggen/ns-mlkit-tflite-curated/blob/master/scripts/get_model_details.py)
// localModelFile: "~/custommodel/nutella/retrained_quantized_model.tflite",
// labelsFile: "~/custommodel/nutella/nutella_labels.txt",

// localModelFile: "~/custommodel/inception/inception_v3_quant.tflite",
// labelsFile: "~/custommodel/inception/inception_labels.txt",
// localModelFile: "~/custommodel/mobilenet/mobilenet_quant_v2_1.0_299.tflite",
// labelsFile: "~/custommodel/mobilenet/mobilenet_labels.txt",

localModelFile: "~/custommodel/inception/inception_v3_quant.tflite",
labelsFile: "~/custommodel/inception/inception_labels.txt",

maxResults: 5,
modelInput: [{
Expand Down
4 changes: 3 additions & 1 deletion demo-ng/app/tabs/tabs-routing.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import { TextRecognitionComponent } from "~/tabs/mlkit/textrecognition/textrecog
import { BarcodeScanningComponent } from "~/tabs/mlkit/barcodescanning/barcodescanning.component";
import { FaceDetectionComponent } from "~/tabs/mlkit/facedetection/facedetection.component";
import { ImageLabelingComponent } from "~/tabs/mlkit/imagelabeling/imagelabeling.component";
import { CustomModelComponent } from "~/tabs/mlkit/custommodel/custommodel.component";

const routes: Routes = [
{ path: "", component: TabsComponent },
{ path: "mlkit/textrecognition", component: TextRecognitionComponent },
{ path: "mlkit/barcodescanning", component: BarcodeScanningComponent },
{ path: "mlkit/facedetection", component: FaceDetectionComponent },
{ path: "mlkit/imagelabeling", component: ImageLabelingComponent }
{ path: "mlkit/imagelabeling", component: ImageLabelingComponent },
{ path: "mlkit/custommodel", component: CustomModelComponent }
];

@NgModule({
Expand Down
5 changes: 4 additions & 1 deletion demo-ng/app/tabs/tabs.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import { TextRecognitionComponent } from "~/tabs/mlkit/textrecognition/textrecog
import { BarcodeScanningComponent } from "~/tabs/mlkit/barcodescanning/barcodescanning.component";
import { FaceDetectionComponent } from "~/tabs/mlkit/facedetection/facedetection.component";
import { ImageLabelingComponent } from "~/tabs/mlkit/imagelabeling/imagelabeling.component";
import { CustomModelComponent } from "~/tabs/mlkit/custommodel/custommodel.component";

import { registerElement } from "nativescript-angular/element-registry";
registerElement("MLKitBarcodeScanner", () => require("nativescript-plugin-firebase/mlkit/barcodescanning").MLKitBarcodeScanner);
registerElement("MLKitFaceDetection", () => require("nativescript-plugin-firebase/mlkit/facedetection").MLKitFaceDetection);
registerElement("MLKitTextRecognition", () => require("nativescript-plugin-firebase/mlkit/textrecognition").MLKitTextRecognition);
registerElement("MLKitImageLabeling", () => require("nativescript-plugin-firebase/mlkit/imagelabeling").MLKitImageLabeling);
registerElement("MLKitCustomModel", () => require("nativescript-plugin-firebase/mlkit/custommodel").MLKitCustomModel);

@NgModule({
imports: [
Expand All @@ -29,7 +31,8 @@ registerElement("MLKitImageLabeling", () => require("nativescript-plugin-firebas
ImageLabelingComponent,
MLKitComponent,
TabsComponent,
TextRecognitionComponent
TextRecognitionComponent,
CustomModelComponent
],
schemas: [
NO_ERRORS_SCHEMA
Expand Down
73 changes: 73 additions & 0 deletions src/mlkit/custommodel/custommodel-common.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,83 @@
import * as fs from "tns-core-modules/file-system";
import { Property } from "tns-core-modules/ui/core/properties";
import { MLKitCameraView } from "../mlkit-cameraview";
import { MLKitCustomModelType } from "./index";

export const localModelFileProperty = new Property<MLKitCustomModel, string>({
name: "localModelFile",
defaultValue: null,
});

export const labelsFileProperty = new Property<MLKitCustomModel, string>({
name: "labelsFile",
defaultValue: null,
});

export const modelInputShapeProperty = new Property<MLKitCustomModel, string>({
name: "modelInputShape",
defaultValue: null,
});

export const modelInputTypeProperty = new Property<MLKitCustomModel, string>({
name: "modelInputType",
defaultValue: null,
});

// TODO could combine this with 'confidenceThreshold'
export const maxResultsProperty = new Property<MLKitCustomModel, number>({
name: "maxResults",
defaultValue: 5
});

export abstract class MLKitCustomModel extends MLKitCameraView {
static scanResultEvent: string = "scanResult";
protected localModelFile: string;
protected labelsFile: string;
protected maxResults: number;
protected modelInputShape: Array<number>;
protected modelInputType: MLKitCustomModelType;

protected onSuccessListener;
protected detectorBusy: boolean;

protected labels: Array<string>;

[localModelFileProperty.setNative](value: string) {
this.localModelFile = value;
}

[labelsFileProperty.setNative](value: string) {
this.labelsFile = value;
if (value.indexOf("~/") === 0) {
this.labels = getLabelsFromAppFolder(value);
} else {
// no dice loading from assets yet, let's advice users to use ~/ for now
console.log("For the 'labelsFile' property, use the ~/ prefix for now..");
return;
}
}

[maxResultsProperty.setNative](value: any) {
this.maxResults = parseInt(value);
}

[modelInputShapeProperty.setNative](value: string) {
if ((typeof value) === "string") {
this.modelInputShape = value.split(",").map(v => parseInt(v.trim()));
}
}

[modelInputTypeProperty.setNative](value: MLKitCustomModelType) {
this.modelInputType = value;
}
}

localModelFileProperty.register(MLKitCustomModel);
labelsFileProperty.register(MLKitCustomModel);
maxResultsProperty.register(MLKitCustomModel);
modelInputShapeProperty.register(MLKitCustomModel);
modelInputTypeProperty.register(MLKitCustomModel);

export function getLabelsFromAppFolder(labelsFile: string): Array<string> {
const labelsPath = fs.knownFolders.currentApp().path + labelsFile.substring(1);
return getLabelsFromFile(labelsPath);
Expand Down
117 changes: 81 additions & 36 deletions src/mlkit/custommodel/index.android.ts
Original file line number Diff line number Diff line change
@@ -1,71 +1,122 @@
import * as fs from "tns-core-modules/file-system";
import { ImageSource } from "tns-core-modules/image-source";
import { MLKitOptions } from "../";
import { MLKitCustomModelOptions, MLKitCustomModelResult, MLKitCustomModelResultValue } from "./";
import { getLabelsFromAppFolder, MLKitCustomModel as MLKitCustomModelBase } from "./custommodel-common";
import * as fs from "tns-core-modules/file-system";

declare const com: any;
declare const org: any; // TODO remove after regenerating typings

export class MLKitCustomModel extends MLKitCustomModelBase {
private detector;
private onFailureListener;
private inputOutputOptions;

protected createDetector(): any {
return getInterpreter(null); // TODO
this.detector = getInterpreter(this.localModelFile);
return this.detector;
}

protected runDetector(imageByteBuffer, previewWidth, previewHeight): void {
if (this.detectorBusy) {
return;
}

this.detectorBusy = true;

if (!this.onFailureListener) {
this.onFailureListener = new com.google.android.gms.tasks.OnFailureListener({
onFailure: exception => {
console.log(exception.getMessage());
this.detectorBusy = false;
}
});
}

const modelExpectsWidth = this.modelInputShape[1];
const modelExpectsHeight = this.modelInputShape[2];
const isQuantized = this.modelInputType !== "FLOAT32";

if (!this.inputOutputOptions) {
let intArrayIn = Array.create("int", 4);
intArrayIn[0] = this.modelInputShape[0];
intArrayIn[1] = modelExpectsWidth;
intArrayIn[2] = modelExpectsHeight;
intArrayIn[3] = this.modelInputShape[3];

const inputType = isQuantized ? com.google.firebase.ml.custom.FirebaseModelDataType.BYTE : com.google.firebase.ml.custom.FirebaseModelDataType.FLOAT32;

let intArrayOut = Array.create("int", 2);
intArrayOut[0] = 1;
intArrayOut[1] = this.labels.length;

this.inputOutputOptions = new com.google.firebase.ml.custom.FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, inputType, intArrayIn)
.setOutputFormat(0, inputType, intArrayOut)
.build();
}

const input = org.nativescript.plugins.firebase.mlkit.BitmapUtil.byteBufferToByteBuffer(imageByteBuffer, previewWidth, previewHeight, modelExpectsWidth, modelExpectsHeight, isQuantized);
const inputs = new com.google.firebase.ml.custom.FirebaseModelInputs.Builder()
.add(input) // add as many input arrays as your model requires
.build();

this.detector
.run(inputs, this.inputOutputOptions)
.addOnSuccessListener(this.onSuccessListener)
.addOnFailureListener(this.onFailureListener);
}

protected createSuccessListener(): any {
return new com.google.android.gms.tasks.OnSuccessListener({
onSuccess: labels => {
this.onSuccessListener = new com.google.android.gms.tasks.OnSuccessListener({
onSuccess: output => {
const probabilities: Array<number> = output.getOutput(0)[0];

if (labels.size() === 0) return;
if (this.labels.length !== probabilities.length) {
console.log(`The number of labels (${this.labels.length}) is not equal to the interpretation result (${probabilities.length})!`);
return;
}

const result = <MLKitCustomModelResult>{
result: []
result: getSortedResult(this.labels, probabilities, this.maxResults)
};

// see https://github.com/firebase/quickstart-android/blob/0f4c86877fc5f771cac95797dffa8bd026dd9dc7/mlkit/app/src/main/java/com/google/firebase/samples/apps/mlkit/textrecognition/TextRecognitionProcessor.java#L62
for (let i = 0; i < labels.size(); i++) {
const label = labels.get(i);
result.result.push({
text: label.getLabel(),
confidence: label.getConfidence()
});
}

this.notify({
eventName: MLKitCustomModel.scanResultEvent,
object: this,
value: result
});

this.detectorBusy = false;
}
});

return this.onSuccessListener;
}
}

// TODO should probably cache this
function getInterpreter(options: MLKitCustomModelOptions): any {
function getInterpreter(localModelFile?: string): any {
const firModelOptionsBuilder = new com.google.firebase.ml.custom.FirebaseModelOptions.Builder();

let localModelRegistrationSuccess = false;
let cloudModelRegistrationSuccess = false;
let localModelName;

if (options.localModelFile) {
localModelName = options.localModelFile.lastIndexOf("/") === -1 ? options.localModelFile : options.localModelFile.substring(options.localModelFile.lastIndexOf("/") + 1);
if (localModelFile) {
localModelName = localModelFile.lastIndexOf("/") === -1 ? localModelFile : localModelFile.substring(localModelFile.lastIndexOf("/") + 1);

if (com.google.firebase.ml.custom.FirebaseModelManager.getInstance().getLocalModelSource(localModelName)) {
localModelRegistrationSuccess = true;
firModelOptionsBuilder.setLocalModelName(localModelName)
} else {
console.log("model not yet loaded: " + options.localModelFile);
console.log("model not yet loaded: " + localModelFile);

const firModelLocalSourceBuilder = new com.google.firebase.ml.custom.model.FirebaseLocalModelSource.Builder(localModelName);

if (options.localModelFile.indexOf("~/") === 0) {
firModelLocalSourceBuilder.setFilePath(fs.knownFolders.currentApp().path + options.localModelFile.substring(1));
if (localModelFile.indexOf("~/") === 0) {
firModelLocalSourceBuilder.setFilePath(fs.knownFolders.currentApp().path + localModelFile.substring(1));
} else {
// note that this doesn't seem to work, let's advice users to use ~/ for now
firModelLocalSourceBuilder.setAssetFilePath(options.localModelFile);
firModelLocalSourceBuilder.setAssetFilePath(localModelFile);
}

localModelRegistrationSuccess = com.google.firebase.ml.custom.FirebaseModelManager.getInstance().registerLocalModelSource(firModelLocalSourceBuilder.build());
Expand All @@ -91,7 +142,7 @@ function getInterpreter(options: MLKitCustomModelOptions): any {
export function useCustomModel(options: MLKitCustomModelOptions): Promise<MLKitCustomModelResult> {
return new Promise((resolve, reject) => {
try {
const interpreter = getInterpreter(options);
const interpreter = getInterpreter(options.localModelFile);

let labels: Array<string>;
if (options.labelsFile.indexOf("~/") === 0) {
Expand Down Expand Up @@ -130,7 +181,8 @@ export function useCustomModel(options: MLKitCustomModelOptions): Promise<MLKitC
intArrayIn[2] = options.modelInput[0].shape[2];
intArrayIn[3] = options.modelInput[0].shape[3];

const inputType = options.modelInput[0].type === "FLOAT32" ? com.google.firebase.ml.custom.FirebaseModelDataType.FLOAT32 : com.google.firebase.ml.custom.FirebaseModelDataType.BYTE;
const isQuantized = options.modelInput[0].type !== "FLOAT32";
const inputType = isQuantized ? com.google.firebase.ml.custom.FirebaseModelDataType.BYTE : com.google.firebase.ml.custom.FirebaseModelDataType.FLOAT32;

let intArrayOut = Array.create("int", 2);
intArrayOut[0] = 1;
Expand All @@ -142,9 +194,7 @@ export function useCustomModel(options: MLKitCustomModelOptions): Promise<MLKitC
.build();

const image: android.graphics.Bitmap = options.image instanceof ImageSource ? options.image.android : options.image.imageSource.android;

const input = org.nativescript.plugins.firebase.mlkit.BitmapUtil.bitmapToByteBuffer(image, options.modelInput[0].shape[1], options.modelInput[0].shape[2]);

const input = org.nativescript.plugins.firebase.mlkit.BitmapUtil.bitmapToByteBuffer(image, options.modelInput[0].shape[1], options.modelInput[0].shape[2], isQuantized);
const inputs = new com.google.firebase.ml.custom.FirebaseModelInputs.Builder()
.add(input) // add as many input arrays as your model requires
.build();
Expand All @@ -161,16 +211,11 @@ export function useCustomModel(options: MLKitCustomModelOptions): Promise<MLKitC
});
}

function getImage(options: MLKitOptions): any /* com.google.firebase.ml.vision.common.FirebaseVisionImage */ {
const image: android.graphics.Bitmap = options.image instanceof ImageSource ? options.image.android : options.image.imageSource.android;
return com.google.firebase.ml.vision.common.FirebaseVisionImage.fromBitmap(image);
}

function getSortedResult(labels: Array<string>, probabilities: Array<number>, maxResults?: number): Array<MLKitCustomModelResultValue> {
function getSortedResult(labels: Array<string>, probabilities: Array<number>, maxResults = 5): Array<MLKitCustomModelResultValue> {
const result: Array<MLKitCustomModelResultValue> = [];
labels.forEach((text, i) => result.push({text, confidence: probabilities[i]}));
result.sort((a, b) => a.confidence < b.confidence ? 1 : (a.confidence === b.confidence ? 0 : -1));
if (maxResults && result.length > maxResults) {
if (result.length > maxResults) {
result.splice(maxResults);
}
result.map(r => r.confidence = (r.confidence & 0xff) / 255.0);
Expand Down
Loading

0 comments on commit df29885

Please sign in to comment.