Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not create FastAPI app until config loaded #702

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import asynccontextmanager

from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Depends,
Expand Down Expand Up @@ -80,25 +81,34 @@
teardown_runner()


app = FastAPI(
docs_url="/docs",
title="BlueAPI Control",
lifespan=lifespan,
version=REST_API_VERSION,
)
router = APIRouter()


def get_app():
DiamondJoseph marked this conversation as resolved.
Show resolved Hide resolved
app = FastAPI(
docs_url="/docs",
title="BlueAPI Control",
lifespan=lifespan,
version=REST_API_VERSION,
)
app.include_router(router)
app.add_exception_handler(KeyError, on_key_error_404)
app.middleware("http")(add_api_version_header)
app.middleware("http")(inject_propagated_observability_context)
return app


TRACER = get_tracer("interface")


@app.exception_handler(KeyError)
async def on_key_error_404(_: Request, __: KeyError):
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={"detail": "Item not found"},
)


@app.get("/environment", response_model=EnvironmentResponse)
@router.get("/environment", response_model=EnvironmentResponse)
@start_as_current_span(TRACER, "runner")
def get_environment(
runner: WorkerDispatcher = Depends(_runner),
Expand All @@ -107,7 +117,7 @@
return runner.state


@app.delete("/environment", response_model=EnvironmentResponse)
@router.delete("/environment", response_model=EnvironmentResponse)
async def delete_environment(
background_tasks: BackgroundTasks,
runner: WorkerDispatcher = Depends(_runner),
Expand All @@ -119,14 +129,14 @@
return EnvironmentResponse(initialized=False)


@app.get("/plans", response_model=PlanResponse)
@router.get("/plans", response_model=PlanResponse)
@start_as_current_span(TRACER)
def get_plans(runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about all available plans."""
return PlanResponse(plans=runner.run(interface.get_plans))


@app.get(
@router.get(
"/plans/{name}",
response_model=PlanModel,
)
Expand All @@ -136,14 +146,14 @@
return runner.run(interface.get_plan, name)


@app.get("/devices", response_model=DeviceResponse)
@router.get("/devices", response_model=DeviceResponse)
@start_as_current_span(TRACER)
def get_devices(runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about all available devices."""
return DeviceResponse(devices=runner.run(interface.get_devices))


@app.get(
@router.get(
"/devices/{name}",
response_model=DeviceModel,
)
Expand All @@ -156,7 +166,7 @@
example_task = Task(name="count", params={"detectors": ["x"]})


@app.post(
@router.post(
"/tasks",
response_model=TaskResponse,
status_code=status.HTTP_201_CREATED,
Expand Down Expand Up @@ -190,7 +200,7 @@
) from e


@app.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK)
@router.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK)
@start_as_current_span(TRACER, "task_id")
def delete_submitted_task(
task_id: str,
Expand All @@ -207,7 +217,7 @@
return TaskStatusEnum(v_upper)


@app.get("/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK)
@router.get("/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK)
@start_as_current_span(TRACER)
def get_tasks(
task_status: str | None = None,
Expand All @@ -234,7 +244,7 @@
return TasksListResponse(tasks=tasks)


@app.put(
@router.put(
"/worker/task",
response_model=WorkerTask,
responses={status.HTTP_409_CONFLICT: {"worker": "already active"}},
Expand All @@ -255,7 +265,7 @@
return task


@app.get(
@router.get(
"/tasks/{task_id}",
response_model=TrackableTask,
)
Expand All @@ -271,7 +281,7 @@
return task


@app.get("/worker/task")
@router.get("/worker/task")
@start_as_current_span(TRACER)
def get_active_task(runner: WorkerDispatcher = Depends(_runner)) -> WorkerTask:
active = runner.run(interface.get_active_task)
Expand All @@ -281,7 +291,7 @@
return WorkerTask(task_id=None)


@app.get("/worker/state")
@router.get("/worker/state")
@start_as_current_span(TRACER)
def get_state(runner: WorkerDispatcher = Depends(_runner)) -> WorkerState:
"""Get the State of the Worker"""
Expand All @@ -303,7 +313,7 @@
}


@app.put(
@router.put(
"/worker/state",
status_code=status.HTTP_202_ACCEPTED,
responses={
Expand Down Expand Up @@ -372,29 +382,32 @@
"%(asctime)s %(levelprefix)s %(client_addr)s"
+ " - '%(request_line)s' %(status_code)s"
)
app = get_app()

Check warning on line 385 in src/blueapi/service/main.py

View check run for this annotation

Codecov / codecov/patch

src/blueapi/service/main.py#L385

Added line #L385 was not covered by tests

FastAPIInstrumentor().instrument_app(
app,
tracer_provider=get_tracer_provider(),
http_capture_headers_server_request=[",*"],
http_capture_headers_server_response=[",*"],
)
app.state.config = config

uvicorn.run(app, host=config.api.host, port=config.api.port)


@app.middleware("http")
async def add_api_version_header(request: Request, call_next):
async def add_api_version_header(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
):
response = await call_next(request)
response.headers["X-API-Version"] = REST_API_VERSION
return response


@app.middleware("http")
async def inject_propagated_observability_context(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Middleware to extract the any prorpagated observability context from the
HTTP headers and attatch it to the local one.
"""Middleware to extract the any propagated observability context from the
HTTP headers and attach it to the local one.
"""
if CONTEXT_HEADER in request.headers:
ctx = get_global_textmap().extract(
Expand Down
3 changes: 2 additions & 1 deletion src/blueapi/service/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from fastapi.openapi.utils import get_openapi
from pyparsing import Any

from blueapi.service.main import app
from blueapi.service.main import get_app

DOCS_SCHEMA_LOCATION = Path(__file__).parents[3] / "docs" / "reference" / "openapi.yaml"


def generate_schema() -> Mapping[str, Any]:
app = get_app()
return get_openapi(
title=app.title,
version=app.version,
Expand Down
12 changes: 7 additions & 5 deletions tests/unit_tests/service/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from blueapi.service.openapi import DOCS_SCHEMA_LOCATION, generate_schema


@mock.patch("blueapi.service.openapi.app")
def test_generate_schema(mock_app: Mock) -> None:
from blueapi.service.main import app
@mock.patch("blueapi.service.openapi.get_app")
def test_generate_schema(mock_get_app: Mock) -> None:
mock_app = mock_get_app()

from blueapi.service.main import get_app

app = get_app()

title = PropertyMock(return_value="title")
version = PropertyMock(return_value=app.version)
Expand All @@ -23,8 +27,6 @@ def test_generate_schema(mock_app: Mock) -> None:
type(mock_app).description = description
type(mock_app).routes = routes

# from blueapi.service.openapi import generate_schema

assert generate_schema() == {
"openapi": openapi_version(),
"info": {
Expand Down
6 changes: 2 additions & 4 deletions tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@

@pytest.fixture
def client() -> Iterator[TestClient]:
with (
patch("blueapi.service.interface.worker"),
):
with patch("blueapi.service.interface.worker"):
main.setup_runner(use_subprocess=False)
yield TestClient(main.app)
yield TestClient(main.get_app())
main.teardown_runner()


Expand Down
Loading