Skip to content

Commit

Permalink
Do not create FastAPI app until config loaded
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Nov 7, 2024
1 parent 1c50a4b commit 1ed5c2e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 36 deletions.
65 changes: 39 additions & 26 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import asynccontextmanager

from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Depends,
Expand Down Expand Up @@ -80,25 +81,34 @@ async def lifespan(app: FastAPI):
teardown_runner()


app = FastAPI(
docs_url="/docs",
title="BlueAPI Control",
lifespan=lifespan,
version=REST_API_VERSION,
)
router = APIRouter()


def get_app():
app = FastAPI(
docs_url="/docs",
title="BlueAPI Control",
lifespan=lifespan,
version=REST_API_VERSION,
)
app.include_router(router)
app.add_exception_handler(KeyError, on_key_error_404)
app.middleware("http")(add_api_version_header)
app.middleware("http")(inject_propagated_observability_context)
return app


TRACER = get_tracer("interface")


@app.exception_handler(KeyError)
async def on_key_error_404(_: Request, __: KeyError):
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={"detail": "Item not found"},
)


@app.get("/environment", response_model=EnvironmentResponse)
@router.get("/environment", response_model=EnvironmentResponse)
@start_as_current_span(TRACER, "runner")
def get_environment(
runner: WorkerDispatcher = Depends(_runner),
Expand All @@ -107,7 +117,7 @@ def get_environment(
return runner.state


@app.delete("/environment", response_model=EnvironmentResponse)
@router.delete("/environment", response_model=EnvironmentResponse)
async def delete_environment(
background_tasks: BackgroundTasks,
runner: WorkerDispatcher = Depends(_runner),
Expand All @@ -119,14 +129,14 @@ async def delete_environment(
return EnvironmentResponse(initialized=False)


@app.get("/plans", response_model=PlanResponse)
@router.get("/plans", response_model=PlanResponse)
@start_as_current_span(TRACER)
def get_plans(runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about all available plans."""
return PlanResponse(plans=runner.run(interface.get_plans))


@app.get(
@router.get(
"/plans/{name}",
response_model=PlanModel,
)
Expand All @@ -136,14 +146,14 @@ def get_plan_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
return runner.run(interface.get_plan, name)


@app.get("/devices", response_model=DeviceResponse)
@router.get("/devices", response_model=DeviceResponse)
@start_as_current_span(TRACER)
def get_devices(runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about all available devices."""
return DeviceResponse(devices=runner.run(interface.get_devices))


@app.get(
@router.get(
"/devices/{name}",
response_model=DeviceModel,
)
Expand All @@ -156,7 +166,7 @@ def get_device_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
example_task = Task(name="count", params={"detectors": ["x"]})


@app.post(
@router.post(
"/tasks",
response_model=TaskResponse,
status_code=status.HTTP_201_CREATED,
Expand Down Expand Up @@ -190,7 +200,7 @@ def submit_task(
) from e


@app.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK)
@router.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK)
@start_as_current_span(TRACER, "task_id")
def delete_submitted_task(
task_id: str,
Expand All @@ -207,7 +217,7 @@ def validate_task_status(v: str) -> TaskStatusEnum:
return TaskStatusEnum(v_upper)


@app.get("/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK)
@router.get("/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK)
@start_as_current_span(TRACER)
def get_tasks(
task_status: str | None = None,
Expand All @@ -234,7 +244,7 @@ def get_tasks(
return TasksListResponse(tasks=tasks)


@app.put(
@router.put(
"/worker/task",
response_model=WorkerTask,
responses={status.HTTP_409_CONFLICT: {"worker": "already active"}},
Expand All @@ -255,7 +265,7 @@ def set_active_task(
return task


@app.get(
@router.get(
"/tasks/{task_id}",
response_model=TrackableTask,
)
Expand All @@ -271,7 +281,7 @@ def get_task(
return task


@app.get("/worker/task")
@router.get("/worker/task")
@start_as_current_span(TRACER)
def get_active_task(runner: WorkerDispatcher = Depends(_runner)) -> WorkerTask:
active = runner.run(interface.get_active_task)
Expand All @@ -281,7 +291,7 @@ def get_active_task(runner: WorkerDispatcher = Depends(_runner)) -> WorkerTask:
return WorkerTask(task_id=None)


@app.get("/worker/state")
@router.get("/worker/state")
@start_as_current_span(TRACER)
def get_state(runner: WorkerDispatcher = Depends(_runner)) -> WorkerState:
"""Get the State of the Worker"""
Expand All @@ -303,7 +313,7 @@ def get_state(runner: WorkerDispatcher = Depends(_runner)) -> WorkerState:
}


@app.put(
@router.put(
"/worker/state",
status_code=status.HTTP_202_ACCEPTED,
responses={
Expand Down Expand Up @@ -372,29 +382,32 @@ def start(config: ApplicationConfig):
"%(asctime)s %(levelprefix)s %(client_addr)s"
+ " - '%(request_line)s' %(status_code)s"
)
app = get_app()

Check warning on line 385 in src/blueapi/service/main.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/main.py#L385

Added line #L385 was not covered by tests

FastAPIInstrumentor().instrument_app(
app,
tracer_provider=get_tracer_provider(),
http_capture_headers_server_request=[",*"],
http_capture_headers_server_response=[",*"],
)
app.state.config = config

uvicorn.run(app, host=config.api.host, port=config.api.port)


@app.middleware("http")
async def add_api_version_header(request: Request, call_next):
async def add_api_version_header(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
):
response = await call_next(request)
response.headers["X-API-Version"] = REST_API_VERSION
return response


@app.middleware("http")
async def inject_propagated_observability_context(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Middleware to extract the any prorpagated observability context from the
HTTP headers and attatch it to the local one.
"""Middleware to extract the any propagated observability context from the
HTTP headers and attach it to the local one.
"""
if CONTEXT_HEADER in request.headers:
ctx = get_global_textmap().extract(
Expand Down
3 changes: 2 additions & 1 deletion src/blueapi/service/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from fastapi.openapi.utils import get_openapi
from pyparsing import Any

from blueapi.service.main import app
from blueapi.service.main import get_app

DOCS_SCHEMA_LOCATION = Path(__file__).parents[3] / "docs" / "reference" / "openapi.yaml"


def generate_schema() -> Mapping[str, Any]:
app = get_app()
return get_openapi(
title=app.title,
version=app.version,
Expand Down
12 changes: 7 additions & 5 deletions tests/unit_tests/service/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from blueapi.service.openapi import DOCS_SCHEMA_LOCATION, generate_schema


@mock.patch("blueapi.service.openapi.app")
def test_generate_schema(mock_app: Mock) -> None:
from blueapi.service.main import app
@mock.patch("blueapi.service.openapi.get_app")
def test_generate_schema(mock_get_app: Mock) -> None:
mock_app = mock_get_app()

from blueapi.service.main import get_app

app = get_app()

title = PropertyMock(return_value="title")
version = PropertyMock(return_value=app.version)
Expand All @@ -23,8 +27,6 @@ def test_generate_schema(mock_app: Mock) -> None:
type(mock_app).description = description
type(mock_app).routes = routes

# from blueapi.service.openapi import generate_schema

assert generate_schema() == {
"openapi": openapi_version(),
"info": {
Expand Down
6 changes: 2 additions & 4 deletions tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@

@pytest.fixture
def client() -> Iterator[TestClient]:
with (
patch("blueapi.service.interface.worker"),
):
with patch("blueapi.service.interface.worker"):
main.setup_runner(use_subprocess=False)
yield TestClient(main.app)
yield TestClient(main.get_app())
main.teardown_runner()


Expand Down

0 comments on commit 1ed5c2e

Please sign in to comment.