Skip to content

Commit

Permalink
Merge pull request #53 from davidesalerno/additionalModels
Browse files Browse the repository at this point in the history
Add feature to serve additional models without any need to rebuild th…
  • Loading branch information
robmarkcole authored Nov 28, 2023
2 parents 7b42156 + 7cc4fa0 commit e1f3587
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ Which should return:
```
An example request using the python requests package is in `tests/live-test.py`

## Additional models

If you would like to serve additional models to the 3 that are shipped out-of-the-box with this project, you can do it adding an `additional` folder to the `models` one.

You can then ask for predictions to the additonal models using this for `detection`:

```
curl -X POST "http://localhost:5000/v1/detection/{model_name}" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "image=........;type=image/jpeg"
```

replacing `{model_name}` with the folder name where are store the `model.tflite` and optionally the `labels.txt` files.

If you would like instead ask for a prediction to a `classification` model, the `curl` request template is:
```
curl -X POST "http://localhost:5000/v1/classification/{model_name}" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "image=........;type=image/jpeg"
```


## Add tflite-server as a service
You can run tflite-server as a [service](https://www.raspberrypi.org/documentation/linux/usage/systemd.md), which means tflite-server will automatically start on RPi boot, and can be easily started & stopped. Create the service file in the appropriate location on the RPi using: ```sudo nano /etc/systemd/system/tflite-server.service```

Expand Down
88 changes: 88 additions & 0 deletions tflite-server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from helpers import classify_image, read_labels, set_input_tensor

from os.path import exists

app = FastAPI()

# Settings
Expand All @@ -25,6 +27,9 @@
OBJ_LABELS = "models/object_detection/mobilenet_ssd_v2_coco/coco_labels.txt"
SCENE_MODEL = "models/classification/dogs-vs-cats/model.tflite"
SCENE_LABELS = "models/classification/dogs-vs-cats/labels.txt"
ADDITIONAL_PREFIX = "models/additional/"
ADDITIONAL_MODEL = "$$MODEL_NAME$$/model.tflite"
ADDITIONAL_LABELS = "$$MODEL_NAME$$/labels.txt"

# Setup object detection
obj_interpreter = tflite.Interpreter(model_path=OBJ_MODEL)
Expand Down Expand Up @@ -53,6 +58,25 @@
scene_labels = read_labels(SCENE_LABELS)


def build_interpreter(model_name):
model_path = ADDITIONAL_MODEL.replace("$$MODEL_NAME$$", model_name)
model_labels = ADDITIONAL_LABELS.replace("$$MODEL_NAME$$", model_name)
return inner_interpreter_builder(ADDITIONAL_PREFIX+model_path, ADDITIONAL_PREFIX+model_labels)

def inner_interpreter_builder(model_path, model_labels):
interpreter = tflite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_height = input_details[0]["shape"][1]
input_width = input_details[0]["shape"][2]
file_exists = exists(model_labels)
if file_exists:
labels = read_labels(model_labels)
else:
labels = None
return interpreter, input_details, output_details, input_height, input_width, labels

@app.get("/")
async def info():
return """tflite-server docs at ip:port/docs"""
Expand Down Expand Up @@ -159,3 +183,67 @@ async def predict_scene(image: UploadFile = File(...)):
except:
e = sys.exc_info()[1]
raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/vision/detection/{model_name}")
async def predict_additional_vision_detection(model_name: str, image: UploadFile = File(...)):
try:
interpreter, input_details, output_details, input_height, input_width, labels = build_interpreter(model_name)
contents = await image.read()
image = Image.open(io.BytesIO(contents))
image_width = image.size[0]
image_height = image.size[1]

# Format data and send to interpreter
resized_image = image.resize((input_width, input_height), Image.ANTIALIAS)
input_data = np.expand_dims(resized_image, axis=0)
interpreter.set_tensor(input_details[0]["index"], input_data)

# Process image and get predictions
interpreter.invoke()
boxes = interpreter.get_tensor(output_details[0]["index"])[0]
classes = interpreter.get_tensor(output_details[1]["index"])[0]
scores = interpreter.get_tensor(output_details[2]["index"])[0]

data = {}
items = []
for i in range(len(scores)):
if not classes[i] == 0: # Item
continue
single_item = {}
single_item["userid"] = "unknown"
if labels is not None:
single_item["label"] = labels[int(classes[i])]
single_item["confidence"] = float(scores[i])
single_item["y_min"] = int(float(boxes[i][0]) * image_height)
single_item["x_min"] = int(float(boxes[i][1]) * image_width)
single_item["y_max"] = int(float(boxes[i][2]) * image_height)
single_item["x_max"] = int(float(boxes[i][3]) * image_width)
if single_item["confidence"] < MIN_CONFIDENCE:
continue
items.append(single_item)

data["predictions"] = items
data["success"] = True
return data
except:
e = sys.exc_info()[1]
raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/vision/classification/{model_name}")
async def predict_additional_vision_classification(model_name: str, image: UploadFile = File(...)):
try:
interpreter, input_details, output_details, input_height, input_width, labels = build_interpreter(model_name)
contents = await image.read()
image = Image.open(io.BytesIO(contents))
resized_image = image.resize((input_width, input_height), Image.ANTIALIAS)
results = classify_image(interpreter, image=resized_image)
label_id, prob = results[0]
data = {}
data["label"] = labels[label_id]
data["confidence"] = prob
data["success"] = True
return data
except:
e = sys.exc_info()[1]
raise HTTPException(status_code=500, detail=str(e))

0 comments on commit e1f3587

Please sign in to comment.