diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a1039f83a..2c47e8b53a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#3037](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3037)) - `opentelemetry-instrumentation-sqlalchemy`: Fix a remaining memory leak in EngineTracer ([#3053](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3053)) +- `opentelemetry-instrumentation-fastapi`: instrument unhandled exceptions + ([#3012](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3012)) ### Breaking changes diff --git a/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py b/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py index 7e4d0aac07..e7ec2ba104 100644 --- a/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py @@ -179,10 +179,14 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A from __future__ import annotations import logging +import types from typing import Collection, Literal import fastapi +from starlette.applications import Starlette +from starlette.middleware.errors import ServerErrorMiddleware from starlette.routing import Match +from starlette.types import ASGIApp from opentelemetry.instrumentation._semconv import ( _get_schema_url, @@ -199,9 +203,9 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A from opentelemetry.instrumentation.fastapi.package import _instruments from opentelemetry.instrumentation.fastapi.version import __version__ from opentelemetry.instrumentation.instrumentor import BaseInstrumentor -from opentelemetry.metrics import get_meter +from opentelemetry.metrics import MeterProvider, get_meter from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace import get_tracer +from opentelemetry.trace import TracerProvider, get_tracer from opentelemetry.util.http import ( get_excluded_urls, parse_excluded_urls, @@ -226,9 +230,9 @@ def instrument_app( server_request_hook: ServerRequestHook = None, client_request_hook: ClientRequestHook = None, client_response_hook: ClientResponseHook = None, - tracer_provider=None, - meter_provider=None, - excluded_urls=None, + tracer_provider: TracerProvider | None = None, + meter_provider: MeterProvider | None = None, + excluded_urls: str | None = None, http_capture_headers_server_request: list[str] | None = None, http_capture_headers_server_response: list[str] | None = None, http_capture_headers_sanitize_fields: list[str] | None = None, @@ -280,21 +284,40 @@ def instrument_app( schema_url=_get_schema_url(sem_conv_opt_in_mode), ) - app.add_middleware( - OpenTelemetryMiddleware, - excluded_urls=excluded_urls, - default_span_details=_get_default_span_details, - server_request_hook=server_request_hook, - client_request_hook=client_request_hook, - client_response_hook=client_response_hook, - # Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation - tracer=tracer, - meter=meter, - http_capture_headers_server_request=http_capture_headers_server_request, - http_capture_headers_server_response=http_capture_headers_server_response, - http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields, - exclude_spans=exclude_spans, + # Instead of using `app.add_middleware` we monkey patch `build_middleware_stack` to insert our middleware + # as the outermost middleware. + # Otherwise `OpenTelemetryMiddleware` would have unhandled exceptions tearing through it and would not be able + # to faithfully record what is returned to the client since it technically cannot know what `ServerErrorMiddleware` is going to do. + + def build_middleware_stack(self: Starlette) -> ASGIApp: + app = type(self).build_middleware_stack(self) + app = OpenTelemetryMiddleware( + app, + excluded_urls=excluded_urls, + default_span_details=_get_default_span_details, + server_request_hook=server_request_hook, + client_request_hook=client_request_hook, + client_response_hook=client_response_hook, + # Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation + tracer=tracer, + meter=meter, + http_capture_headers_server_request=http_capture_headers_server_request, + http_capture_headers_server_response=http_capture_headers_server_response, + http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields, + exclude_spans=exclude_spans, + ) + # Wrap in an outer layer of ServerErrorMiddleware so that any exceptions raised in OpenTelemetryMiddleware + # are handled. + # This should not happen unless there is a bug in OpenTelemetryMiddleware, but if there is we don't want that + # to impact the user's application just because we wrapped the middlewares in this order. + app = ServerErrorMiddleware(app) + return app + + app._original_build_middleware_stack = app.build_middleware_stack + app.build_middleware_stack = types.MethodType( + build_middleware_stack, app ) + app._is_instrumented_by_opentelemetry = True if app not in _InstrumentedFastAPI._instrumented_fastapi_apps: _InstrumentedFastAPI._instrumented_fastapi_apps.add(app) @@ -305,11 +328,12 @@ def instrument_app( @staticmethod def uninstrument_app(app: fastapi.FastAPI): - app.user_middleware = [ - x - for x in app.user_middleware - if x.cls is not OpenTelemetryMiddleware - ] + original_build_middleware_stack = getattr( + app, "_original_build_middleware_stack", None + ) + if original_build_middleware_stack: + app.build_middleware_stack = original_build_middleware_stack + del app._original_build_middleware_stack app.middleware_stack = app.build_middleware_stack() app._is_instrumented_by_opentelemetry = False @@ -337,12 +361,7 @@ def _instrument(self, **kwargs): _InstrumentedFastAPI._http_capture_headers_sanitize_fields = ( kwargs.get("http_capture_headers_sanitize_fields") ) - _excluded_urls = kwargs.get("excluded_urls") - _InstrumentedFastAPI._excluded_urls = ( - _excluded_urls_from_env - if _excluded_urls is None - else parse_excluded_urls(_excluded_urls) - ) + _InstrumentedFastAPI._excluded_urls = kwargs.get("excluded_urls") _InstrumentedFastAPI._meter_provider = kwargs.get("meter_provider") _InstrumentedFastAPI._exclude_spans = kwargs.get("exclude_spans") fastapi.FastAPI = _InstrumentedFastAPI @@ -361,43 +380,29 @@ class _InstrumentedFastAPI(fastapi.FastAPI): _server_request_hook: ServerRequestHook = None _client_request_hook: ClientRequestHook = None _client_response_hook: ClientResponseHook = None + _http_capture_headers_server_request: list[str] | None = None + _http_capture_headers_server_response: list[str] | None = None + _http_capture_headers_sanitize_fields: list[str] | None = None + _exclude_spans: list[Literal["receive", "send"]] | None = None + _instrumented_fastapi_apps = set() _sem_conv_opt_in_mode = _HTTPStabilityMode.DEFAULT def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - tracer = get_tracer( - __name__, - __version__, - _InstrumentedFastAPI._tracer_provider, - schema_url=_get_schema_url( - _InstrumentedFastAPI._sem_conv_opt_in_mode - ), - ) - meter = get_meter( - __name__, - __version__, - _InstrumentedFastAPI._meter_provider, - schema_url=_get_schema_url( - _InstrumentedFastAPI._sem_conv_opt_in_mode - ), - ) - self.add_middleware( - OpenTelemetryMiddleware, - excluded_urls=_InstrumentedFastAPI._excluded_urls, - default_span_details=_get_default_span_details, - server_request_hook=_InstrumentedFastAPI._server_request_hook, - client_request_hook=_InstrumentedFastAPI._client_request_hook, - client_response_hook=_InstrumentedFastAPI._client_response_hook, - # Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation - tracer=tracer, - meter=meter, - http_capture_headers_server_request=_InstrumentedFastAPI._http_capture_headers_server_request, - http_capture_headers_server_response=_InstrumentedFastAPI._http_capture_headers_server_response, - http_capture_headers_sanitize_fields=_InstrumentedFastAPI._http_capture_headers_sanitize_fields, - exclude_spans=_InstrumentedFastAPI._exclude_spans, + FastAPIInstrumentor.instrument_app( + self, + server_request_hook=self._server_request_hook, + client_request_hook=self._client_request_hook, + client_response_hook=self._client_response_hook, + tracer_provider=self._tracer_provider, + meter_provider=self._meter_provider, + excluded_urls=self._excluded_urls, + http_capture_headers_server_request=self._http_capture_headers_server_request, + http_capture_headers_server_response=self._http_capture_headers_server_response, + http_capture_headers_sanitize_fields=self._http_capture_headers_sanitize_fields, + exclude_spans=self._exclude_spans, ) - self._is_instrumented_by_opentelemetry = True _InstrumentedFastAPI._instrumented_fastapi_apps.add(self) def __del__(self): diff --git a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py index fdbad4effb..3a84bc067c 100644 --- a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py +++ b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py @@ -15,6 +15,7 @@ # pylint: disable=too-many-lines import unittest +from contextlib import ExitStack from timeit import default_timer from unittest.mock import Mock, patch @@ -170,9 +171,14 @@ def setUp(self): self._instrumentor = otel_fastapi.FastAPIInstrumentor() self._app = self._create_app() self._app.add_middleware(HTTPSRedirectMiddleware) - self._client = TestClient(self._app) + self._client = TestClient(self._app, base_url="https://testserver:443") + # run the lifespan, initialize the middleware stack + # this is more in-line with what happens in a real application when the server starts up + self._exit_stack = ExitStack() + self._exit_stack.enter_context(self._client) def tearDown(self): + self._exit_stack.close() super().tearDown() self.env_patch.stop() self.exclude_patch.stop() @@ -205,11 +211,19 @@ async def _(param: str): async def _(): return {"message": "ok"} + @app.get("/error") + async def _(): + raise UnhandledException("This is an unhandled exception") + app.mount("/sub", app=sub_app) return app +class UnhandledException(Exception): + pass + + class TestBaseManualFastAPI(TestBaseFastAPI): @classmethod def setUpClass(cls): @@ -220,6 +234,27 @@ def setUpClass(cls): super(TestBaseManualFastAPI, cls).setUpClass() + def test_fastapi_unhandled_exception(self): + """If the application has an unhandled error the instrumentation should capture that a 500 response is returned.""" + try: + resp = self._client.get("/error") + assert ( + resp.status_code == 500 + ), resp.content # pragma: no cover, for debugging this test if an exception is _not_ raised + except UnhandledException: + pass + else: + self.fail("Expected UnhandledException") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 3) + span = spans[0] + assert span.name == "GET /error http send" + assert span.attributes[SpanAttributes.HTTP_STATUS_CODE] == 500 + span = spans[2] + assert span.name == "GET /error" + assert span.attributes[SpanAttributes.HTTP_TARGET] == "/error" + def test_sub_app_fastapi_call(self): """ This test is to ensure that a span in case of a sub app targeted contains the correct server url @@ -976,6 +1011,10 @@ async def _(param: str): async def _(): return {"message": "ok"} + @app.get("/error") + async def _(): + raise UnhandledException("This is an unhandled exception") + app.mount("/sub", app=sub_app) return app @@ -1124,9 +1163,11 @@ def test_request(self): def test_mulitple_way_instrumentation(self): self._instrumentor.instrument_app(self._app) count = 0 - for middleware in self._app.user_middleware: - if middleware.cls is OpenTelemetryMiddleware: + app = self._app.middleware_stack + while app is not None: + if isinstance(app, OpenTelemetryMiddleware): count += 1 + app = getattr(app, "app", None) self.assertEqual(count, 1) def test_uninstrument_after_instrument(self):