Skip to content

Commit

Permalink
feat(JAQPOT-361): support jaqpot metadata (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
alarv authored Oct 16, 2024
1 parent bb1e345 commit 561ccc7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
20 changes: 10 additions & 10 deletions src/handlers/predict_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def sklearn_post_handler(request: PredictionRequestPydantic):
model = model_decoder.decode(request.model["rawModel"])
data_entry_all, JaqpotInternalId = json_to_predreq.decode(request)
data_entry_all, jaqpot_row_ids = json_to_predreq.decode(request)
prediction = predict_onnx(model, data_entry_all, request)
task = request.model["task"].lower()
if task == "binary_classification" or task == "multiclass_classification":
Expand All @@ -15,26 +15,26 @@ def sklearn_post_handler(request: PredictionRequestPydantic):
probabilities = [None for _ in range(len(prediction))]

final_all = []
for jaqpot_id in JaqpotInternalId:
for jaqpot_row_id in jaqpot_row_ids:
if len(request.model["dependentFeatures"]) == 1:
prediction = prediction.reshape(-1, 1)
jaqpot_id = int(jaqpot_id)
jaqpot_row_id = int(jaqpot_row_id)
results = {
feature["key"]: int(prediction[jaqpot_id, i])
feature["key"]: int(prediction[jaqpot_row_id, i])
if isinstance(
prediction[jaqpot_id, i], (np.int16, np.int32, np.int64, np.longlong)
prediction[jaqpot_row_id, i], (np.int16, np.int32, np.int64, np.longlong)
)
else float(prediction[jaqpot_id, i])
else float(prediction[jaqpot_row_id, i])
if isinstance(
prediction[jaqpot_id, i], (np.float16, np.float32, np.float64)
prediction[jaqpot_row_id, i], (np.float16, np.float32, np.float64)
)
else prediction[jaqpot_id, i]
else prediction[jaqpot_row_id, i]
for i, feature in enumerate(request.model["dependentFeatures"])
}
results["jaqpotInternalId"] = jaqpot_id
results["jaqpotInternalMetadata"] = {
"AD": None,
"Probabilities": probabilities[jaqpot_id],
"probabilities": probabilities[jaqpot_row_id],
"jaqpotRowId": jaqpot_row_id
}
final_all.append(results)

Expand Down
6 changes: 3 additions & 3 deletions src/helpers/json_to_predreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

def decode(request):
df = pd.DataFrame(request.dataset["input"])
jaqpot_Internal_Id = []
jaqpot_row_ids = []
for i in range(len(df)):
jaqpot_Internal_Id.append(df.iloc[i]["jaqpotInternalId"])
jaqpot_row_ids.append(df.iloc[i]["jaqpotRowId"])
independent_features = request.model["independentFeatures"]
smiles_cols = [
feature["key"]
Expand Down Expand Up @@ -36,4 +36,4 @@ def decode(request):
task=request.model["task"].lower(),
featurizer=featurizers,
)
return dataset, jaqpot_Internal_Id
return dataset, jaqpot_row_ids

0 comments on commit 561ccc7

Please sign in to comment.