From 8cce80248e60fefde01bdcbc5c0e528680a8f1e8 Mon Sep 17 00:00:00 2001 From: Vassilis Minadakis <56068291+vassilismin@users.noreply.github.com> Date: Fri, 11 Oct 2024 11:38:45 +0300 Subject: [PATCH] feat(JAQPOT-334): add jaqpot internal ID (#27) * chore: delete comment * feat: return all probabilities instead of the max * feat: return jaqpotInternalId with the dataset * refactor: predict_sklearn for jaqpotInternalId * chore: rename to jaqpotInternalMetadata * feat: return a predictions dict --- src/handlers/predict_sklearn.py | 41 +++++++++++++++++++++------------ src/helpers/json_to_predreq.py | 6 +++-- src/helpers/predict_methods.py | 2 +- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/handlers/predict_sklearn.py b/src/handlers/predict_sklearn.py index 426e085..6b3a5ec 100644 --- a/src/handlers/predict_sklearn.py +++ b/src/handlers/predict_sklearn.py @@ -1,11 +1,12 @@ from ..entities.prediction_request import PredictionRequestPydantic from ..helpers import model_decoder, json_to_predreq from ..helpers.predict_methods import predict_onnx, predict_proba_onnx +import numpy as np def sklearn_post_handler(request: PredictionRequestPydantic): model = model_decoder.decode(request.model["rawModel"]) - data_entry_all = json_to_predreq.decode(request) + data_entry_all, JaqpotInternalId = json_to_predreq.decode(request) prediction = predict_onnx(model, data_entry_all, request) task = request.model["task"].lower() if task == "binary_classification" or task == "multiclass_classification": @@ -13,18 +14,28 @@ def sklearn_post_handler(request: PredictionRequestPydantic): else: probabilities = [None for _ in range(len(prediction))] - results = {} - for i, feature in enumerate(request.model["dependentFeatures"]): - key = feature["key"] - values = [ - str(item[i]) if len(request.model["dependentFeatures"]) > 1 else str(item) - for item in prediction - ] - results[key] = values + final_all = [] + for jaqpot_id in JaqpotInternalId: + if len(request.model["dependentFeatures"]) == 1: + prediction = prediction.reshape(-1, 1) + jaqpot_id = int(jaqpot_id) + results = { + feature["key"]: int(prediction[jaqpot_id, i]) + if isinstance( + prediction[jaqpot_id, i], (np.int16, np.int32, np.int64, np.longlong) + ) + else float(prediction[jaqpot_id, i]) + if isinstance( + prediction[jaqpot_id, i], (np.float16, np.float32, np.float64) + ) + else prediction[jaqpot_id, i] + for i, feature in enumerate(request.model["dependentFeatures"]) + } + results["jaqpotInternalId"] = jaqpot_id + results["jaqpotInternalMetadata"] = { + "AD": None, + "Probabilities": probabilities[jaqpot_id], + } + final_all.append(results) - results["Probabilities"] = [str(prob) for prob in probabilities] - results["AD"] = [None for _ in range(len(prediction))] - - final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]} - - return final_all + return {"predictions": final_all} diff --git a/src/helpers/json_to_predreq.py b/src/helpers/json_to_predreq.py index c2d94c2..16306ae 100644 --- a/src/helpers/json_to_predreq.py +++ b/src/helpers/json_to_predreq.py @@ -5,8 +5,10 @@ def decode(request): df = pd.DataFrame(request.dataset["input"]) + jaqpot_Internal_Id = [] + for i in range(len(df)): + jaqpot_Internal_Id.append(df.iloc[i]["jaqpotInternalId"]) independent_features = request.model["independentFeatures"] - # smiles_cols = [feature['key'] for feature in independent_features if feature['featureType'] == 'SMILES'] smiles_cols = [ feature["key"] for feature in independent_features @@ -34,4 +36,4 @@ def decode(request): task=request.model["task"].lower(), featurizer=featurizers, ) - return dataset + return dataset, jaqpot_Internal_Id diff --git a/src/helpers/predict_methods.py b/src/helpers/predict_methods.py index beaba52..71140e4 100644 --- a/src/helpers/predict_methods.py +++ b/src/helpers/predict_methods.py @@ -61,6 +61,6 @@ def predict_proba_onnx(model, dataset: JaqpotpyDataset, request): ) onnx_probs = sess.run(None, input_feed) onnx_probs_list = [ - max(onnx_probs[1][instance].values()) for instance in range(len(onnx_probs[1])) + onnx_probs[1][instance] for instance in range(len(onnx_probs[1])) ] return onnx_probs_list