Skip to content

Commit

Permalink
feat(JAQPOT-386): onnx_sequence (#52)
Browse files Browse the repository at this point in the history
* predict_sequence_function

* .

* torch_utils_folder

* .

---------

Co-authored-by: Alex Arvanitidis <[email protected]>
  • Loading branch information
johnsaveus and alarv authored Dec 5, 2024
1 parent dbe3675 commit b49b4c9
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 58 deletions.
15 changes: 6 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import uvicorn
from fastapi import FastAPI
from jaqpotpy.api.openapi import PredictionRequest, PredictionResponse, ModelType

from src.handlers.predict_sklearn_onnx import sklearn_onnx_post_handler
from src.handlers.predict_torch import torch_post_handler

from src.handlers.predict_torch_geometric import torch_geometric_post_handler
from src.handlers.predict_torch_sequence import torch_sequence_post_handler
from src.loggers.logger import logger
from src.loggers.log_middleware import LogMiddleware

Expand All @@ -32,12 +31,10 @@ def predict(req: PredictionRequest) -> PredictionResponse:
match req.model.type:
case ModelType.SKLEARN_ONNX:
return sklearn_onnx_post_handler(req)
case (
ModelType.TORCH_GEOMETRIC_ONNX
| ModelType.TORCHSCRIPT
| ModelType.TORCH_SEQUENCE_ONNX
):
return torch_post_handler(req)
case ModelType.TORCH_SEQUENCE_ONNX:
return torch_sequence_post_handler(req)
case ModelType.TORCH_GEOMETRIC_ONNX | ModelType.TORCHSCRIPT:
return torch_geometric_post_handler(req)
case _:
raise Exception("Model type not supported")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import torch
import io
import numpy as np
import torch.nn.functional as f
from src.helpers.torch_utils import to_numpy, generate_prediction_response
from jaqpotpy.api.openapi import ModelType, PredictionRequest, PredictionResponse
from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer
from jaqpotpy.api.openapi.models.model_task import ModelTask


def torch_post_handler(request: PredictionRequest) -> PredictionResponse:
def torch_geometric_post_handler(request: PredictionRequest) -> PredictionResponse:
feat_config = request.model.torch_config
featurizer = _load_featurizer(feat_config)
target_name = request.model.dependent_features[0].name
Expand All @@ -22,15 +23,15 @@ def torch_post_handler(request: PredictionRequest) -> PredictionResponse:
raw_model, featurizer.featurize(inp["SMILES"])
)
predictions.append(
check_model_task(model_task, target_name, model_output, inp)
generate_prediction_response(model_task, target_name, model_output, inp)
)
elif request.model.type == ModelType.TORCHSCRIPT:
for inp in user_input:
model_output = torchscript_post_handler(
raw_model, featurizer.featurize(inp["SMILES"])
)
predictions.append(
check_model_task(model_task, target_name, model_output, inp)
generate_prediction_response(model_task, target_name, model_output, inp)
)
return PredictionResponse(predictions=predictions)

Expand All @@ -39,9 +40,9 @@ def torch_geometric_onnx_post_handler(raw_model, data):
onnx_model = base64.b64decode(raw_model)
ort_session = onnxruntime.InferenceSession(onnx_model)
ort_inputs = {
ort_session.get_inputs()[0].name: _to_numpy(data.x),
ort_session.get_inputs()[1].name: _to_numpy(data.edge_index),
ort_session.get_inputs()[2].name: _to_numpy(
ort_session.get_inputs()[0].name: to_numpy(data.x),
ort_session.get_inputs()[1].name: to_numpy(data.edge_index),
ort_session.get_inputs()[2].name: to_numpy(
torch.zeros(data.x.shape[0], dtype=torch.int64)
),
}
Expand All @@ -63,50 +64,8 @@ def torchscript_post_handler(raw_model, data):
return out


def _to_numpy(tensor):
return (
tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
)


def _load_featurizer(config):
featurizer = SmilesGraphFeaturizer()
featurizer.load_dict(config)
featurizer.sort_allowable_sets()
return featurizer


def graph_regression(target_name, output, inp):
pred = [output.squeeze().tolist()]
results = {"jaqpotMetadata": {"jaqpotRowId": inp["jaqpotRowId"]}}
if "jaqpotRowLabel" in inp:
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"]
results[target_name] = pred
return results


def graph_binary_classification(target_name, output, inp):
proba = f.sigmoid(output).squeeze().tolist()
pred = int(proba > 0.5)
# UI Results
results = {
"jaqpotMetadata": {
"probabilities": [round((1 - proba), 3), round(proba, 3)],
"jaqpotRowId": inp["jaqpotRowId"],
}
}
if "jaqpotRowLabel" in inp:
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"]
results[target_name] = pred
return results


def check_model_task(model_task, target_name, out, row_id):
if model_task == "BINARY_CLASSIFICATION":
return graph_binary_classification(target_name, out, row_id)
elif model_task == "REGRESSION":
return graph_regression(target_name, out, row_id)
else:
raise ValueError(
"Only BINARY_CLASSIFICATION and REGRESSION tasks are supported"
)
42 changes: 42 additions & 0 deletions src/handlers/predict_torch_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import base64
import onnxruntime
import torch
import io
import numpy as np
import torch.nn.functional as f
from src.helpers.torch_utils import to_numpy, generate_prediction_response
from jaqpotpy.descriptors.tokenizer import SmilesVectorizer
from jaqpotpy.api.openapi import ModelType, PredictionRequest, PredictionResponse
from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer


def torch_sequence_post_handler(request: PredictionRequest) -> PredictionResponse:
feat_config = request.model.torch_config
featurizer = _load_featurizer(feat_config)
target_name = request.model.dependent_features[0].name
model_task = request.model.task
user_input = request.dataset.input
raw_model = request.model.raw_model
predictions = []
for inp in user_input:
model_output = onnx_post_handler(
raw_model, featurizer.transform(featurizer.transform([inp["SMILES"]]))
)
predictions.append(
generate_prediction_response(model_task, target_name, model_output, inp)
)
return PredictionResponse(predictions=predictions)


def onnx_post_handler(raw_model, data):
onnx_model = base64.b64decode(raw_model)
ort_session = onnxruntime.InferenceSession(onnx_model)
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(data)}
ort_outs = torch.tensor(np.array(ort_session.run(None, ort_inputs)))
return ort_outs


def _load_featurizer(config):
featurizer = SmilesVectorizer()
featurizer.load_dict(config)
return featurizer
44 changes: 44 additions & 0 deletions src/helpers/torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch.nn.functional as f
from jaqpotpy.api.openapi.models.model_task import ModelTask


def to_numpy(tensor):
return (
tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
)


def torch_binary_classification(target_name, output, inp):
proba = f.sigmoid(output).squeeze().tolist()
pred = int(proba > 0.5)
# UI Results
results = {
"jaqpotMetadata": {
"probabilities": [round((1 - proba), 3), round(proba, 3)],
"jaqpotRowId": inp["jaqpotRowId"],
}
}
if "jaqpotRowLabel" in inp:
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"]
results[target_name] = pred
return results


def torch_regression(target_name, output, inp):
pred = [output.squeeze().tolist()]
results = {"jaqpotMetadata": {"jaqpotRowId": inp["jaqpotRowId"]}}
if "jaqpotRowLabel" in inp:
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"]
results[target_name] = pred
return results


def generate_prediction_response(model_task, target_name, out, row_id):
if model_task == ModelTask.BINARY_CLASSIFICATION:
return torch_binary_classification(target_name, out, row_id)
elif model_task == ModelTask.REGRESSION:
return torch_regression(target_name, out, row_id)
else:
raise ValueError(
"Only BINARY_CLASSIFICATION and REGRESSION tasks are supported"
)

0 comments on commit b49b4c9

Please sign in to comment.