-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: jaqpot-223 #14
Merged
Merged
Feat: jaqpot-223 #14
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
2cc295c
starting_code_for_graph_inference
johnsaveus 41c7a45
no_pickling_for_graph
johnsaveus f703a79
remove predict_graph
johnsaveus df939df
minor_fix
johnsaveus c7d95f3
remove serializers
johnsaveus 14db7d7
comment_out
johnsaveus bccc3b6
ruff_format
johnsaveus 8e113a0
Merge branch 'main' into feat/JAQPOT-223/Inference_for_graphs
johnsaveus 01a99c6
Update predict.py
johnsaveus 0aa0a0b
Update predict.py
johnsaveus d00d2a6
fix: build
alarv ba0bd21
input fix
johnsaveus File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[lint] | ||
ignore = ["D100", "E722", "F401", "F403"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,8 @@ | ||
from pydantic import BaseModel | ||
from typing import Any | ||
|
||
|
||
class PredictionRequest: | ||
|
||
def __init__(self, dataset, model, doa=None): | ||
self.dataset = dataset | ||
self.model = model | ||
self.rawModel = model.rawModel | ||
self.additionalInfo = model.additionalInfo | ||
self.doa = doa | ||
|
||
|
||
class PredictionRequestPydantic(BaseModel): | ||
dataset: Any | ||
model: Any | ||
doa: Any = None | ||
extraConfig: Any = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,104 @@ | ||
from ..entities.prediction_request import PredictionRequestPydantic | ||
from ..helpers import model_decoder, json_to_predreq | ||
import base64 | ||
import onnxruntime | ||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
# import sys | ||
# import os | ||
|
||
# current_dir = os.path.dirname(__file__) | ||
# software_dir = os.path.abspath(os.path.join(current_dir, '../../../../../JQP')) | ||
# sys.path.append(software_dir) | ||
from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer | ||
|
||
|
||
def model_post_handler(request: PredictionRequestPydantic): | ||
model = model_decoder.decode(request.model['rawModel']) | ||
model = model_decoder.decode(request.model["rawModel"]) | ||
data_entry_all = json_to_predreq.decode(request, model) | ||
prediction = model.predict_onnx(data_entry_all) | ||
if model.task == 'classification': | ||
if model.task == "classification": | ||
probabilities = model.predict_proba_onnx(data_entry_all) | ||
else: | ||
probabilities = [None for _ in range(len(prediction))] | ||
|
||
results = {} | ||
for i, feature in enumerate(model.dependentFeatures): | ||
key = feature['key'] | ||
key = feature["key"] | ||
if len(model.dependentFeatures) == 1: | ||
values = [str(item) for item in prediction] | ||
else: | ||
values = [str(item) for item in prediction[:, i]] | ||
results[key] = values | ||
|
||
results['Probabilities'] = [str(prob) for prob in probabilities] | ||
results['AD'] = [None for _ in range(len(prediction))] | ||
results["Probabilities"] = [str(prob) for prob in probabilities] | ||
results["AD"] = [None for _ in range(len(prediction))] | ||
|
||
final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]} | ||
|
||
return final_all | ||
|
||
|
||
def graph_post_handler(request: PredictionRequestPydantic): | ||
# Obtain the request info | ||
onnx_model = base64.b64decode(request.model["rawModel"]) | ||
ort_session = onnxruntime.InferenceSession(onnx_model) | ||
feat_config = request.extraConfig["torchConfig"]["featurizer"] | ||
# Load the featurizer | ||
featurizer = SmilesGraphFeaturizer() | ||
featurizer.load_json_rep(feat_config) | ||
smiles = request.dataset["input"][0] | ||
|
||
def to_numpy(tensor): | ||
return ( | ||
tensor.detach().cpu().numpy() | ||
if tensor.requires_grad | ||
else tensor.cpu().numpy() | ||
) | ||
|
||
data = featurizer.featurize(smiles) | ||
# ONNX Inference | ||
ort_inputs = { | ||
ort_session.get_inputs()[0].name: to_numpy(data.x), | ||
ort_session.get_inputs()[1].name: to_numpy(data.edge_index), | ||
ort_session.get_inputs()[2].name: to_numpy( | ||
torch.zeros(data.x.shape[0], dtype=torch.int64) | ||
), | ||
} | ||
ort_outs = torch.tensor(np.array(ort_session.run(None, ort_inputs))) | ||
if request.extraConfig["torchConfig"]["task"] == "classification": | ||
return graph_binary_classification(request, ort_outs) | ||
elif request.extraConfig["torchConfig"]["task"] == "regression": | ||
return ort_outs | ||
else: | ||
raise ValueError("Only classification and regression tasks are supported") | ||
|
||
|
||
def graph_binary_classification(request: PredictionRequestPydantic, onnx_output): | ||
# Classification | ||
target_name = request.model["dependentFeatures"][0]["name"] | ||
probs = [F.sigmoid(onnx_output).squeeze().tolist()] | ||
preds = [int(prob > 0.5) for prob in probs] | ||
# UI Results | ||
results = {} | ||
results["Probabilities"] = [str(prob) for prob in probs] | ||
results[target_name] = [str(pred) for pred in preds] | ||
final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]} | ||
print(final_all) | ||
|
||
return final_all | ||
|
||
|
||
# def graph_regression(request: PredictionRequestPydantic, onnx_output): | ||
# # Regression | ||
# target_name = request.model['dependentFeatures'][0]['name'] | ||
# preds = [onnx_output.squeeze().tolist()] | ||
# # UI Results | ||
# results = {} | ||
# results[target_name] = [str(pred) for pred in preds] | ||
# final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]} | ||
# print(final_all) | ||
|
||
# return final_all |
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this is as generic as possible and we don't need to change it in the future. Hardcoding [0] of the input means that if we ever upload 2 smiles inputs this won't work and we won't know why till we find this [0] here 😄