diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index ce6fd77..30509f9 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -26,14 +26,13 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest + pip install flake8 pytest ruff if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --max-complexity=10 --max-line-length=127 --statistics + ruff check # - name: Test with pytest # run: | # pytest diff --git a/main.py b/main.py index d1f3b3f..d9ab6ef 100644 --- a/main.py +++ b/main.py @@ -1,23 +1,27 @@ import uvicorn from fastapi import FastAPI -from src.handlers.predict import model_post_handler +from src.handlers.predict import model_post_handler, graph_post_handler from src.entities.prediction_request import PredictionRequestPydantic from fastapi.responses import JSONResponse from src.loggers.log_middleware import LogMiddleware - app = FastAPI() app.add_middleware(LogMiddleware) -@app.get('/') +@app.get("/") def health_check(): - return {'status': 'UP'} + return {"status": "UP"} -@app.post('/predict/') +@app.post("/predict/") def predict(req: PredictionRequestPydantic): - return JSONResponse(content=model_post_handler(req)) + if req.model["type"] == "SKLEARN": + return JSONResponse(content=model_post_handler(req)) + elif req.model["type"] == "TORCH": + return JSONResponse(content=graph_post_handler(req)) + else: + raise ValueError("Only SKLEARN and TORCH models are supported") if __name__ == "__main__": diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..f11a08d --- /dev/null +++ b/ruff.toml @@ -0,0 +1,2 @@ +[lint] +ignore = ["D100", "E722", "F401", "F403"] diff --git a/src/entities/prediction_request.py b/src/entities/prediction_request.py index 6f8ff35..6191660 100644 --- a/src/entities/prediction_request.py +++ b/src/entities/prediction_request.py @@ -1,18 +1,8 @@ from pydantic import BaseModel from typing import Any - -class PredictionRequest: - - def __init__(self, dataset, model, doa=None): - self.dataset = dataset - self.model = model - self.rawModel = model.rawModel - self.additionalInfo = model.additionalInfo - self.doa = doa - - class PredictionRequestPydantic(BaseModel): dataset: Any model: Any doa: Any = None + extraConfig: Any = None \ No newline at end of file diff --git a/src/handlers/predict.py b/src/handlers/predict.py index 985d7e4..aedef2a 100644 --- a/src/handlers/predict.py +++ b/src/handlers/predict.py @@ -1,28 +1,104 @@ from ..entities.prediction_request import PredictionRequestPydantic from ..helpers import model_decoder, json_to_predreq +import base64 +import onnxruntime +import numpy as np +import torch +import torch.nn.functional as F + +# import sys +# import os + +# current_dir = os.path.dirname(__file__) +# software_dir = os.path.abspath(os.path.join(current_dir, '../../../../../JQP')) +# sys.path.append(software_dir) +from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer def model_post_handler(request: PredictionRequestPydantic): - model = model_decoder.decode(request.model['rawModel']) + model = model_decoder.decode(request.model["rawModel"]) data_entry_all = json_to_predreq.decode(request, model) prediction = model.predict_onnx(data_entry_all) - if model.task == 'classification': + if model.task == "classification": probabilities = model.predict_proba_onnx(data_entry_all) else: probabilities = [None for _ in range(len(prediction))] results = {} for i, feature in enumerate(model.dependentFeatures): - key = feature['key'] + key = feature["key"] if len(model.dependentFeatures) == 1: values = [str(item) for item in prediction] else: values = [str(item) for item in prediction[:, i]] results[key] = values - results['Probabilities'] = [str(prob) for prob in probabilities] - results['AD'] = [None for _ in range(len(prediction))] + results["Probabilities"] = [str(prob) for prob in probabilities] + results["AD"] = [None for _ in range(len(prediction))] + + final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]} + + return final_all + + +def graph_post_handler(request: PredictionRequestPydantic): + # Obtain the request info + onnx_model = base64.b64decode(request.model["rawModel"]) + ort_session = onnxruntime.InferenceSession(onnx_model) + feat_config = request.extraConfig["torchConfig"]["featurizer"] + # Load the featurizer + featurizer = SmilesGraphFeaturizer() + featurizer.load_json_rep(feat_config) + smiles = request.dataset["input"][0] + def to_numpy(tensor): + return ( + tensor.detach().cpu().numpy() + if tensor.requires_grad + else tensor.cpu().numpy() + ) + + data = featurizer.featurize(smiles) + # ONNX Inference + ort_inputs = { + ort_session.get_inputs()[0].name: to_numpy(data.x), + ort_session.get_inputs()[1].name: to_numpy(data.edge_index), + ort_session.get_inputs()[2].name: to_numpy( + torch.zeros(data.x.shape[0], dtype=torch.int64) + ), + } + ort_outs = torch.tensor(np.array(ort_session.run(None, ort_inputs))) + if request.extraConfig["torchConfig"]["task"] == "classification": + return graph_binary_classification(request, ort_outs) + elif request.extraConfig["torchConfig"]["task"] == "regression": + return ort_outs + else: + raise ValueError("Only classification and regression tasks are supported") + + +def graph_binary_classification(request: PredictionRequestPydantic, onnx_output): + # Classification + target_name = request.model["dependentFeatures"][0]["name"] + probs = [F.sigmoid(onnx_output).squeeze().tolist()] + preds = [int(prob > 0.5) for prob in probs] + # UI Results + results = {} + results["Probabilities"] = [str(prob) for prob in probs] + results[target_name] = [str(pred) for pred in preds] final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]} + print(final_all) return final_all + + +# def graph_regression(request: PredictionRequestPydantic, onnx_output): +# # Regression +# target_name = request.model['dependentFeatures'][0]['name'] +# preds = [onnx_output.squeeze().tolist()] +# # UI Results +# results = {} +# results[target_name] = [str(pred) for pred in preds] +# final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]} +# print(final_all) + +# return final_all diff --git a/src/serializers/__init__.py b/src/serializers/__init__.py deleted file mode 100644 index e69de29..0000000