Skip to content

Commit

Permalink
feat(JAQPOT-334): add jaqpot internal ID (#27)
Browse files Browse the repository at this point in the history
* chore: delete comment

* feat: return all probabilities instead of the max

* feat: return jaqpotInternalId with the dataset

* refactor: predict_sklearn for jaqpotInternalId

* chore: rename to jaqpotInternalMetadata

* feat: return a predictions dict
  • Loading branch information
vassilismin authored Oct 11, 2024
1 parent fa455e8 commit 8cce802
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 18 deletions.
41 changes: 26 additions & 15 deletions src/handlers/predict_sklearn.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
from ..entities.prediction_request import PredictionRequestPydantic
from ..helpers import model_decoder, json_to_predreq
from ..helpers.predict_methods import predict_onnx, predict_proba_onnx
import numpy as np


def sklearn_post_handler(request: PredictionRequestPydantic):
model = model_decoder.decode(request.model["rawModel"])
data_entry_all = json_to_predreq.decode(request)
data_entry_all, JaqpotInternalId = json_to_predreq.decode(request)
prediction = predict_onnx(model, data_entry_all, request)
task = request.model["task"].lower()
if task == "binary_classification" or task == "multiclass_classification":
probabilities = predict_proba_onnx(model, data_entry_all, request)
else:
probabilities = [None for _ in range(len(prediction))]

results = {}
for i, feature in enumerate(request.model["dependentFeatures"]):
key = feature["key"]
values = [
str(item[i]) if len(request.model["dependentFeatures"]) > 1 else str(item)
for item in prediction
]
results[key] = values
final_all = []
for jaqpot_id in JaqpotInternalId:
if len(request.model["dependentFeatures"]) == 1:
prediction = prediction.reshape(-1, 1)
jaqpot_id = int(jaqpot_id)
results = {
feature["key"]: int(prediction[jaqpot_id, i])
if isinstance(
prediction[jaqpot_id, i], (np.int16, np.int32, np.int64, np.longlong)
)
else float(prediction[jaqpot_id, i])
if isinstance(
prediction[jaqpot_id, i], (np.float16, np.float32, np.float64)
)
else prediction[jaqpot_id, i]
for i, feature in enumerate(request.model["dependentFeatures"])
}
results["jaqpotInternalId"] = jaqpot_id
results["jaqpotInternalMetadata"] = {
"AD": None,
"Probabilities": probabilities[jaqpot_id],
}
final_all.append(results)

results["Probabilities"] = [str(prob) for prob in probabilities]
results["AD"] = [None for _ in range(len(prediction))]

final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]}

return final_all
return {"predictions": final_all}
6 changes: 4 additions & 2 deletions src/helpers/json_to_predreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

def decode(request):
df = pd.DataFrame(request.dataset["input"])
jaqpot_Internal_Id = []
for i in range(len(df)):
jaqpot_Internal_Id.append(df.iloc[i]["jaqpotInternalId"])
independent_features = request.model["independentFeatures"]
# smiles_cols = [feature['key'] for feature in independent_features if feature['featureType'] == 'SMILES']
smiles_cols = [
feature["key"]
for feature in independent_features
Expand Down Expand Up @@ -34,4 +36,4 @@ def decode(request):
task=request.model["task"].lower(),
featurizer=featurizers,
)
return dataset
return dataset, jaqpot_Internal_Id
2 changes: 1 addition & 1 deletion src/helpers/predict_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ def predict_proba_onnx(model, dataset: JaqpotpyDataset, request):
)
onnx_probs = sess.run(None, input_feed)
onnx_probs_list = [
max(onnx_probs[1][instance].values()) for instance in range(len(onnx_probs[1]))
onnx_probs[1][instance] for instance in range(len(onnx_probs[1]))
]
return onnx_probs_list

0 comments on commit 8cce802

Please sign in to comment.