Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(JAQPOT-364): Implement DOA in inference #34

Merged
merged 18 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ fastapi==0.111.0
pydantic==2.7.1
uvicorn==0.29.0
starlette~=0.37.2
jaqpotpy==6.8.3
jaqpotpy==6.9.0

12 changes: 7 additions & 5 deletions src/handlers/predict_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from ..helpers.predict_methods import predict_onnx, predict_proba_onnx
import numpy as np

from jaqpotpy.doa import Leverage


def sklearn_post_handler(request: PredictionRequestPydantic):
model = model_decoder.decode(request.model["rawModel"])
data_entry_all, jaqpot_row_ids = json_to_predreq.decode(request)
prediction = predict_onnx(model, data_entry_all, request)
prediction, doa_predictions = predict_onnx(model, data_entry_all, request)
task = request.model["task"].lower()
if task == "binary_classification" or task == "multiclass_classification":
probabilities = predict_proba_onnx(model, data_entry_all, request)
Expand All @@ -22,7 +24,8 @@ def sklearn_post_handler(request: PredictionRequestPydantic):
results = {
feature["key"]: int(prediction[jaqpot_row_id, i])
if isinstance(
prediction[jaqpot_row_id, i], (np.int16, np.int32, np.int64, np.longlong)
prediction[jaqpot_row_id, i],
(np.int16, np.int32, np.int64, np.longlong),
)
else float(prediction[jaqpot_row_id, i])
if isinstance(
Expand All @@ -32,10 +35,9 @@ def sklearn_post_handler(request: PredictionRequestPydantic):
for i, feature in enumerate(request.model["dependentFeatures"])
}
results["jaqpotMetadata"] = {
"AD": None,
"doa": doa_predictions[jaqpot_row_id] if doa_predictions else None,
"probabilities": probabilities[jaqpot_row_id],
"jaqpotRowId": jaqpot_row_id
"jaqpotRowId": jaqpot_row_id,
}
final_all.append(results)

return {"predictions": final_all}
90 changes: 86 additions & 4 deletions src/helpers/predict_methods.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,78 @@
import numpy as np
import pandas as pd
import onnx
from onnxruntime import InferenceSession
from jaqpotpy.datasets import JaqpotpyDataset
from src.helpers.recreate_preprocessor import recreate_preprocessor
from jaqpotpy.doa import Leverage, BoundingBox, MeanVar


def calculate_doas(input_feed, request):
"""
Calculate the Domain of Applicability (DoA) for given input data using specified methods.
Args:
input_feed (dict): A dictionary containing the input data under the key "input".
request (object): An object containing the model information, specifically the DoA methods
and their corresponding data under the key "model".
Returns:
list: A list of dictionaries where each dictionary contains the DoA predictions for a single
data instance. The keys in the dictionary are the names of the DoA methods used, and
the values are the corresponding DoA predictions.
"""

doas_results = []
input_df = pd.DataFrame(input_feed["input"])
for _, data_instance in input_df.iterrows():
doa_instance_prediction = {}
for doa_data in request.model["doas"]:
if doa_data["method"] == "LEVERAGE":
doa_method = Leverage()
doa_method.h_star = doa_data["doaData"]["hStar"]
doa_method.doa_matrix = doa_data["doaData"]["doaMatrix"]
elif doa_data["method"] == "BOUNDING_BOX":
doa_method = BoundingBox()
doa_method.bounding_box = doa_data["doaData"]["boundingBox"]
elif doa_data["method"] == "MEAN_VAR":
doa_method = MeanVar()
doa_method.bounds = doa_data["doaData"]["bounds"]
doa_instance_prediction[doa_method.__name__] = doa_method.predict(
pd.DataFrame(data_instance.values.reshape(1, -1))
)[0]
# Majority voting
if len(request.model["doas"]) > 1:
in_doa_values = [
value["inDoa"] for value in doa_instance_prediction.values()
]
doa_instance_prediction["majorityVoting"] = (
in_doa_values.count(True) > (len(in_doa_values) / 2)
)
else:
doa_instance_prediction["majorityVoting"] = None
doas_results.append(doa_instance_prediction)
return doas_results


def predict_onnx(model, dataset: JaqpotpyDataset, request):
"""
Perform prediction using an ONNX model.
Parameters:
model (onnx.ModelProto): The ONNX model to be used for prediction.
dataset (JaqpotpyDataset): The dataset containing the input features.
request (dict): A dictionary containing additional configuration for the prediction,
including model-specific settings and preprocessors.
Returns:
tuple: A tuple containing the ONNX model predictions and DOA results (if applicable).
The function performs the following steps:
1. Initializes an ONNX InferenceSession with the serialized model.
2. Prepares the input feed by converting dataset features to the appropriate numpy data types.
3. If doas (Domain of Applicability) is requested, it calculates the DOAS results.
4. Runs the ONNX model to get predictions.
5. Applies any specified preprocessors in reverse order to the predictions.
6. Flattens the predictions if there is only one dependent feature.
Note:
- The function assumes that the dataset features and model inputs are aligned.
- The request dictionary should contain the necessary configuration for preprocessors and DOAS.
"""

sess = InferenceSession(model.SerializeToString())
input_feed = {}
for independent_feature in model.graph.input:
Expand All @@ -21,8 +88,12 @@ def predict_onnx(model, dataset: JaqpotpyDataset, request):
.values.astype(np_dtype)
.reshape(-1, 1)
)
onnx_prediction = sess.run(None, input_feed)
onnx_prediction = onnx_prediction[0]
if request.model["doas"]:
doas_results = calculate_doas(input_feed, request)
else:
doas_results = None

onnx_prediction = sess.run(None, input_feed)[0]

if request.model["extraConfig"]["preprocessors"]:
for i in reversed(range(len(request.model["extraConfig"]["preprocessors"]))):
Expand All @@ -45,10 +116,21 @@ def predict_onnx(model, dataset: JaqpotpyDataset, request):
if len(request.model["dependentFeatures"]) == 1:
onnx_prediction = onnx_prediction.flatten()

return onnx_prediction
return onnx_prediction, doas_results


def predict_proba_onnx(model, dataset: JaqpotpyDataset, request):
"""
Predict the probability estimates for a given dataset using an ONNX model.
Parameters:
model (onnx.ModelProto): The ONNX model used for prediction.
dataset (JaqpotpyDataset): The dataset containing the features for prediction.
request (dict): A dictionary containing additional request information, including model configuration.
Returns:
list: A list of dictionaries where each dictionary contains the predicted probabilities for each class,
with class labels as keys and rounded probability values as values.
"""

sess = InferenceSession(model.SerializeToString())
input_feed = {}
for independent_feature in model.graph.input:
Expand Down
Loading