diff --git a/android-sample/app/src/main/java/com/specpal/mobileai/MainActivity.java b/android-sample/app/src/main/java/com/specpal/mobileai/MainActivity.java index 97ad1ba..739e097 100644 --- a/android-sample/app/src/main/java/com/specpal/mobileai/MainActivity.java +++ b/android-sample/app/src/main/java/com/specpal/mobileai/MainActivity.java @@ -24,15 +24,18 @@ public class MainActivity extends AppCompatActivity { + //Load the tensorflow inference library static { System.loadLibrary("tensorflow_inference"); } + //PATH TO OUR MODEL FILE AND NAMES OF THE INPUT AND OUTPUT NODES private String MODEL_PATH = "file:///android_asset/squeezenet.pb"; private String INPUT_NAME = "input_1"; private String OUTPUT_NAME = "output_1"; private TensorFlowInferenceInterface tf; + //ARRAY TO HOLD THE PREDICTIONS AND FLOAT VALUES TO HOLD THE IMAGE DATA float[] PREDICTIONS = new float[1000]; private float[] floatValues; private int[] INPUT_SIZE = {224,224,3}; @@ -46,12 +49,14 @@ protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); - tf = new TensorFlowInferenceInterface(getAssets(),MODEL_PATH); - Toolbar toolbar = (Toolbar) findViewById(R.id.toolbar); setSupportActionBar(toolbar); + + //initialize tensorflow with the AssetManager and the Model + tf = new TensorFlowInferenceInterface(getAssets(),MODEL_PATH); + imageView = (ImageView) findViewById(R.id.imageview); resultView = (TextView) findViewById(R.id.results); @@ -66,6 +71,7 @@ public void onClick(View view) { try{ + //READ THE IMAGE FROM ASSETS FOLDER InputStream imageStream = getAssets().open("testimage.jpg"); Bitmap bitmap = BitmapFactory.decodeStream(imageStream); @@ -84,6 +90,7 @@ public void onClick(View view) { }); } + //FUNCTION TO COMPUTE THE MAXIMUM PREDICTION AND ITS CONFIDENCE public Object[] argmax(float[] array){ @@ -110,20 +117,29 @@ public Object[] argmax(float[] array){ public void predict(final Bitmap bitmap){ + //Runs inference in background thread new AsyncTask(){ @Override protected Integer doInBackground(Integer ...params){ + //Resize the image into 224 x 224 Bitmap resized_image = ImageUtils.processBitmap(bitmap,224); + + //Normalize the pixels floatValues = ImageUtils.normalizeBitmap(resized_image,224,127.5f,1.0f); + //Pass input into the tensorflow tf.feed(INPUT_NAME,floatValues,1,224,224,3); + + //compute predictions tf.run(new String[]{OUTPUT_NAME}); + //copy the output into the PREDICTIONS array tf.fetch(OUTPUT_NAME,PREDICTIONS); + //Obtained highest prediction Object[] results = argmax(PREDICTIONS); @@ -135,10 +151,12 @@ protected Integer doInBackground(Integer ...params){ final String conf = String.valueOf(confidence * 100).substring(0,5); + //Convert predicted class index into actual label name final String label = ImageUtils.getLabel(getAssets().open("labels.json"),class_index); + //Display result on UI runOnUiThread(new Runnable() { @Override public void run() {