Skip to content

Commit

Permalink
feat(JAQPOT-434): add new model types
Browse files Browse the repository at this point in the history
  • Loading branch information
alarv committed Nov 28, 2024
1 parent 8483a3c commit c319a39
Show file tree
Hide file tree
Showing 79 changed files with 27 additions and 17,619 deletions.
40 changes: 0 additions & 40 deletions etc/openapi-generate.sh

This file was deleted.

25 changes: 16 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

import uvicorn
from fastapi import FastAPI
from jaqpotpy.api.openapi import PredictionRequest, PredictionResponse, ModelType

from src.api.openapi import PredictionResponse
from src.api.openapi.models.prediction_request import PredictionRequest
from src.handlers.predict_sklearn import sklearn_post_handler
from src.handlers.predict_pyg import graph_post_handler
from src.handlers.predict_sklearn_onnx import sklearn_onnx_post_handler
from src.handlers.predict_torch import torch_post_handler

from src.loggers.logger import logger
from src.loggers.log_middleware import LogMiddleware
Expand All @@ -26,13 +25,21 @@ def health_check():
return {"status": "UP"}


@app.post("/predict/")
@app.post("/predict")
def predict(req: PredictionRequest) -> PredictionResponse:
logger.info("Prediction request for model " + str(req.model.id))
if req.model.type == "SKLEARN":
return sklearn_post_handler(req)
else:
return graph_post_handler(req)

match req.model.type:
case ModelType.SKLEARN_ONNX:
return sklearn_onnx_post_handler(req)
case (
ModelType.TORCH_GEOMETRIC_ONNX,
ModelType.TORCHSCRIPT,
ModelType.TORCH_SEQUENCE_ONNX,
):
return torch_post_handler(req)
case _:
raise Exception("Model type not supported")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ fastapi==0.111.0
pydantic==2.7.1
uvicorn==0.29.0
starlette~=0.37.2
jaqpotpy==6.11.0
jaqpotpy==6.17.0
pre-commit==4.0.1
ruff==0.6.3
95 changes: 0 additions & 95 deletions src/api/openapi/__init__.py

This file was deleted.

13 changes: 0 additions & 13 deletions src/api/openapi/api/__init__.py

This file was deleted.

Loading

0 comments on commit c319a39

Please sign in to comment.