Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
robmarkcole committed Jun 11, 2020
1 parent 80e6397 commit f8cc459
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 11 deletions.
23 changes: 23 additions & 0 deletions helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Helper utilities.
"""
import numpy as np


def read_labels(file_path):
Expand All @@ -14,3 +15,25 @@ def read_labels(file_path):
pair = line.strip().split(maxsplit=1)
ret[int(pair[0])] = pair[1].strip()
return ret


def set_input_tensor(interpreter, image):
tensor_index = interpreter.get_input_details()[0]["index"]
input_tensor = interpreter.tensor(tensor_index)()[0]
input_tensor[:, :] = image


def classify_image(interpreter, image, top_k=1):
"""Returns a sorted array of classification results."""
set_input_tensor(interpreter, image)
interpreter.invoke()
output_details = interpreter.get_output_details()[0]
output = np.squeeze(interpreter.get_tensor(output_details["index"]))

# If the model is quantized (uint8 data), then dequantize the results
if output_details["dtype"] == np.uint8:
scale, zero_point = output_details["quantization"]
output = scale * (output - zero_point)

ordered = np.argpartition(-output, top_k)
return [(i, output[i]) for i in ordered[:top_k]]
44 changes: 33 additions & 11 deletions tflite-server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tflite_runtime.interpreter as tflite
from PIL import Image

from helpers import read_labels
from helpers import read_labels, set_input_tensor, classify_image

app = flask.Flask(__name__)

Expand All @@ -34,14 +34,15 @@

SCENE_URL = "/v1/vision/scene"
SCENE_MODEL = "models/classification/dogs-vs-cats/model.tflite"
SCENE_LABEL = "models/classification/dogs-vs-cats/labels.txt"
SCENE_LABELS = "models/classification/dogs-vs-cats/labels.txt"


@app.route("/")
def info():
return f"""
Object detection model: {OBJ_MODEL.split("/")[-2]} \n
Face detection model: {FACE_MODEL.split("/")[-2]} \n
Scene model: {SCENE_MODEL.split("/")[-2]} \n
""".replace(
"\n", "<br>"
)
Expand All @@ -50,10 +51,6 @@ def info():
@app.route(FACE_DETECTION_URL, methods=["POST"])
def predict_face():
data = {"success": False}
print(
f"Received request from {flask.request.remote_addr} on {FACE_DETECTION_URL}",
file=sys.stderr,
)
if not flask.request.method == "POST":
return

Expand Down Expand Up @@ -101,10 +98,6 @@ def predict_face():
@app.route(OBJ_DETECTION_URL, methods=["POST"])
def predict_object():
data = {"success": False}
print(
f"Received request from {flask.request.remote_addr} on {OBJ_DETECTION_URL}",
file=sys.stderr,
)
if not flask.request.method == "POST":
return

Expand Down Expand Up @@ -146,6 +139,34 @@ def predict_object():
return flask.jsonify(data)


@app.route(SCENE_URL, methods=["POST"])
def predict_scene():
data = {"success": False}
if not flask.request.method == "POST":
return

if flask.request.files.get("image"):
# Open image and get bytes and size
image_file = flask.request.files["image"]
image_bytes = image_file.read()
image = Image.open(io.BytesIO(image_bytes)) # A PIL image
# Format data and send to interpreter
resized_image = image.resize(
(scene_input_width, scene_input_height), Image.ANTIALIAS
)
results = classify_image(scene_interpreter, image=resized_image)

print(
f"results[0]: {results[0]}", file=sys.stderr,
)
label_id, prob = results[0]

data["label"] = scene_labels[label_id]
data["confidence"] = prob
data["success"] = True
return flask.jsonify(data)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Flask app exposing tflite models"
Expand All @@ -171,11 +192,12 @@ def predict_object():
face_input_width = face_input_details[0]["shape"][2] # 320

# Setup face detection
scene_interpreter = tflite.Interpreter(model_path=FACE_MODEL)
scene_interpreter = tflite.Interpreter(model_path=SCENE_MODEL)
scene_interpreter.allocate_tensors()
scene_input_details = scene_interpreter.get_input_details()
scene_output_details = scene_interpreter.get_output_details()
scene_input_height = scene_input_details[0]["shape"][1]
scene_input_width = scene_input_details[0]["shape"][2]
scene_labels = read_labels(SCENE_LABELS)

app.run(host="0.0.0.0", debug=True, port=args.port)

0 comments on commit f8cc459

Please sign in to comment.