Skip to content

Commit

Permalink
Feat: jaqpot-223 (#14)
Browse files Browse the repository at this point in the history
* starting_code_for_graph_inference

* no_pickling_for_graph

* remove predict_graph

* minor_fix

* remove serializers

* comment_out

* ruff_format

* Update predict.py

* Update predict.py

* fix: build

* input fix

---------

Co-authored-by: alarv <[email protected]>
  • Loading branch information
johnsaveus and alarv authored Sep 10, 2024
1 parent dabb25b commit a960e5e
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 25 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install flake8 pytest ruff
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --max-complexity=10 --max-line-length=127 --statistics
ruff check
# - name: Test with pytest
# run: |
# pytest
16 changes: 10 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import uvicorn
from fastapi import FastAPI
from src.handlers.predict import model_post_handler
from src.handlers.predict import model_post_handler, graph_post_handler
from src.entities.prediction_request import PredictionRequestPydantic
from fastapi.responses import JSONResponse
from src.loggers.log_middleware import LogMiddleware


app = FastAPI()
app.add_middleware(LogMiddleware)


@app.get('/')
@app.get("/")
def health_check():
return {'status': 'UP'}
return {"status": "UP"}


@app.post('/predict/')
@app.post("/predict/")
def predict(req: PredictionRequestPydantic):
return JSONResponse(content=model_post_handler(req))
if req.model["type"] == "SKLEARN":
return JSONResponse(content=model_post_handler(req))
elif req.model["type"] == "TORCH":
return JSONResponse(content=graph_post_handler(req))
else:
raise ValueError("Only SKLEARN and TORCH models are supported")


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[lint]
ignore = ["D100", "E722", "F401", "F403"]
12 changes: 1 addition & 11 deletions src/entities/prediction_request.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
from pydantic import BaseModel
from typing import Any


class PredictionRequest:

def __init__(self, dataset, model, doa=None):
self.dataset = dataset
self.model = model
self.rawModel = model.rawModel
self.additionalInfo = model.additionalInfo
self.doa = doa


class PredictionRequestPydantic(BaseModel):
dataset: Any
model: Any
doa: Any = None
extraConfig: Any = None
86 changes: 81 additions & 5 deletions src/handlers/predict.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,104 @@
from ..entities.prediction_request import PredictionRequestPydantic
from ..helpers import model_decoder, json_to_predreq
import base64
import onnxruntime
import numpy as np
import torch
import torch.nn.functional as F

# import sys
# import os

# current_dir = os.path.dirname(__file__)
# software_dir = os.path.abspath(os.path.join(current_dir, '../../../../../JQP'))
# sys.path.append(software_dir)
from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer


def model_post_handler(request: PredictionRequestPydantic):
model = model_decoder.decode(request.model['rawModel'])
model = model_decoder.decode(request.model["rawModel"])
data_entry_all = json_to_predreq.decode(request, model)
prediction = model.predict_onnx(data_entry_all)
if model.task == 'classification':
if model.task == "classification":
probabilities = model.predict_proba_onnx(data_entry_all)
else:
probabilities = [None for _ in range(len(prediction))]

results = {}
for i, feature in enumerate(model.dependentFeatures):
key = feature['key']
key = feature["key"]
if len(model.dependentFeatures) == 1:
values = [str(item) for item in prediction]
else:
values = [str(item) for item in prediction[:, i]]
results[key] = values

results['Probabilities'] = [str(prob) for prob in probabilities]
results['AD'] = [None for _ in range(len(prediction))]
results["Probabilities"] = [str(prob) for prob in probabilities]
results["AD"] = [None for _ in range(len(prediction))]

final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]}

return final_all


def graph_post_handler(request: PredictionRequestPydantic):
# Obtain the request info
onnx_model = base64.b64decode(request.model["rawModel"])
ort_session = onnxruntime.InferenceSession(onnx_model)
feat_config = request.extraConfig["torchConfig"]["featurizer"]
# Load the featurizer
featurizer = SmilesGraphFeaturizer()
featurizer.load_json_rep(feat_config)
smiles = request.dataset["input"][0]

def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)

data = featurizer.featurize(smiles)
# ONNX Inference
ort_inputs = {
ort_session.get_inputs()[0].name: to_numpy(data.x),
ort_session.get_inputs()[1].name: to_numpy(data.edge_index),
ort_session.get_inputs()[2].name: to_numpy(
torch.zeros(data.x.shape[0], dtype=torch.int64)
),
}
ort_outs = torch.tensor(np.array(ort_session.run(None, ort_inputs)))
if request.extraConfig["torchConfig"]["task"] == "classification":
return graph_binary_classification(request, ort_outs)
elif request.extraConfig["torchConfig"]["task"] == "regression":
return ort_outs
else:
raise ValueError("Only classification and regression tasks are supported")


def graph_binary_classification(request: PredictionRequestPydantic, onnx_output):
# Classification
target_name = request.model["dependentFeatures"][0]["name"]
probs = [F.sigmoid(onnx_output).squeeze().tolist()]
preds = [int(prob > 0.5) for prob in probs]
# UI Results
results = {}
results["Probabilities"] = [str(prob) for prob in probs]
results[target_name] = [str(pred) for pred in preds]
final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]}
print(final_all)

return final_all


# def graph_regression(request: PredictionRequestPydantic, onnx_output):
# # Regression
# target_name = request.model['dependentFeatures'][0]['name']
# preds = [onnx_output.squeeze().tolist()]
# # UI Results
# results = {}
# results[target_name] = [str(pred) for pred in preds]
# final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]}
# print(final_all)

# return final_all
Empty file removed src/serializers/__init__.py
Empty file.

0 comments on commit a960e5e

Please sign in to comment.