Skip to content

Commit

Permalink
Merge pull request #14 from robmarkcole/add-scene
Browse files Browse the repository at this point in the history
Add scene
  • Loading branch information
robmarkcole authored Jun 11, 2020
2 parents 3295a6f + f8cc459 commit 8f7d6de
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 54 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# tensorflow-lite-rest-server
Expose tensorflow-lite models via a rest API, and currently object detection is supported. Can be hosted on any of the common platforms including RPi, linux desktop, Mac and Windows.
Expose tensorflow-lite models via a rest API, and currently object, face & scene detection is supported. Can be hosted on any of the common platforms including RPi, linux desktop, Mac and Windows.

## Setup
In this process we create a virtual environment (venv), then install tensorflow-lite [as per these instructions](https://www.tensorflow.org/lite/guide/python) which is platform specific, and finally install the remaining requirements. Note on an RPi (only) it is necessary to manually install pip3, numpy, pillow.
Expand Down Expand Up @@ -58,6 +58,13 @@ To detect faces:
curl -X POST -F image=@tests/faces.jpg 'http://localhost:5000/v1/vision/face'
```

To run the scene:
```
curl -X POST -F image=@tests/cat.jpg 'http://localhost:5000/v1/vision/scene'
or
curl -X POST -F image=@tests/dog.jpg 'http://localhost:5000/v1/vision/scene'
```

## Deepstack, Home Assistant & UI
This API can be used as a drop in replacement for [deepstack object detection](https://github.com/robmarkcole/HASS-Deepstack-object) and [deepstack face detection](https://github.com/robmarkcole/HASS-Deepstack-face) (configuring `detect_only: True`) in Home Assistant. I also created a UI for viewing the predictions of the object detection model [here](https://github.com/robmarkcole/deepstack-ui).

Expand Down
66 changes: 27 additions & 39 deletions helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from PIL import ImageDraw
from typing import Tuple
"""
Helper utilities.
"""
import numpy as np

def read_coco_labels(file_path):

def read_labels(file_path):
"""
Helper for loading coco_labels.txt
Helper for loading labels.txt
"""
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
Expand All @@ -13,39 +16,24 @@ def read_coco_labels(file_path):
ret[int(pair[0])] = pair[1].strip()
return ret

def draw_box(
draw: ImageDraw,
box: Tuple[float, float, float, float],
img_width: int,
img_height: int,
text: str = "",
color: Tuple[int, int, int] = (255, 255, 0),
) -> None:
"""
Draw a bounding box on and image.
The bounding box is defined by the tuple (y_min, x_min, y_max, x_max)
where the coordinates are floats in the range [0.0, 1.0] and
relative to the width and height of the image.
For example, if an image is 100 x 200 pixels (height x width) and the bounding
box is `(0.1, 0.2, 0.5, 0.9)`, the upper-left and bottom-right coordinates of
the bounding box will be `(40, 10)` to `(180, 50)` (in (x,y) coordinates).
"""

line_width = 3
font_height = 8
y_min, x_min, y_max, x_max = box
(left, right, top, bottom) = (
x_min * img_width,
x_max * img_width,
y_min * img_height,
y_max * img_height,
)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
width=line_width,
fill=color,
)
if text:
draw.text(
(left + line_width, abs(top - line_width - font_height)), text, fill=color
)
def set_input_tensor(interpreter, image):
tensor_index = interpreter.get_input_details()[0]["index"]
input_tensor = interpreter.tensor(tensor_index)()[0]
input_tensor[:, :] = image


def classify_image(interpreter, image, top_k=1):
"""Returns a sorted array of classification results."""
set_input_tensor(interpreter, image)
interpreter.invoke()
output_details = interpreter.get_output_details()[0]
output = np.squeeze(interpreter.get_tensor(output_details["index"]))

# If the model is quantized (uint8 data), then dequantize the results
if output_details["dtype"] == np.uint8:
scale, zero_point = output_details["quantization"]
output = scale * (output - zero_point)

ordered = np.argpartition(-output, top_k)
return [(i, output[i]) for i in ordered[:top_k]]
5 changes: 5 additions & 0 deletions models/classification/dogs-vs-cats/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
## dogs vs cats
* Custom model trained on [teachable machine image classification](https://teachablemachine.withgoogle.com/train/image)
* Dataset: 30 cat and 30 dog images from [kaggle dogs vs cats](https://www.kaggle.com/c/dogs-vs-cats)
* Input size: 224x224
* Type: tensorflow lite quantized
2 changes: 2 additions & 0 deletions models/classification/dogs-vs-cats/labels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
0 dog
1 cat
Binary file added models/classification/dogs-vs-cats/model.tflite
Binary file not shown.
Binary file added tests/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
66 changes: 52 additions & 14 deletions tflite-server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tflite_runtime.interpreter as tflite
from PIL import Image

from helpers import read_coco_labels
from helpers import read_labels, set_input_tensor, classify_image

app = flask.Flask(__name__)

Expand All @@ -32,12 +32,17 @@
OBJ_MODEL = "models/object_detection/mobilenet_ssd_v2_coco/mobilenet_ssd_v2_coco_quant_postprocess.tflite"
OBJ_LABELS = "models/object_detection/mobilenet_ssd_v2_coco/coco_labels.txt"

SCENE_URL = "/v1/vision/scene"
SCENE_MODEL = "models/classification/dogs-vs-cats/model.tflite"
SCENE_LABELS = "models/classification/dogs-vs-cats/labels.txt"


@app.route("/")
def info():
return f"""
Object detection model: {OBJ_MODEL.split("/")[-2]} \n
Face detection model: {FACE_MODEL.split("/")[-2]} \n
Scene model: {SCENE_MODEL.split("/")[-2]} \n
""".replace(
"\n", "<br>"
)
Expand All @@ -46,10 +51,6 @@ def info():
@app.route(FACE_DETECTION_URL, methods=["POST"])
def predict_face():
data = {"success": False}
print(
f"Received request from {flask.request.remote_addr} on {FACE_DETECTION_URL}",
file=sys.stderr,
)
if not flask.request.method == "POST":
return

Expand All @@ -69,7 +70,9 @@ def predict_face():
# Process image and get predictions
face_interpreter.invoke()
boxes = face_interpreter.get_tensor(face_output_details[0]["index"])[0]
classes = face_interpreter.get_tensor(face_output_details[1]["index"])[0]
classes = face_interpreter.get_tensor(face_output_details[1]["index"])[
0
]
scores = face_interpreter.get_tensor(face_output_details[2]["index"])[0]

faces = []
Expand All @@ -95,10 +98,6 @@ def predict_face():
@app.route(OBJ_DETECTION_URL, methods=["POST"])
def predict_object():
data = {"success": False}
print(
f"Received request from {flask.request.remote_addr} on {OBJ_DETECTION_URL}",
file=sys.stderr,
)
if not flask.request.method == "POST":
return

Expand Down Expand Up @@ -140,8 +139,38 @@ def predict_object():
return flask.jsonify(data)


@app.route(SCENE_URL, methods=["POST"])
def predict_scene():
data = {"success": False}
if not flask.request.method == "POST":
return

if flask.request.files.get("image"):
# Open image and get bytes and size
image_file = flask.request.files["image"]
image_bytes = image_file.read()
image = Image.open(io.BytesIO(image_bytes)) # A PIL image
# Format data and send to interpreter
resized_image = image.resize(
(scene_input_width, scene_input_height), Image.ANTIALIAS
)
results = classify_image(scene_interpreter, image=resized_image)

print(
f"results[0]: {results[0]}", file=sys.stderr,
)
label_id, prob = results[0]

data["label"] = scene_labels[label_id]
data["confidence"] = prob
data["success"] = True
return flask.jsonify(data)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flask app exposing tflite models")
parser = argparse.ArgumentParser(
description="Flask app exposing tflite models"
)
parser.add_argument("--port", default=5000, type=int, help="port number")
args = parser.parse_args()

Expand All @@ -152,14 +181,23 @@ def predict_object():
obj_output_details = obj_interpreter.get_output_details()
obj_input_height = obj_input_details[0]["shape"][1]
obj_input_width = obj_input_details[0]["shape"][2]
obj_labels = read_coco_labels(OBJ_LABELS)
obj_labels = read_labels(OBJ_LABELS)

# Setup face detection
face_interpreter = tflite.Interpreter(model_path=FACE_MODEL)
face_interpreter.allocate_tensors()
face_input_details = face_interpreter.get_input_details()
face_output_details = face_interpreter.get_output_details()
face_input_height = 320
face_input_width = 320
face_input_height = face_input_details[0]["shape"][1] # 320
face_input_width = face_input_details[0]["shape"][2] # 320

# Setup face detection
scene_interpreter = tflite.Interpreter(model_path=SCENE_MODEL)
scene_interpreter.allocate_tensors()
scene_input_details = scene_interpreter.get_input_details()
scene_output_details = scene_interpreter.get_output_details()
scene_input_height = scene_input_details[0]["shape"][1]
scene_input_width = scene_input_details[0]["shape"][2]
scene_labels = read_labels(SCENE_LABELS)

app.run(host="0.0.0.0", debug=True, port=args.port)

0 comments on commit 8f7d6de

Please sign in to comment.