diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 80c0a330e..05c529d7d 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -2,6 +2,7 @@ from contextlib import asynccontextmanager from fastapi import ( + APIRouter, BackgroundTasks, Body, Depends, @@ -80,17 +81,26 @@ async def lifespan(app: FastAPI): teardown_runner() -app = FastAPI( - docs_url="/docs", - title="BlueAPI Control", - lifespan=lifespan, - version=REST_API_VERSION, -) +router = APIRouter() + + +def get_app(): + app = FastAPI( + docs_url="/docs", + title="BlueAPI Control", + lifespan=lifespan, + version=REST_API_VERSION, + ) + app.include_router(router) + app.add_exception_handler(KeyError, on_key_error_404) + app.middleware("http")(add_api_version_header) + app.middleware("http")(inject_propagated_observability_context) + return app + TRACER = get_tracer("interface") -@app.exception_handler(KeyError) async def on_key_error_404(_: Request, __: KeyError): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -98,7 +108,7 @@ async def on_key_error_404(_: Request, __: KeyError): ) -@app.get("/environment", response_model=EnvironmentResponse) +@router.get("/environment", response_model=EnvironmentResponse) @start_as_current_span(TRACER, "runner") def get_environment( runner: WorkerDispatcher = Depends(_runner), @@ -107,7 +117,7 @@ def get_environment( return runner.state -@app.delete("/environment", response_model=EnvironmentResponse) +@router.delete("/environment", response_model=EnvironmentResponse) async def delete_environment( background_tasks: BackgroundTasks, runner: WorkerDispatcher = Depends(_runner), @@ -119,14 +129,14 @@ async def delete_environment( return EnvironmentResponse(initialized=False) -@app.get("/plans", response_model=PlanResponse) +@router.get("/plans", response_model=PlanResponse) @start_as_current_span(TRACER) def get_plans(runner: WorkerDispatcher = Depends(_runner)): """Retrieve information about all available plans.""" return PlanResponse(plans=runner.run(interface.get_plans)) -@app.get( +@router.get( "/plans/{name}", response_model=PlanModel, ) @@ -136,14 +146,14 @@ def get_plan_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)): return runner.run(interface.get_plan, name) -@app.get("/devices", response_model=DeviceResponse) +@router.get("/devices", response_model=DeviceResponse) @start_as_current_span(TRACER) def get_devices(runner: WorkerDispatcher = Depends(_runner)): """Retrieve information about all available devices.""" return DeviceResponse(devices=runner.run(interface.get_devices)) -@app.get( +@router.get( "/devices/{name}", response_model=DeviceModel, ) @@ -156,7 +166,7 @@ def get_device_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)): example_task = Task(name="count", params={"detectors": ["x"]}) -@app.post( +@router.post( "/tasks", response_model=TaskResponse, status_code=status.HTTP_201_CREATED, @@ -190,7 +200,7 @@ def submit_task( ) from e -@app.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK) +@router.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK) @start_as_current_span(TRACER, "task_id") def delete_submitted_task( task_id: str, @@ -207,7 +217,7 @@ def validate_task_status(v: str) -> TaskStatusEnum: return TaskStatusEnum(v_upper) -@app.get("/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK) +@router.get("/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK) @start_as_current_span(TRACER) def get_tasks( task_status: str | None = None, @@ -234,7 +244,7 @@ def get_tasks( return TasksListResponse(tasks=tasks) -@app.put( +@router.put( "/worker/task", response_model=WorkerTask, responses={status.HTTP_409_CONFLICT: {"worker": "already active"}}, @@ -255,7 +265,7 @@ def set_active_task( return task -@app.get( +@router.get( "/tasks/{task_id}", response_model=TrackableTask, ) @@ -271,7 +281,7 @@ def get_task( return task -@app.get("/worker/task") +@router.get("/worker/task") @start_as_current_span(TRACER) def get_active_task(runner: WorkerDispatcher = Depends(_runner)) -> WorkerTask: active = runner.run(interface.get_active_task) @@ -281,7 +291,7 @@ def get_active_task(runner: WorkerDispatcher = Depends(_runner)) -> WorkerTask: return WorkerTask(task_id=None) -@app.get("/worker/state") +@router.get("/worker/state") @start_as_current_span(TRACER) def get_state(runner: WorkerDispatcher = Depends(_runner)) -> WorkerState: """Get the State of the Worker""" @@ -303,7 +313,7 @@ def get_state(runner: WorkerDispatcher = Depends(_runner)) -> WorkerState: } -@app.put( +@router.put( "/worker/state", status_code=status.HTTP_202_ACCEPTED, responses={ @@ -372,6 +382,8 @@ def start(config: ApplicationConfig): "%(asctime)s %(levelprefix)s %(client_addr)s" + " - '%(request_line)s' %(status_code)s" ) + app = get_app() + FastAPIInstrumentor().instrument_app( app, tracer_provider=get_tracer_provider(), @@ -379,22 +391,23 @@ def start(config: ApplicationConfig): http_capture_headers_server_response=[",*"], ) app.state.config = config + uvicorn.run(app, host=config.api.host, port=config.api.port) -@app.middleware("http") -async def add_api_version_header(request: Request, call_next): +async def add_api_version_header( + request: Request, call_next: Callable[[Request], Awaitable[Response]] +): response = await call_next(request) response.headers["X-API-Version"] = REST_API_VERSION return response -@app.middleware("http") async def inject_propagated_observability_context( request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: - """Middleware to extract the any prorpagated observability context from the - HTTP headers and attatch it to the local one. + """Middleware to extract the any propagated observability context from the + HTTP headers and attach it to the local one. """ if CONTEXT_HEADER in request.headers: ctx = get_global_textmap().extract( diff --git a/src/blueapi/service/openapi.py b/src/blueapi/service/openapi.py index 859fae8fd..762a081b9 100644 --- a/src/blueapi/service/openapi.py +++ b/src/blueapi/service/openapi.py @@ -7,12 +7,13 @@ from fastapi.openapi.utils import get_openapi from pyparsing import Any -from blueapi.service.main import app +from blueapi.service.main import get_app DOCS_SCHEMA_LOCATION = Path(__file__).parents[3] / "docs" / "reference" / "openapi.yaml" def generate_schema() -> Mapping[str, Any]: + app = get_app() return get_openapi( title=app.title, version=app.version, diff --git a/tests/unit_tests/service/test_openapi.py b/tests/unit_tests/service/test_openapi.py index edc5c726e..c0f8d7017 100644 --- a/tests/unit_tests/service/test_openapi.py +++ b/tests/unit_tests/service/test_openapi.py @@ -7,9 +7,13 @@ from blueapi.service.openapi import DOCS_SCHEMA_LOCATION, generate_schema -@mock.patch("blueapi.service.openapi.app") -def test_generate_schema(mock_app: Mock) -> None: - from blueapi.service.main import app +@mock.patch("blueapi.service.openapi.get_app") +def test_generate_schema(mock_get_app: Mock) -> None: + mock_app = mock_get_app() + + from blueapi.service.main import get_app + + app = get_app() title = PropertyMock(return_value="title") version = PropertyMock(return_value=app.version) @@ -23,8 +27,6 @@ def test_generate_schema(mock_app: Mock) -> None: type(mock_app).description = description type(mock_app).routes = routes - # from blueapi.service.openapi import generate_schema - assert generate_schema() == { "openapi": openapi_version(), "info": { diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index decf2a483..8a1fe5219 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -25,11 +25,9 @@ @pytest.fixture def client() -> Iterator[TestClient]: - with ( - patch("blueapi.service.interface.worker"), - ): + with patch("blueapi.service.interface.worker"): main.setup_runner(use_subprocess=False) - yield TestClient(main.app) + yield TestClient(main.get_app()) main.teardown_runner()