Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(JAQPOT-386): onnx_sequence #52

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import uvicorn
from fastapi import FastAPI
from jaqpotpy.api.openapi import PredictionRequest, PredictionResponse, ModelType

from src.handlers.predict_sklearn_onnx import sklearn_onnx_post_handler
from src.handlers.predict_torch import torch_post_handler

from src.handlers.predict_torch_geometric import torch_geometric_post_handler
from src.handlers.predict_torch_sequence import torch_sequence_post_handler
from src.loggers.logger import logger
from src.loggers.log_middleware import LogMiddleware

Expand All @@ -32,12 +31,10 @@ def predict(req: PredictionRequest) -> PredictionResponse:
match req.model.type:
case ModelType.SKLEARN_ONNX:
return sklearn_onnx_post_handler(req)
case (
ModelType.TORCH_GEOMETRIC_ONNX
| ModelType.TORCHSCRIPT
| ModelType.TORCH_SEQUENCE_ONNX
):
return torch_post_handler(req)
case ModelType.TORCH_SEQUENCE_ONNX:
return torch_sequence_post_handler(req)
case ModelType.TORCH_GEOMETRIC_ONNX | ModelType.TORCHSCRIPT:
return torch_geometric_post_handler(req)
case _:
raise Exception("Model type not supported")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import torch
import io
import numpy as np
import torch.nn.functional as f
from src.helpers.torch_utils import to_numpy, generate_prediction_response
from jaqpotpy.api.openapi import ModelType, PredictionRequest, PredictionResponse
from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer
from jaqpotpy.api.openapi.models.model_task import ModelTask


def torch_post_handler(request: PredictionRequest) -> PredictionResponse:
def torch_geometric_post_handler(request: PredictionRequest) -> PredictionResponse:
feat_config = request.model.torch_config
featurizer = _load_featurizer(feat_config)
target_name = request.model.dependent_features[0].name
Expand All @@ -22,15 +23,15 @@ def torch_post_handler(request: PredictionRequest) -> PredictionResponse:
raw_model, featurizer.featurize(inp["SMILES"])
)
predictions.append(
check_model_task(model_task, target_name, model_output, inp)
generate_prediction_response(model_task, target_name, model_output, inp)
)
elif request.model.type == ModelType.TORCHSCRIPT:
for inp in user_input:
model_output = torchscript_post_handler(
raw_model, featurizer.featurize(inp["SMILES"])
)
predictions.append(
check_model_task(model_task, target_name, model_output, inp)
generate_prediction_response(model_task, target_name, model_output, inp)
)
return PredictionResponse(predictions=predictions)

Expand All @@ -39,9 +40,9 @@ def torch_geometric_onnx_post_handler(raw_model, data):
onnx_model = base64.b64decode(raw_model)
ort_session = onnxruntime.InferenceSession(onnx_model)
ort_inputs = {
ort_session.get_inputs()[0].name: _to_numpy(data.x),
ort_session.get_inputs()[1].name: _to_numpy(data.edge_index),
ort_session.get_inputs()[2].name: _to_numpy(
ort_session.get_inputs()[0].name: to_numpy(data.x),
ort_session.get_inputs()[1].name: to_numpy(data.edge_index),
ort_session.get_inputs()[2].name: to_numpy(
torch.zeros(data.x.shape[0], dtype=torch.int64)
),
}
Expand All @@ -63,50 +64,8 @@ def torchscript_post_handler(raw_model, data):
return out


def _to_numpy(tensor):
return (
tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
)


def _load_featurizer(config):
featurizer = SmilesGraphFeaturizer()
featurizer.load_dict(config)
featurizer.sort_allowable_sets()
return featurizer


def graph_regression(target_name, output, inp):
pred = [output.squeeze().tolist()]
results = {"jaqpotMetadata": {"jaqpotRowId": inp["jaqpotRowId"]}}
if "jaqpotRowLabel" in inp:
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"]
results[target_name] = pred
return results


def graph_binary_classification(target_name, output, inp):
proba = f.sigmoid(output).squeeze().tolist()
pred = int(proba > 0.5)
# UI Results
results = {
"jaqpotMetadata": {
"probabilities": [round((1 - proba), 3), round(proba, 3)],
"jaqpotRowId": inp["jaqpotRowId"],
}
}
if "jaqpotRowLabel" in inp:
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"]
results[target_name] = pred
return results


def check_model_task(model_task, target_name, out, row_id):
if model_task == "BINARY_CLASSIFICATION":
return graph_binary_classification(target_name, out, row_id)
elif model_task == "REGRESSION":
return graph_regression(target_name, out, row_id)
else:
raise ValueError(
"Only BINARY_CLASSIFICATION and REGRESSION tasks are supported"
)
42 changes: 42 additions & 0 deletions src/handlers/predict_torch_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import base64
import onnxruntime
import torch
import io
import numpy as np
import torch.nn.functional as f
from src.helpers.torch_utils import to_numpy, generate_prediction_response
from jaqpotpy.descriptors.tokenizer import SmilesVectorizer
from jaqpotpy.api.openapi import ModelType, PredictionRequest, PredictionResponse
from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer


def torch_sequence_post_handler(request: PredictionRequest) -> PredictionResponse:
feat_config = request.model.torch_config
featurizer = _load_featurizer(feat_config)
target_name = request.model.dependent_features[0].name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it hard to allow multiple y target names instead of always hardcoding the 1 target_name by getting the dependent_features[0] here?

model_task = request.model.task
user_input = request.dataset.input
raw_model = request.model.raw_model
predictions = []
for inp in user_input:
model_output = onnx_post_handler(
raw_model, featurizer.transform(featurizer.transform([inp["SMILES"]]))
)
predictions.append(
generate_prediction_response(model_task, target_name, model_output, inp)
)
return PredictionResponse(predictions=predictions)


def onnx_post_handler(raw_model, data):
onnx_model = base64.b64decode(raw_model)
ort_session = onnxruntime.InferenceSession(onnx_model)
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(data)}
ort_outs = torch.tensor(np.array(ort_session.run(None, ort_inputs)))
return ort_outs


def _load_featurizer(config):
featurizer = SmilesVectorizer()
featurizer.load_dict(config)
return featurizer
44 changes: 44 additions & 0 deletions src/helpers/torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch.nn.functional as f
from jaqpotpy.api.openapi.models.model_task import ModelTask


def to_numpy(tensor):
return (
tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
)


def torch_binary_classification(target_name, output, inp):
proba = f.sigmoid(output).squeeze().tolist()
pred = int(proba > 0.5)
# UI Results
results = {
"jaqpotMetadata": {
"probabilities": [round((1 - proba), 3), round(proba, 3)],
"jaqpotRowId": inp["jaqpotRowId"],
}
}
if "jaqpotRowLabel" in inp:
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"]
results[target_name] = pred
return results


def torch_regression(target_name, output, inp):
pred = [output.squeeze().tolist()]
results = {"jaqpotMetadata": {"jaqpotRowId": inp["jaqpotRowId"]}}
if "jaqpotRowLabel" in inp:
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"]
results[target_name] = pred
return results


def generate_prediction_response(model_task, target_name, out, row_id):
if model_task == ModelTask.BINARY_CLASSIFICATION:
return torch_binary_classification(target_name, out, row_id)
elif model_task == ModelTask.REGRESSION:
return torch_regression(target_name, out, row_id)
else:
raise ValueError(
"Only BINARY_CLASSIFICATION and REGRESSION tasks are supported"
)
Loading