Skip to content

Commit

Permalink
Merge branch 'main' into feat/JAQPOT-238/support_onnx_models
Browse files Browse the repository at this point in the history
  • Loading branch information
vassilismin authored Sep 17, 2024
2 parents b6933ff + 6a2b4de commit 3b8e795
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions src/handlers/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,24 @@ def to_numpy(tensor):
),
}
ort_outs = torch.tensor(np.array(ort_session.run(None, ort_inputs)))
if request.model["task"] == "BINARY_CLASSIFICATION":
return graph_binary_classification(request, ort_outs)
elif request.model["task"] == "REGRESSION":
return graph_regression(request, ort_outs)
else:
raise ValueError(
"Only BINARY_CLASSIFICATION and REGRESSION tasks are supported"
)


return graph_binary_classification(request, ort_outs)
def graph_regression(request: PredictionRequestPydantic, onnx_output):
target_name = request.model["dependentFeatures"][0]["name"]
preds = [onnx_output.squeeze().tolist()]
results = {}
results[target_name] = [str(pred) for pred in preds]
final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]}
print(final_all)
return final_all


def graph_binary_classification(request: PredictionRequestPydantic, onnx_output):
Expand All @@ -86,5 +102,4 @@ def graph_binary_classification(request: PredictionRequestPydantic, onnx_output)
results[target_name] = [str(pred) for pred in preds]
final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]}
print(final_all)

return final_all

0 comments on commit 3b8e795

Please sign in to comment.