-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(JAQPOT-386): onnx_sequence (#52)
* predict_sequence_function * . * torch_utils_folder * . --------- Co-authored-by: Alex Arvanitidis <[email protected]>
- Loading branch information
1 parent
dbe3675
commit b49b4c9
Showing
4 changed files
with
100 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import base64 | ||
import onnxruntime | ||
import torch | ||
import io | ||
import numpy as np | ||
import torch.nn.functional as f | ||
from src.helpers.torch_utils import to_numpy, generate_prediction_response | ||
from jaqpotpy.descriptors.tokenizer import SmilesVectorizer | ||
from jaqpotpy.api.openapi import ModelType, PredictionRequest, PredictionResponse | ||
from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer | ||
|
||
|
||
def torch_sequence_post_handler(request: PredictionRequest) -> PredictionResponse: | ||
feat_config = request.model.torch_config | ||
featurizer = _load_featurizer(feat_config) | ||
target_name = request.model.dependent_features[0].name | ||
model_task = request.model.task | ||
user_input = request.dataset.input | ||
raw_model = request.model.raw_model | ||
predictions = [] | ||
for inp in user_input: | ||
model_output = onnx_post_handler( | ||
raw_model, featurizer.transform(featurizer.transform([inp["SMILES"]])) | ||
) | ||
predictions.append( | ||
generate_prediction_response(model_task, target_name, model_output, inp) | ||
) | ||
return PredictionResponse(predictions=predictions) | ||
|
||
|
||
def onnx_post_handler(raw_model, data): | ||
onnx_model = base64.b64decode(raw_model) | ||
ort_session = onnxruntime.InferenceSession(onnx_model) | ||
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(data)} | ||
ort_outs = torch.tensor(np.array(ort_session.run(None, ort_inputs))) | ||
return ort_outs | ||
|
||
|
||
def _load_featurizer(config): | ||
featurizer = SmilesVectorizer() | ||
featurizer.load_dict(config) | ||
return featurizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch.nn.functional as f | ||
from jaqpotpy.api.openapi.models.model_task import ModelTask | ||
|
||
|
||
def to_numpy(tensor): | ||
return ( | ||
tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | ||
) | ||
|
||
|
||
def torch_binary_classification(target_name, output, inp): | ||
proba = f.sigmoid(output).squeeze().tolist() | ||
pred = int(proba > 0.5) | ||
# UI Results | ||
results = { | ||
"jaqpotMetadata": { | ||
"probabilities": [round((1 - proba), 3), round(proba, 3)], | ||
"jaqpotRowId": inp["jaqpotRowId"], | ||
} | ||
} | ||
if "jaqpotRowLabel" in inp: | ||
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"] | ||
results[target_name] = pred | ||
return results | ||
|
||
|
||
def torch_regression(target_name, output, inp): | ||
pred = [output.squeeze().tolist()] | ||
results = {"jaqpotMetadata": {"jaqpotRowId": inp["jaqpotRowId"]}} | ||
if "jaqpotRowLabel" in inp: | ||
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"] | ||
results[target_name] = pred | ||
return results | ||
|
||
|
||
def generate_prediction_response(model_task, target_name, out, row_id): | ||
if model_task == ModelTask.BINARY_CLASSIFICATION: | ||
return torch_binary_classification(target_name, out, row_id) | ||
elif model_task == ModelTask.REGRESSION: | ||
return torch_regression(target_name, out, row_id) | ||
else: | ||
raise ValueError( | ||
"Only BINARY_CLASSIFICATION and REGRESSION tasks are supported" | ||
) |