Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use run_in_threadpool make server async #62

Merged
merged 2 commits into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions duetector/service/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import Any, Dict, Optional

try:
from functools import cache
except ImportError:
from functools import lru_cache as cache

from fastapi import Depends

from duetector.config import Configuable
Expand All @@ -15,6 +20,7 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, *args, **kwargs):
super().__init__(config, *args, **kwargs)


@cache
def get_controller(controller_type: type):
def _(config: dict = Depends(get_config)) -> Controller:
return controller_type(config)
Expand Down
9 changes: 6 additions & 3 deletions duetector/service/query/routes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from asyncio import sleep

from fastapi import APIRouter, Body, Depends
from fastapi.concurrency import run_in_threadpool

from duetector.service.base import get_controller
from duetector.service.query.controller import AnalyzerController
Expand Down Expand Up @@ -35,7 +38,7 @@ async def query(
Query data from analyzer
"""
analyzer = controller.get_analyzer(analyzer_name)
trackings = analyzer.query(**query_param.model_dump())
trackings = await run_in_threadpool(analyzer.query, **query_param.model_dump())

return QueryResult(
trackings=trackings,
Expand All @@ -50,8 +53,8 @@ async def query_brief(
):
# type is not serializable, so we need to get analyzer without inspect type
analyzer = controller.get_analyzer(analyzer_name)

brief = await run_in_threadpool(analyzer.brief, inspect_type=False)
return BriefResult(
brief=analyzer.brief(inspect_type=False),
brief=brief,
analyzer_name=analyzer_name,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
# Following are for web server
"fastapi",
"uvicorn[standard]",
"anyio"
]
dynamic = ["version"]
classifiers = [
Expand Down
Loading