Skip to content

Commit

Permalink
feat(prediction): enhance model loading and prediction functions for …
Browse files Browse the repository at this point in the history
…YOLO, Keras, and TFLite
  • Loading branch information
kshitijrajsharma committed Nov 13, 2024
1 parent 784da21 commit af94324
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 71 deletions.
188 changes: 122 additions & 66 deletions predictor/prediction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Standard library imports
import concurrent.futures
import os
import time
import uuid
Expand All @@ -9,17 +8,6 @@
# Third party imports
import numpy as np

try:
import tflite_runtime.interpreter as tflite
except ImportError:
print("TFlite_runtime is not installed.")
try:
from tensorflow import keras, lite

except ImportError:
print("Tensorflow is not installed , Predictions with .h5 or .tf won't work")


from .georeferencer import georeference
from .utils import open_images_keras, open_images_pillow, remove_files, save_mask

Expand All @@ -28,6 +16,107 @@
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


def get_model_type(path):
if path.endswith(".pt"):
return "yolo"
elif path.endswith(".tflite"):
return "tflite"
elif path.endswith(".h5") or path.endswith(".tf"):
return "keras"
else:
raise RuntimeError("Model type not supported")


def initialize_model(path):
"""Loads either keras, tflite, or yolo model."""
model_type = get_model_type(path)

if model_type == "yolo":
try:
import torch
from ultralytics import YOLO
except ImportError: # YOLO is not installed
raise ImportError("YOLO & torch is not installed.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = YOLO(path).to(device)
elif model_type == "tflite":
try:
import tflite_runtime.interpreter as tflite
except ImportError:
raise ImportError("TFlite_runtime is not installed.")
model = load_tflite_model(path)
elif model_type == "keras":
try:
from tensorflow import keras
except ImportError:
raise ImportError(
"Tensorflow is not installed, Predictions with .h5 or .tf won't work"
)
model = keras.models.load_model(path)
else:
return path
return model


def load_tflite_model(checkpoint_path):
try:
interpreter = tflite.Interpreter(model_path=checkpoint_path)
except Exception as ex:
interpreter = lite.Interpreter(model_path=checkpoint_path)
interpreter.allocate_tensors()
return interpreter


def predict_tflite(interpreter, image_batch, confidence):
input_tensor_index = interpreter.get_input_details()[0]["index"]
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
images = open_images_pillow(image_batch)
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3).astype(np.float32)
interpreter.set_tensor(input_tensor_index, images)
interpreter.invoke()
preds = output()
preds = np.argmax(preds, axis=-1)
preds = np.expand_dims(preds, axis=-1)
preds = np.where(preds > confidence, 1, 0)
return preds


def predict_keras(model, image_batch, confidence):
images = open_images_keras(image_batch)
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3)
preds = model.predict(images)
preds = np.argmax(preds, axis=-1)
preds = np.expand_dims(preds, axis=-1)
preds = np.where(preds > confidence, 1, 0)
return preds


def predict_yolo(model, image_paths, prediction_path, confidence):
for idx in range(0, len(image_paths), BATCH_SIZE):
batch = image_paths[idx : idx + BATCH_SIZE]
for i, r in enumerate(
model.predict(batch, conf=confidence, imgsz=IMAGE_SIZE, verbose=False)
):
if hasattr(r, "masks") and r.masks is not None:
preds = (
r.masks.data.max(dim=0)[0].detach().cpu().numpy()
) # Combine masks and convert to numpy
else:
preds = np.zeros(
(
IMAGE_SIZE,
IMAGE_SIZE,
),
dtype=np.float32,
) # Default if no masks
save_mask(preds, str(f"{prediction_path}/{Path(batch[i]).stem}.png"))


def save_predictions(preds, image_batch, prediction_path):
for idx, path in enumerate(image_batch):
save_mask(preds[idx], str(f"{prediction_path}/{Path(path).stem}.png"))


def run_prediction(
checkpoint_path: str,
input_path: str,
Expand Down Expand Up @@ -58,84 +147,51 @@ def run_prediction(
)
"""
if prediction_path is None:
# Generate a temporary download path using a UUID
temp_dir = os.path.join("/tmp", "prediction", str(uuid.uuid4()))
os.makedirs(temp_dir, exist_ok=True)
prediction_path = temp_dir

start = time.time()
print(f"Using : {checkpoint_path}")
if checkpoint_path.endswith(".tflite"):
try:
interpreter = tflite.Interpreter(model_path=checkpoint_path)
except Exception as ex:
interpreter = lite.Interpreter(model_path=checkpoint_path)

interpreter.resize_tensor_input(
interpreter.get_input_details()[0]["index"], (BATCH_SIZE, 256, 256, 3)
)
interpreter.allocate_tensors()
input_tensor_index = interpreter.get_input_details()[0]["index"]
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
else:
model = keras.models.load_model(checkpoint_path)
model_type = get_model_type(checkpoint_path)
model = initialize_model(checkpoint_path)

print(f"It took {round(time.time()-start)} sec to load model")
start = time.time()

os.makedirs(prediction_path, exist_ok=True)
image_paths = glob(f"{input_path}/*.png")
if checkpoint_path.endswith(".tflite"):

if model_type == "tflite":
for i in range((len(image_paths) + BATCH_SIZE - 1) // BATCH_SIZE):
image_batch = image_paths[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
if len(image_batch) != BATCH_SIZE:
interpreter.resize_tensor_input(
interpreter.get_input_details()[0]["index"],
model.resize_tensor_input(
model.get_input_details()[0]["index"],
(len(image_batch), 256, 256, 3),
)
interpreter.allocate_tensors()
input_tensor_index = interpreter.get_input_details()[0]["index"]
output = interpreter.tensor(
interpreter.get_output_details()[0]["index"]
)
images = open_images_pillow(image_batch)
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3).astype(np.float32)
interpreter.set_tensor(input_tensor_index, images)
interpreter.invoke()
preds = output()
preds = np.argmax(preds, axis=-1)
preds = np.expand_dims(preds, axis=-1)
preds = np.where(
preds > confidence, 1, 0
) # Filter out low confidence predictions

for idx, path in enumerate(image_batch):
save_mask(
preds[idx],
str(f"{prediction_path}/{Path(path).stem}.png"),
)
else:
model.allocate_tensors()
preds = predict_tflite(model, image_batch, confidence)
save_predictions(preds, image_batch, prediction_path)
elif model_type == "keras":
for i in range((len(image_paths) + BATCH_SIZE - 1) // BATCH_SIZE):
image_batch = image_paths[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
images = open_images_keras(image_batch)
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3)
preds = model.predict(images)
preds = np.argmax(preds, axis=-1)
preds = np.expand_dims(preds, axis=-1)
preds = np.where(
preds > confidence, 1, 0
) # Filter out low confidence predictions

for idx, path in enumerate(image_batch):
save_mask(
preds[idx],
str(f"{prediction_path}/{Path(path).stem}.png"),
)
preds = predict_keras(model, image_batch, confidence)
save_predictions(preds, image_batch, prediction_path)
elif model_type == "yolo":
predict_yolo(model, image_paths, prediction_path, confidence)
else:
raise RuntimeError("Loaded model is not supported")

print(
f"It took {round(time.time()-start)} sec to predict with {confidence} Confidence Threshold"
)
if not checkpoint_path.endswith(".tflite"):

if model_type == "keras":
keras.backend.clear_session()
del model

start = time.time()
georeference_path = os.path.join(prediction_path, "georeference")
georeference(
Expand Down
5 changes: 0 additions & 5 deletions predictor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@
from PIL import Image
from shapely.geometry import box

try:
from tensorflow import keras
except ImportError:
pass

IMAGE_SIZE = 256


Expand Down
3 changes: 3 additions & 0 deletions yolo-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-r requirements.txt
torch==2.5.1
ultralytics==8.3.29

0 comments on commit af94324

Please sign in to comment.