From d85af6a92ef7dec88ee3a23dccfb6551105d16ef Mon Sep 17 00:00:00 2001 From: john savvas Date: Thu, 5 Dec 2024 13:47:24 +0200 Subject: [PATCH] . --- src/handlers/predict_torch_geometric.py | 6 +++--- src/handlers/predict_torch_sequence.py | 6 ++++-- src/helpers/torch_utils.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/handlers/predict_torch_geometric.py b/src/handlers/predict_torch_geometric.py index f621e46..cecc27d 100644 --- a/src/handlers/predict_torch_geometric.py +++ b/src/handlers/predict_torch_geometric.py @@ -3,7 +3,7 @@ import torch import io import numpy as np -from src.helpers.torch_utils import to_numpy, check_model_task +from src.helpers.torch_utils import to_numpy, generate_prediction_response from jaqpotpy.api.openapi import ModelType, PredictionRequest, PredictionResponse from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer from jaqpotpy.api.openapi.models.model_task import ModelTask @@ -23,7 +23,7 @@ def torch_geometric_post_handler(request: PredictionRequest) -> PredictionRespon raw_model, featurizer.featurize(inp["SMILES"]) ) predictions.append( - check_model_task(model_task, target_name, model_output, inp) + generate_prediction_response(model_task, target_name, model_output, inp) ) elif request.model.type == ModelType.TORCHSCRIPT: for inp in user_input: @@ -31,7 +31,7 @@ def torch_geometric_post_handler(request: PredictionRequest) -> PredictionRespon raw_model, featurizer.featurize(inp["SMILES"]) ) predictions.append( - check_model_task(model_task, target_name, model_output, inp) + generate_prediction_response(model_task, target_name, model_output, inp) ) return PredictionResponse(predictions=predictions) diff --git a/src/handlers/predict_torch_sequence.py b/src/handlers/predict_torch_sequence.py index 59743ac..9775c32 100644 --- a/src/handlers/predict_torch_sequence.py +++ b/src/handlers/predict_torch_sequence.py @@ -4,7 +4,7 @@ import io import numpy as np import torch.nn.functional as f -from src.helpers.torch_utils import to_numpy, check_model_task +from src.helpers.torch_utils import to_numpy, generate_prediction_response from jaqpotpy.descriptors.tokenizer import SmilesVectorizer from jaqpotpy.api.openapi import ModelType, PredictionRequest, PredictionResponse from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer @@ -22,7 +22,9 @@ def torch_sequence_post_handler(request: PredictionRequest) -> PredictionRespons model_output = onnx_post_handler( raw_model, featurizer.transform(featurizer.transform([inp["SMILES"]])) ) - predictions.append(check_model_task(model_task, target_name, model_output, inp)) + predictions.append( + generate_prediction_response(model_task, target_name, model_output, inp) + ) return PredictionResponse(predictions=predictions) diff --git a/src/helpers/torch_utils.py b/src/helpers/torch_utils.py index d8a9084..5e8e027 100644 --- a/src/helpers/torch_utils.py +++ b/src/helpers/torch_utils.py @@ -33,7 +33,7 @@ def torch_regression(target_name, output, inp): return results -def check_model_task(model_task, target_name, out, row_id): +def generate_prediction_response(model_task, target_name, out, row_id): if model_task == ModelTask.BINARY_CLASSIFICATION: return torch_binary_classification(target_name, out, row_id) elif model_task == ModelTask.REGRESSION: