From b49b4c92822d87c382da55fcc6c43bd7e2566887 Mon Sep 17 00:00:00 2001 From: john savvas <149728671+johnsaveus@users.noreply.github.com> Date: Thu, 5 Dec 2024 14:12:22 +0200 Subject: [PATCH] feat(JAQPOT-386): onnx_sequence (#52) * predict_sequence_function * . * torch_utils_folder * . --------- Co-authored-by: Alex Arvanitidis --- main.py | 15 ++--- ...ct_torch.py => predict_torch_geometric.py} | 57 +++---------------- src/handlers/predict_torch_sequence.py | 42 ++++++++++++++ src/helpers/torch_utils.py | 44 ++++++++++++++ 4 files changed, 100 insertions(+), 58 deletions(-) rename src/handlers/{predict_torch.py => predict_torch_geometric.py} (54%) create mode 100644 src/handlers/predict_torch_sequence.py create mode 100644 src/helpers/torch_utils.py diff --git a/main.py b/main.py index 0d57c57..81e4a26 100644 --- a/main.py +++ b/main.py @@ -9,10 +9,9 @@ import uvicorn from fastapi import FastAPI from jaqpotpy.api.openapi import PredictionRequest, PredictionResponse, ModelType - from src.handlers.predict_sklearn_onnx import sklearn_onnx_post_handler -from src.handlers.predict_torch import torch_post_handler - +from src.handlers.predict_torch_geometric import torch_geometric_post_handler +from src.handlers.predict_torch_sequence import torch_sequence_post_handler from src.loggers.logger import logger from src.loggers.log_middleware import LogMiddleware @@ -32,12 +31,10 @@ def predict(req: PredictionRequest) -> PredictionResponse: match req.model.type: case ModelType.SKLEARN_ONNX: return sklearn_onnx_post_handler(req) - case ( - ModelType.TORCH_GEOMETRIC_ONNX - | ModelType.TORCHSCRIPT - | ModelType.TORCH_SEQUENCE_ONNX - ): - return torch_post_handler(req) + case ModelType.TORCH_SEQUENCE_ONNX: + return torch_sequence_post_handler(req) + case ModelType.TORCH_GEOMETRIC_ONNX | ModelType.TORCHSCRIPT: + return torch_geometric_post_handler(req) case _: raise Exception("Model type not supported") diff --git a/src/handlers/predict_torch.py b/src/handlers/predict_torch_geometric.py similarity index 54% rename from src/handlers/predict_torch.py rename to src/handlers/predict_torch_geometric.py index 9bf8b76..cecc27d 100644 --- a/src/handlers/predict_torch.py +++ b/src/handlers/predict_torch_geometric.py @@ -3,12 +3,13 @@ import torch import io import numpy as np -import torch.nn.functional as f +from src.helpers.torch_utils import to_numpy, generate_prediction_response from jaqpotpy.api.openapi import ModelType, PredictionRequest, PredictionResponse from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer +from jaqpotpy.api.openapi.models.model_task import ModelTask -def torch_post_handler(request: PredictionRequest) -> PredictionResponse: +def torch_geometric_post_handler(request: PredictionRequest) -> PredictionResponse: feat_config = request.model.torch_config featurizer = _load_featurizer(feat_config) target_name = request.model.dependent_features[0].name @@ -22,7 +23,7 @@ def torch_post_handler(request: PredictionRequest) -> PredictionResponse: raw_model, featurizer.featurize(inp["SMILES"]) ) predictions.append( - check_model_task(model_task, target_name, model_output, inp) + generate_prediction_response(model_task, target_name, model_output, inp) ) elif request.model.type == ModelType.TORCHSCRIPT: for inp in user_input: @@ -30,7 +31,7 @@ def torch_post_handler(request: PredictionRequest) -> PredictionResponse: raw_model, featurizer.featurize(inp["SMILES"]) ) predictions.append( - check_model_task(model_task, target_name, model_output, inp) + generate_prediction_response(model_task, target_name, model_output, inp) ) return PredictionResponse(predictions=predictions) @@ -39,9 +40,9 @@ def torch_geometric_onnx_post_handler(raw_model, data): onnx_model = base64.b64decode(raw_model) ort_session = onnxruntime.InferenceSession(onnx_model) ort_inputs = { - ort_session.get_inputs()[0].name: _to_numpy(data.x), - ort_session.get_inputs()[1].name: _to_numpy(data.edge_index), - ort_session.get_inputs()[2].name: _to_numpy( + ort_session.get_inputs()[0].name: to_numpy(data.x), + ort_session.get_inputs()[1].name: to_numpy(data.edge_index), + ort_session.get_inputs()[2].name: to_numpy( torch.zeros(data.x.shape[0], dtype=torch.int64) ), } @@ -63,50 +64,8 @@ def torchscript_post_handler(raw_model, data): return out -def _to_numpy(tensor): - return ( - tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() - ) - - def _load_featurizer(config): featurizer = SmilesGraphFeaturizer() featurizer.load_dict(config) featurizer.sort_allowable_sets() return featurizer - - -def graph_regression(target_name, output, inp): - pred = [output.squeeze().tolist()] - results = {"jaqpotMetadata": {"jaqpotRowId": inp["jaqpotRowId"]}} - if "jaqpotRowLabel" in inp: - results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"] - results[target_name] = pred - return results - - -def graph_binary_classification(target_name, output, inp): - proba = f.sigmoid(output).squeeze().tolist() - pred = int(proba > 0.5) - # UI Results - results = { - "jaqpotMetadata": { - "probabilities": [round((1 - proba), 3), round(proba, 3)], - "jaqpotRowId": inp["jaqpotRowId"], - } - } - if "jaqpotRowLabel" in inp: - results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"] - results[target_name] = pred - return results - - -def check_model_task(model_task, target_name, out, row_id): - if model_task == "BINARY_CLASSIFICATION": - return graph_binary_classification(target_name, out, row_id) - elif model_task == "REGRESSION": - return graph_regression(target_name, out, row_id) - else: - raise ValueError( - "Only BINARY_CLASSIFICATION and REGRESSION tasks are supported" - ) diff --git a/src/handlers/predict_torch_sequence.py b/src/handlers/predict_torch_sequence.py new file mode 100644 index 0000000..9775c32 --- /dev/null +++ b/src/handlers/predict_torch_sequence.py @@ -0,0 +1,42 @@ +import base64 +import onnxruntime +import torch +import io +import numpy as np +import torch.nn.functional as f +from src.helpers.torch_utils import to_numpy, generate_prediction_response +from jaqpotpy.descriptors.tokenizer import SmilesVectorizer +from jaqpotpy.api.openapi import ModelType, PredictionRequest, PredictionResponse +from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer + + +def torch_sequence_post_handler(request: PredictionRequest) -> PredictionResponse: + feat_config = request.model.torch_config + featurizer = _load_featurizer(feat_config) + target_name = request.model.dependent_features[0].name + model_task = request.model.task + user_input = request.dataset.input + raw_model = request.model.raw_model + predictions = [] + for inp in user_input: + model_output = onnx_post_handler( + raw_model, featurizer.transform(featurizer.transform([inp["SMILES"]])) + ) + predictions.append( + generate_prediction_response(model_task, target_name, model_output, inp) + ) + return PredictionResponse(predictions=predictions) + + +def onnx_post_handler(raw_model, data): + onnx_model = base64.b64decode(raw_model) + ort_session = onnxruntime.InferenceSession(onnx_model) + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(data)} + ort_outs = torch.tensor(np.array(ort_session.run(None, ort_inputs))) + return ort_outs + + +def _load_featurizer(config): + featurizer = SmilesVectorizer() + featurizer.load_dict(config) + return featurizer diff --git a/src/helpers/torch_utils.py b/src/helpers/torch_utils.py new file mode 100644 index 0000000..5e8e027 --- /dev/null +++ b/src/helpers/torch_utils.py @@ -0,0 +1,44 @@ +import torch.nn.functional as f +from jaqpotpy.api.openapi.models.model_task import ModelTask + + +def to_numpy(tensor): + return ( + tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + ) + + +def torch_binary_classification(target_name, output, inp): + proba = f.sigmoid(output).squeeze().tolist() + pred = int(proba > 0.5) + # UI Results + results = { + "jaqpotMetadata": { + "probabilities": [round((1 - proba), 3), round(proba, 3)], + "jaqpotRowId": inp["jaqpotRowId"], + } + } + if "jaqpotRowLabel" in inp: + results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"] + results[target_name] = pred + return results + + +def torch_regression(target_name, output, inp): + pred = [output.squeeze().tolist()] + results = {"jaqpotMetadata": {"jaqpotRowId": inp["jaqpotRowId"]}} + if "jaqpotRowLabel" in inp: + results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"] + results[target_name] = pred + return results + + +def generate_prediction_response(model_task, target_name, out, row_id): + if model_task == ModelTask.BINARY_CLASSIFICATION: + return torch_binary_classification(target_name, out, row_id) + elif model_task == ModelTask.REGRESSION: + return torch_regression(target_name, out, row_id) + else: + raise ValueError( + "Only BINARY_CLASSIFICATION and REGRESSION tasks are supported" + )