Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(JAQPOT-391): support two onnx models #47

Merged
merged 8 commits into from
Nov 4, 2024
26 changes: 14 additions & 12 deletions src/handlers/predict_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,32 @@

def sklearn_post_handler(request: PredictionRequest) -> PredictionResponse:
model = model_decoder.decode(request.model.raw_model)
preprocessor = (
model_decoder.decode(request.model.raw_preprocessor)
if request.model.raw_preprocessor
else None
)
data_entry_all, jaqpot_row_ids = json_to_predreq.decode(request)
prediction, doa_predictions = predict_onnx(model, data_entry_all, request)
task = request.model.task.lower()
if task == "binary_classification" or task == "multiclass_classification":
probabilities = predict_proba_onnx(model, data_entry_all, request)
else:
probabilities = [None for _ in range(len(prediction))]
predicted_values, probabilities, doa_predictions = predict_onnx(
model, preprocessor, data_entry_all, request
)

predictions = []
for jaqpot_row_id in jaqpot_row_ids:
if len(request.model.dependent_features) == 1:
prediction = prediction.reshape(-1, 1)
predicted_values = predicted_values.reshape(-1, 1)
jaqpot_row_id = int(jaqpot_row_id)
results = {
feature.key: int(prediction[jaqpot_row_id, i])
feature.key: int(predicted_values[jaqpot_row_id, i])
if isinstance(
prediction[jaqpot_row_id, i],
predicted_values[jaqpot_row_id, i],
(np.int16, np.int32, np.int64, np.longlong),
)
else float(prediction[jaqpot_row_id, i])
else float(predicted_values[jaqpot_row_id, i])
if isinstance(
prediction[jaqpot_row_id, i], (np.float16, np.float32, np.float64)
predicted_values[jaqpot_row_id, i], (np.float16, np.float32, np.float64)
)
else prediction[jaqpot_row_id, i]
else predicted_values[jaqpot_row_id, i]
for i, feature in enumerate(request.model.dependent_features)
}
results["jaqpotMetadata"] = {
Expand Down
109 changes: 53 additions & 56 deletions src/helpers/predict_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def calculate_doas(input_feed, request):
for doa_data in request.model.doas:
if doa_data.method == "LEVERAGE":
doa_method = Leverage()
doa_method.h_star = doa_data.data.hStar
doa_method.doa_matrix = doa_data.data.doaMatrix
doa_method.h_star = doa_data.data["hStar"]
doa_method.doa_matrix = doa_data.data["doaMatrix"]
elif doa_data.method == "BOUNDING_BOX":
doa_method = BoundingBox()
doa_method.bounding_box = doa_data.data.boundingBox
doa_method.bounding_box = doa_data.data["boundingBox"]
elif doa_data.method == "MEAN_VAR":
doa_method = MeanVar()
doa_method.bounds = doa_data.data.bounds
doa_method.bounds = doa_data.data["bounds"]
doa_instance_prediction[doa_method.__name__] = doa_method.predict(
pd.DataFrame(data_instance.values.reshape(1, -1))
)[0]
Expand All @@ -51,7 +51,7 @@ def calculate_doas(input_feed, request):
return doas_results


def predict_onnx(model, dataset: JaqpotpyDataset, request):
def predict_onnx(model, preprocessor, dataset: JaqpotpyDataset, request):
"""
Perform prediction using an ONNX model.
Parameters:
Expand All @@ -73,27 +73,44 @@ def predict_onnx(model, dataset: JaqpotpyDataset, request):
- The request dictionary should contain the necessary configuration for preprocessors and DOAS.
"""

sess = InferenceSession(model.SerializeToString())
if preprocessor:
onnx_graph = preprocessor
else:
onnx_graph = model

# prepare initial types for preprocessing
input_feed = {}
for independent_feature in model.graph.input:
for independent_feature in onnx_graph.graph.input:
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
independent_feature.type.tensor_type.elem_type
)

if len(model.graph.input) == 1:
if len(onnx_graph.graph.input) == 1:
input_feed[independent_feature.name] = dataset.X.values.astype(np_dtype)
else:
input_feed[independent_feature.name] = (
dataset.X[independent_feature.name]
.values.astype(np_dtype)
.reshape(-1, 1)
)
if preprocessor:
preprocessor_session = InferenceSession(preprocessor.SerializeToString())
input_feed = {"input": preprocessor_session.run(None, input_feed)[0]}

if request.model.doas:
doas_results = calculate_doas(input_feed, request)
else:
doas_results = None

onnx_prediction = sess.run(None, input_feed)[0]
model_session = InferenceSession(model.SerializeToString())
for independent_feature in model.graph.input:
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
independent_feature.type.tensor_type.elem_type
)

input_feed = {
model_session.get_inputs()[0].name: input_feed["input"].astype(np_dtype)
}
onnx_prediction = model_session.run(None, input_feed)

if request.model.extra_config["preprocessors"]:
for i in reversed(range(len(request.model.extra_config["preprocessors"]))):
Expand All @@ -108,56 +125,36 @@ def predict_onnx(model, dataset: JaqpotpyDataset, request):
len(request.model.dependent_features) == 1
and preprocessor_name != "LabelEncoder"
):
onnx_prediction = preprocessor_recreated.inverse_transform(
onnx_prediction.reshape(-1, 1)
onnx_prediction[0] = preprocessor_recreated.inverse_transform(
onnx_prediction[0].reshape(-1, 1)
)
onnx_prediction = preprocessor_recreated.inverse_transform(onnx_prediction)
onnx_prediction[0] = preprocessor_recreated.inverse_transform(
onnx_prediction[0]
)

if len(request.model.dependent_features) == 1:
onnx_prediction = onnx_prediction.flatten()

return onnx_prediction, doas_results


def predict_proba_onnx(model, dataset: JaqpotpyDataset, request):
"""
Predict the probability estimates for a given dataset using an ONNX model.
Parameters:
model (onnx.ModelProto): The ONNX model used for prediction.
dataset (JaqpotpyDataset): The dataset containing the features for prediction.
request (dict): A dictionary containing additional request information, including model configuration.
Returns:
list: A list of dictionaries where each dictionary contains the predicted probabilities for each class,
with class labels as keys and rounded probability values as values.
"""
onnx_prediction[0] = onnx_prediction[0].flatten()

sess = InferenceSession(model.SerializeToString())
input_feed = {}
for independent_feature in model.graph.input:
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
independent_feature.type.tensor_type.elem_type
)
if len(model.graph.input) == 1:
input_feed[independent_feature.name] = dataset.X.values.astype(np_dtype)
else:
input_feed[independent_feature.name] = (
dataset.X[independent_feature.name]
.values.astype(np_dtype)
.reshape(-1, 1)
)
onnx_probs = sess.run(None, input_feed)
# Probabilities estimation
probs_list = []
for instance in onnx_probs[1]:
rounded_instance = {k: round(v, 3) for k, v in instance.items()}
if (
request.model.extra_config["preprocessors"]
and request.model.extra_config["preprocessors"][0]["name"] == "LabelEncoder"
):
labels = request.model.extra_config["preprocessors"][0]["config"][
"classes_"
]
rounded_instance = {labels[k]: v for k, v in rounded_instance.items()}
if request.model.task.lower() in [
"binary_classification",
"multiclass_classification",
]:
for instance in onnx_prediction[1]:
rounded_instance = {k: round(v, 3) for k, v in instance.items()}
if (
request.model.extra_config["preprocessors"]
and request.model.extra_config["preprocessors"][0]["name"]
== "LabelEncoder"
):
labels = request.model.extra_config["preprocessors"][0]["config"][
"classes_"
]
rounded_instance = {labels[k]: v for k, v in rounded_instance.items()}

probs_list.append(rounded_instance)
probs_list.append(rounded_instance)
else:
probs_list = [None for _ in range(len(onnx_prediction[0]))]

return probs_list
return onnx_prediction[0], probs_list, doas_results
Loading