Skip to content

Commit

Permalink
feat(JAQPOT-391): support two onnx models (#47)
Browse files Browse the repository at this point in the history
* feat: add preprocessor sklearn_post_handler in predict_onnx

* feat: support preprocessing through onnx

* feat: support doa and preprocessing

* fix: access doa data

* refactor: sklearn_post_handler

* refactor: predict_onnx
  • Loading branch information
vassilismin authored Nov 4, 2024
1 parent fa2c538 commit 7cafd3a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 68 deletions.
26 changes: 14 additions & 12 deletions src/handlers/predict_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,32 @@

def sklearn_post_handler(request: PredictionRequest) -> PredictionResponse:
model = model_decoder.decode(request.model.raw_model)
preprocessor = (
model_decoder.decode(request.model.raw_preprocessor)
if request.model.raw_preprocessor
else None
)
data_entry_all, jaqpot_row_ids = json_to_predreq.decode(request)
prediction, doa_predictions = predict_onnx(model, data_entry_all, request)
task = request.model.task.lower()
if task == "binary_classification" or task == "multiclass_classification":
probabilities = predict_proba_onnx(model, data_entry_all, request)
else:
probabilities = [None for _ in range(len(prediction))]
predicted_values, probabilities, doa_predictions = predict_onnx(
model, preprocessor, data_entry_all, request
)

predictions = []
for jaqpot_row_id in jaqpot_row_ids:
if len(request.model.dependent_features) == 1:
prediction = prediction.reshape(-1, 1)
predicted_values = predicted_values.reshape(-1, 1)
jaqpot_row_id = int(jaqpot_row_id)
results = {
feature.key: int(prediction[jaqpot_row_id, i])
feature.key: int(predicted_values[jaqpot_row_id, i])
if isinstance(
prediction[jaqpot_row_id, i],
predicted_values[jaqpot_row_id, i],
(np.int16, np.int32, np.int64, np.longlong),
)
else float(prediction[jaqpot_row_id, i])
else float(predicted_values[jaqpot_row_id, i])
if isinstance(
prediction[jaqpot_row_id, i], (np.float16, np.float32, np.float64)
predicted_values[jaqpot_row_id, i], (np.float16, np.float32, np.float64)
)
else prediction[jaqpot_row_id, i]
else predicted_values[jaqpot_row_id, i]
for i, feature in enumerate(request.model.dependent_features)
}
results["jaqpotMetadata"] = {
Expand Down
109 changes: 53 additions & 56 deletions src/helpers/predict_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def calculate_doas(input_feed, request):
for doa_data in request.model.doas:
if doa_data.method == "LEVERAGE":
doa_method = Leverage()
doa_method.h_star = doa_data.data.hStar
doa_method.doa_matrix = doa_data.data.doaMatrix
doa_method.h_star = doa_data.data["hStar"]
doa_method.doa_matrix = doa_data.data["doaMatrix"]
elif doa_data.method == "BOUNDING_BOX":
doa_method = BoundingBox()
doa_method.bounding_box = doa_data.data.boundingBox
doa_method.bounding_box = doa_data.data["boundingBox"]
elif doa_data.method == "MEAN_VAR":
doa_method = MeanVar()
doa_method.bounds = doa_data.data.bounds
doa_method.bounds = doa_data.data["bounds"]
doa_instance_prediction[doa_method.__name__] = doa_method.predict(
pd.DataFrame(data_instance.values.reshape(1, -1))
)[0]
Expand All @@ -51,7 +51,7 @@ def calculate_doas(input_feed, request):
return doas_results


def predict_onnx(model, dataset: JaqpotpyDataset, request):
def predict_onnx(model, preprocessor, dataset: JaqpotpyDataset, request):
"""
Perform prediction using an ONNX model.
Parameters:
Expand All @@ -73,27 +73,44 @@ def predict_onnx(model, dataset: JaqpotpyDataset, request):
- The request dictionary should contain the necessary configuration for preprocessors and DOAS.
"""

sess = InferenceSession(model.SerializeToString())
if preprocessor:
onnx_graph = preprocessor
else:
onnx_graph = model

# prepare initial types for preprocessing
input_feed = {}
for independent_feature in model.graph.input:
for independent_feature in onnx_graph.graph.input:
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
independent_feature.type.tensor_type.elem_type
)

if len(model.graph.input) == 1:
if len(onnx_graph.graph.input) == 1:
input_feed[independent_feature.name] = dataset.X.values.astype(np_dtype)
else:
input_feed[independent_feature.name] = (
dataset.X[independent_feature.name]
.values.astype(np_dtype)
.reshape(-1, 1)
)
if preprocessor:
preprocessor_session = InferenceSession(preprocessor.SerializeToString())
input_feed = {"input": preprocessor_session.run(None, input_feed)[0]}

if request.model.doas:
doas_results = calculate_doas(input_feed, request)
else:
doas_results = None

onnx_prediction = sess.run(None, input_feed)[0]
model_session = InferenceSession(model.SerializeToString())
for independent_feature in model.graph.input:
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
independent_feature.type.tensor_type.elem_type
)

input_feed = {
model_session.get_inputs()[0].name: input_feed["input"].astype(np_dtype)
}
onnx_prediction = model_session.run(None, input_feed)

if request.model.extra_config["preprocessors"]:
for i in reversed(range(len(request.model.extra_config["preprocessors"]))):
Expand All @@ -108,56 +125,36 @@ def predict_onnx(model, dataset: JaqpotpyDataset, request):
len(request.model.dependent_features) == 1
and preprocessor_name != "LabelEncoder"
):
onnx_prediction = preprocessor_recreated.inverse_transform(
onnx_prediction.reshape(-1, 1)
onnx_prediction[0] = preprocessor_recreated.inverse_transform(
onnx_prediction[0].reshape(-1, 1)
)
onnx_prediction = preprocessor_recreated.inverse_transform(onnx_prediction)
onnx_prediction[0] = preprocessor_recreated.inverse_transform(
onnx_prediction[0]
)

if len(request.model.dependent_features) == 1:
onnx_prediction = onnx_prediction.flatten()

return onnx_prediction, doas_results


def predict_proba_onnx(model, dataset: JaqpotpyDataset, request):
"""
Predict the probability estimates for a given dataset using an ONNX model.
Parameters:
model (onnx.ModelProto): The ONNX model used for prediction.
dataset (JaqpotpyDataset): The dataset containing the features for prediction.
request (dict): A dictionary containing additional request information, including model configuration.
Returns:
list: A list of dictionaries where each dictionary contains the predicted probabilities for each class,
with class labels as keys and rounded probability values as values.
"""
onnx_prediction[0] = onnx_prediction[0].flatten()

sess = InferenceSession(model.SerializeToString())
input_feed = {}
for independent_feature in model.graph.input:
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
independent_feature.type.tensor_type.elem_type
)
if len(model.graph.input) == 1:
input_feed[independent_feature.name] = dataset.X.values.astype(np_dtype)
else:
input_feed[independent_feature.name] = (
dataset.X[independent_feature.name]
.values.astype(np_dtype)
.reshape(-1, 1)
)
onnx_probs = sess.run(None, input_feed)
# Probabilities estimation
probs_list = []
for instance in onnx_probs[1]:
rounded_instance = {k: round(v, 3) for k, v in instance.items()}
if (
request.model.extra_config["preprocessors"]
and request.model.extra_config["preprocessors"][0]["name"] == "LabelEncoder"
):
labels = request.model.extra_config["preprocessors"][0]["config"][
"classes_"
]
rounded_instance = {labels[k]: v for k, v in rounded_instance.items()}
if request.model.task.lower() in [
"binary_classification",
"multiclass_classification",
]:
for instance in onnx_prediction[1]:
rounded_instance = {k: round(v, 3) for k, v in instance.items()}
if (
request.model.extra_config["preprocessors"]
and request.model.extra_config["preprocessors"][0]["name"]
== "LabelEncoder"
):
labels = request.model.extra_config["preprocessors"][0]["config"][
"classes_"
]
rounded_instance = {labels[k]: v for k, v in rounded_instance.items()}

probs_list.append(rounded_instance)
probs_list.append(rounded_instance)
else:
probs_list = [None for _ in range(len(onnx_prediction[0]))]

return probs_list
return onnx_prediction[0], probs_list, doas_results

0 comments on commit 7cafd3a

Please sign in to comment.