Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: background ingestion in LOCAL mode #321

Merged
merged 13 commits into from
Sep 6, 2024
32 changes: 18 additions & 14 deletions backend/indexer/indexer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import tempfile
from typing import Dict, List
from concurrent.futures import Executor
from typing import Dict, List, Optional

from fastapi import HTTPException
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -299,7 +300,7 @@ async def ingest_data_points(
)


async def ingest_data(request: IngestDataToCollectionDto):
async def ingest_data(request: IngestDataToCollectionDto, pool: Optional[Executor]):
"""Ingest data into the collection"""
try:
client = await get_client()
Expand Down Expand Up @@ -353,19 +354,22 @@ async def ingest_data(request: IngestDataToCollectionDto):
created_data_ingestion_run = await client.acreate_data_ingestion_run(
data_ingestion_run=data_ingestion_run
)
await sync_data_source_to_collection(
inputs=DataIngestionConfig(
collection_name=created_data_ingestion_run.collection_name,
data_ingestion_run_name=created_data_ingestion_run.name,
data_source=associated_data_source.data_source,
embedder_config=collection.embedder_config,
parser_config=created_data_ingestion_run.parser_config,
data_ingestion_mode=created_data_ingestion_run.data_ingestion_mode,
raise_error_on_failure=created_data_ingestion_run.raise_error_on_failure,
batch_size=request.batch_size,
)
ingestion_config = DataIngestionConfig(
collection_name=created_data_ingestion_run.collection_name,
data_ingestion_run_name=created_data_ingestion_run.name,
data_source=associated_data_source.data_source,
embedder_config=collection.embedder_config,
parser_config=created_data_ingestion_run.parser_config,
data_ingestion_mode=created_data_ingestion_run.data_ingestion_mode,
raise_error_on_failure=created_data_ingestion_run.raise_error_on_failure,
batch_size=request.batch_size,
)
created_data_ingestion_run.status = DataIngestionRunStatus.COMPLETED
if pool:
# future of this submission is ignored, failures not tracked
pool.submit(sync_data_source_to_collection, ingestion_config)
else:
await sync_data_source_to_collection(ingestion_config)
created_data_ingestion_run.status = DataIngestionRunStatus.INITIALIZED
else:
if not settings.JOB_FQN:
logger.error("Job FQN is required to trigger the job")
Expand Down
14 changes: 14 additions & 0 deletions backend/server/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from contextlib import asynccontextmanager

from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
Expand All @@ -9,12 +11,24 @@
from backend.server.routers.internal import router as internal_router
from backend.server.routers.rag_apps import router as rag_apps_router
from backend.settings import settings
from backend.utils import AsyncProcessPoolExecutor


@asynccontextmanager
async def _process_pool_lifespan_manager(app: FastAPI):
app.state.process_pool = AsyncProcessPoolExecutor(
max_workers=settings.PROCESS_POOL_WORKERS
)
yield # FastAPI runs here
app.state.process_pool.shutdown(wait=True)


# FastAPI Initialization
app = FastAPI(
title="Backend for RAG",
root_path=settings.TFY_SERVICE_ROOT_PATH,
docs_url="/",
lifespan=_process_pool_lifespan_manager,
)

app.add_middleware(
Expand Down
14 changes: 11 additions & 3 deletions backend/server/routers/collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Path
from fastapi import APIRouter, HTTPException, Path, Request
from fastapi.responses import JSONResponse

from backend.indexer.indexer import ingest_data as ingest_data_to_collection
Expand Down Expand Up @@ -149,10 +149,18 @@ async def unassociate_data_source_from_collection(


@router.post("/ingest")
async def ingest_data(request: IngestDataToCollectionDto):
async def ingest_data(
ingest_data_to_collection_dto: IngestDataToCollectionDto, request: Request
):
"""Ingest data into the collection"""
try:
return await ingest_data_to_collection(request)
process_pool = request.app.state.process_pool
except AttributeError:
process_pool = None
try:
return await ingest_data_to_collection(
ingest_data_to_collection_dto, process_pool
)
except HTTPException as exp:
raise exp
except Exception as exp:
Expand Down
1 change: 1 addition & 0 deletions backend/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Settings(BaseSettings):
UNSTRUCTURED_IO_URL: str = ""

UNSTRUCTURED_IO_API_KEY: str = ""
PROCESS_POOL_WORKERS: int = 1

@model_validator(mode="before")
@classmethod
Expand Down
11 changes: 10 additions & 1 deletion backend/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import zipfile
from concurrent.futures import Executor
from concurrent.futures import Executor, ProcessPoolExecutor
from contextvars import copy_context
from functools import partial
from typing import Callable, Optional, TypeVar, cast
Expand Down Expand Up @@ -86,3 +86,12 @@ def wrapper() -> T:
)

return await asyncio.get_running_loop().run_in_executor(executor, wrapper)


class AsyncProcessPoolExecutor(ProcessPoolExecutor):
@staticmethod
def _async_to_sync(__fn, *args, **kwargs):
return asyncio.run(__fn(*args, **kwargs))

def submit(self, __fn, *args, **kwargs):
return super().submit(self._async_to_sync, __fn, *args, **kwargs)
1 change: 1 addition & 0 deletions compose.env
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
LOCAL=true
PROCESS_POOL_WORKERS=4

## POSTGRES
POSTGRES_PORT=5432
Expand Down
1 change: 1 addition & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ services:
environment:
- DEBUG_MODE=true
- LOCAL=${LOCAL}
- PROCESS_POOL_WORKERS=${PROCESS_POOL_WORKERS}
- LOG_LEVEL=DEBUG
- DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@cognita-db:5432/cognita-config
- METADATA_STORE_CONFIG=${METADATA_STORE_CONFIG}
Expand Down
Loading