diff --git a/src/helpers/predict_methods.py b/src/helpers/predict_methods.py index 7259ac9..c19cf1e 100644 --- a/src/helpers/predict_methods.py +++ b/src/helpers/predict_methods.py @@ -14,6 +14,8 @@ def predict_onnx(model, dataset: JaqpotpyDataset, request): # onnx_prediction is being reshaped to a 2D array to avoid errors # when the model has only one dependent feature. In multi-output models, # onnx_prediction is already a 2D array. + else: + onnx_prediction = onnx_prediction[0] if request.model["extraConfig"]["preprocessors"]: for i in reversed(range(len(request.model["extraConfig"]["preprocessors"]))): @@ -28,9 +30,8 @@ def predict_onnx(model, dataset: JaqpotpyDataset, request): if len(request.model["dependentFeatures"]) == 1: onnx_prediction = onnx_prediction.flatten() - return onnx_prediction - else: - return onnx_prediction[0] + + return onnx_prediction def predict_proba_onnx(model, dataset: JaqpotpyDataset, request):