diff --git a/README.md b/README.md index a45e69d..7a65e1f 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,24 @@ Which should return: ``` An example request using the python requests package is in `tests/live-test.py` +## Additional models + +If you would like to serve additional models to the 3 that are shipped out-of-the-box with this project, you can do it adding an `additional` folder to the `models` one. + +You can then ask for predictions to the additonal models using this for `detection`: + +``` +curl -X POST "http://localhost:5000/v1/detection/{model_name}" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "image=........;type=image/jpeg" +``` + +replacing `{model_name}` with the folder name where are store the `model.tflite` and optionally the `labels.txt` files. + +If you would like instead ask for a prediction to a `classification` model, the `curl` request template is: +``` +curl -X POST "http://localhost:5000/v1/classification/{model_name}" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "image=........;type=image/jpeg" +``` + + ## Add tflite-server as a service You can run tflite-server as a [service](https://www.raspberrypi.org/documentation/linux/usage/systemd.md), which means tflite-server will automatically start on RPi boot, and can be easily started & stopped. Create the service file in the appropriate location on the RPi using: ```sudo nano /etc/systemd/system/tflite-server.service``` diff --git a/tflite-server.py b/tflite-server.py index 4271498..bc22aa5 100644 --- a/tflite-server.py +++ b/tflite-server.py @@ -11,6 +11,8 @@ from helpers import classify_image, read_labels, set_input_tensor +from os.path import exists + app = FastAPI() # Settings @@ -25,6 +27,9 @@ OBJ_LABELS = "models/object_detection/mobilenet_ssd_v2_coco/coco_labels.txt" SCENE_MODEL = "models/classification/dogs-vs-cats/model.tflite" SCENE_LABELS = "models/classification/dogs-vs-cats/labels.txt" +ADDITIONAL_PREFIX = "models/additional/" +ADDITIONAL_MODEL = "$$MODEL_NAME$$/model.tflite" +ADDITIONAL_LABELS = "$$MODEL_NAME$$/labels.txt" # Setup object detection obj_interpreter = tflite.Interpreter(model_path=OBJ_MODEL) @@ -53,6 +58,25 @@ scene_labels = read_labels(SCENE_LABELS) +def build_interpreter(model_name): + model_path = ADDITIONAL_MODEL.replace("$$MODEL_NAME$$", model_name) + model_labels = ADDITIONAL_LABELS.replace("$$MODEL_NAME$$", model_name) + return inner_interpreter_builder(ADDITIONAL_PREFIX+model_path, ADDITIONAL_PREFIX+model_labels) + +def inner_interpreter_builder(model_path, model_labels): + interpreter = tflite.Interpreter(model_path=model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + input_height = input_details[0]["shape"][1] + input_width = input_details[0]["shape"][2] + file_exists = exists(model_labels) + if file_exists: + labels = read_labels(model_labels) + else: + labels = None + return interpreter, input_details, output_details, input_height, input_width, labels + @app.get("/") async def info(): return """tflite-server docs at ip:port/docs""" @@ -159,3 +183,67 @@ async def predict_scene(image: UploadFile = File(...)): except: e = sys.exc_info()[1] raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/v1/vision/detection/{model_name}") +async def predict_additional_vision_detection(model_name: str, image: UploadFile = File(...)): + try: + interpreter, input_details, output_details, input_height, input_width, labels = build_interpreter(model_name) + contents = await image.read() + image = Image.open(io.BytesIO(contents)) + image_width = image.size[0] + image_height = image.size[1] + + # Format data and send to interpreter + resized_image = image.resize((input_width, input_height), Image.ANTIALIAS) + input_data = np.expand_dims(resized_image, axis=0) + interpreter.set_tensor(input_details[0]["index"], input_data) + + # Process image and get predictions + interpreter.invoke() + boxes = interpreter.get_tensor(output_details[0]["index"])[0] + classes = interpreter.get_tensor(output_details[1]["index"])[0] + scores = interpreter.get_tensor(output_details[2]["index"])[0] + + data = {} + items = [] + for i in range(len(scores)): + if not classes[i] == 0: # Item + continue + single_item = {} + single_item["userid"] = "unknown" + if labels is not None: + single_item["label"] = labels[int(classes[i])] + single_item["confidence"] = float(scores[i]) + single_item["y_min"] = int(float(boxes[i][0]) * image_height) + single_item["x_min"] = int(float(boxes[i][1]) * image_width) + single_item["y_max"] = int(float(boxes[i][2]) * image_height) + single_item["x_max"] = int(float(boxes[i][3]) * image_width) + if single_item["confidence"] < MIN_CONFIDENCE: + continue + items.append(single_item) + + data["predictions"] = items + data["success"] = True + return data + except: + e = sys.exc_info()[1] + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/v1/vision/classification/{model_name}") +async def predict_additional_vision_classification(model_name: str, image: UploadFile = File(...)): + try: + interpreter, input_details, output_details, input_height, input_width, labels = build_interpreter(model_name) + contents = await image.read() + image = Image.open(io.BytesIO(contents)) + resized_image = image.resize((input_width, input_height), Image.ANTIALIAS) + results = classify_image(interpreter, image=resized_image) + label_id, prob = results[0] + data = {} + data["label"] = labels[label_id] + data["confidence"] = prob + data["success"] = True + return data + except: + e = sys.exc_info()[1] + raise HTTPException(status_code=500, detail=str(e)) +