From 6a2b4de6150b8c3f279c439d085c0292ae101097 Mon Sep 17 00:00:00 2001 From: john savvas <149728671+johnsaveus@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:10:21 +0300 Subject: [PATCH] inference (#17) handling regression --- src/handlers/predict.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/handlers/predict.py b/src/handlers/predict.py index 6b88ee5..ff4d154 100644 --- a/src/handlers/predict.py +++ b/src/handlers/predict.py @@ -63,8 +63,24 @@ def to_numpy(tensor): ), } ort_outs = torch.tensor(np.array(ort_session.run(None, ort_inputs))) + if request.model["task"] == "BINARY_CLASSIFICATION": + return graph_binary_classification(request, ort_outs) + elif request.model["task"] == "REGRESSION": + return graph_regression(request, ort_outs) + else: + raise ValueError( + "Only BINARY_CLASSIFICATION and REGRESSION tasks are supported" + ) + - return graph_binary_classification(request, ort_outs) +def graph_regression(request: PredictionRequestPydantic, onnx_output): + target_name = request.model["dependentFeatures"][0]["name"] + preds = [onnx_output.squeeze().tolist()] + results = {} + results[target_name] = [str(pred) for pred in preds] + final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]} + print(final_all) + return final_all def graph_binary_classification(request: PredictionRequestPydantic, onnx_output): @@ -78,5 +94,4 @@ def graph_binary_classification(request: PredictionRequestPydantic, onnx_output) results[target_name] = [str(pred) for pred in preds] final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]} print(final_all) - return final_all