Skip to content

Commit

Permalink
Update MainActivity.java
Browse files Browse the repository at this point in the history
  • Loading branch information
johnolafenwa authored Jun 2, 2018
1 parent 5f1dc48 commit 683c690
Showing 1 changed file with 20 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,18 @@

public class MainActivity extends AppCompatActivity {

//Load the tensorflow inference library
static {
System.loadLibrary("tensorflow_inference");
}

//PATH TO OUR MODEL FILE AND NAMES OF THE INPUT AND OUTPUT NODES
private String MODEL_PATH = "file:///android_asset/squeezenet.pb";
private String INPUT_NAME = "input_1";
private String OUTPUT_NAME = "output_1";
private TensorFlowInferenceInterface tf;

//ARRAY TO HOLD THE PREDICTIONS AND FLOAT VALUES TO HOLD THE IMAGE DATA
float[] PREDICTIONS = new float[1000];
private float[] floatValues;
private int[] INPUT_SIZE = {224,224,3};
Expand All @@ -46,12 +49,14 @@ protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);

tf = new TensorFlowInferenceInterface(getAssets(),MODEL_PATH);


Toolbar toolbar = (Toolbar) findViewById(R.id.toolbar);
setSupportActionBar(toolbar);


//initialize tensorflow with the AssetManager and the Model
tf = new TensorFlowInferenceInterface(getAssets(),MODEL_PATH);

imageView = (ImageView) findViewById(R.id.imageview);
resultView = (TextView) findViewById(R.id.results);

Expand All @@ -66,6 +71,7 @@ public void onClick(View view) {

try{

//READ THE IMAGE FROM ASSETS FOLDER
InputStream imageStream = getAssets().open("testimage.jpg");

Bitmap bitmap = BitmapFactory.decodeStream(imageStream);
Expand All @@ -84,6 +90,7 @@ public void onClick(View view) {
});
}

//FUNCTION TO COMPUTE THE MAXIMUM PREDICTION AND ITS CONFIDENCE
public Object[] argmax(float[] array){


Expand All @@ -110,20 +117,29 @@ public Object[] argmax(float[] array){
public void predict(final Bitmap bitmap){


//Runs inference in background thread
new AsyncTask<Integer,Integer,Integer>(){

@Override

protected Integer doInBackground(Integer ...params){

//Resize the image into 224 x 224
Bitmap resized_image = ImageUtils.processBitmap(bitmap,224);

//Normalize the pixels
floatValues = ImageUtils.normalizeBitmap(resized_image,224,127.5f,1.0f);

//Pass input into the tensorflow
tf.feed(INPUT_NAME,floatValues,1,224,224,3);

//compute predictions
tf.run(new String[]{OUTPUT_NAME});

//copy the output into the PREDICTIONS array
tf.fetch(OUTPUT_NAME,PREDICTIONS);

//Obtained highest prediction
Object[] results = argmax(PREDICTIONS);


Expand All @@ -135,10 +151,12 @@ protected Integer doInBackground(Integer ...params){

final String conf = String.valueOf(confidence * 100).substring(0,5);

//Convert predicted class index into actual label name
final String label = ImageUtils.getLabel(getAssets().open("labels.json"),class_index);



//Display result on UI
runOnUiThread(new Runnable() {
@Override
public void run() {
Expand Down

0 comments on commit 683c690

Please sign in to comment.