Skip to content

Commit

Permalink
Add support for preloading models
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnorell committed Nov 20, 2024
1 parent 83e9220 commit 469fbcc
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
3 changes: 3 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,6 @@
WORKFLOW_BLOCKS_WRITE_DIRECTORY = os.getenv("WORKFLOW_BLOCKS_WRITE_DIRECTORY")

DEDICATED_DEPLOYMENT_ID = os.getenv("DEDICATED_DEPLOYMENT_ID")

# Preload Models
PRELOAD_MODELS = os.getenv("PRELOAD_MODELS").split(",") if os.getenv("PRELOAD_MODELS") else None
56 changes: 55 additions & 1 deletion inference/core/interfaces/http/http_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import base64
import os
import traceback
Expand All @@ -7,13 +8,14 @@

import asgi_correlation_id
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Path, Query, Request
from fastapi import BackgroundTasks, FastAPI, Path, Query, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse, Response
from fastapi.staticfiles import StaticFiles
from fastapi_cprofile.profiler import CProfileMiddleware
from starlette.convertors import StringConvertor, register_url_convertor
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel

from inference.core import logger
from inference.core.devices.utils import GLOBAL_INFERENCE_SERVER_ID
Expand Down Expand Up @@ -129,6 +131,7 @@
WORKFLOWS_MAX_CONCURRENT_STEPS,
WORKFLOWS_PROFILER_BUFFER_SIZE,
WORKFLOWS_STEP_EXECUTION_MODE,
PRELOAD_MODELS,
)
from inference.core.exceptions import (
ContentTypeInvalid,
Expand Down Expand Up @@ -2384,5 +2387,56 @@ async def model_add(dataset_id: str, version_id: str, api_key: str = None):
name="static",
)

# Enable preloading models at startup
if PRELOAD_MODELS and API_KEY and not (LAMBDA or LEGACY_ROUTE_ENABLED):

class ModelInitState:
"""Class to track model initialization state."""

def __init__(self):
self.is_ready = False

model_init_state = ModelInitState()

async def initialize_models(state: ModelInitState):
"""Perform asynchronous initialization tasks to load models."""
try:
# Create tasks for each model to be loaded
tasks = [
model_add(
AddModelRequest(
model_id=model_id, model_type=None, api_key=API_KEY
)
)
for model_id in PRELOAD_MODELS
]

# Wait for all model loading tasks to complete
await asyncio.gather(*tasks)

# Mark the server as ready
state.is_ready = True
except Exception as e:
print(f"Error during startup initialization: {e}")

@app.on_event("startup")
async def startup_model_init():
"""Start to initialize the models on startup"""
asyncio.create_task(initialize_models(model_init_state))

@app.get("/readiness", status_code=200)
async def readiness(
state: ModelInitState = Depends(lambda: model_init_state),
):
"""Readiness endpoint for Kubernetes readiness probe."""
if state.is_ready:
return {"status": "ready"}
return JSONResponse(content={"status": "not ready"}, status_code=503)

@app.get("/healthz", status_code=200)
async def healthz():
"""Health endpoint for Kubernetes liveness probe."""
return {"status": "healthy"}

def run(self):
uvicorn.run(self.app, host="127.0.0.1", port=8080)

0 comments on commit 469fbcc

Please sign in to comment.