-
Notifications
You must be signed in to change notification settings - Fork 626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fastapi: fix wrapping of middlewares #3012
base: main
Are you sure you want to change the base?
Changes from 20 commits
707b192
dcfab56
a78b1c6
e25e143
54d4482
81f1e83
96e66b9
46e1809
daa5a1b
de9e795
f9d348b
7bc89b3
f629932
437891b
1340ac7
003d7af
be26393
ec545b7
f825c76
d909633
0de3e37
d57b673
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -179,10 +179,14 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A | |
from __future__ import annotations | ||
|
||
import logging | ||
import types | ||
from typing import Collection, Literal | ||
|
||
import fastapi | ||
from starlette.applications import Starlette | ||
from starlette.middleware.errors import ServerErrorMiddleware | ||
from starlette.routing import Match | ||
from starlette.types import ASGIApp | ||
|
||
from opentelemetry.instrumentation._semconv import ( | ||
_get_schema_url, | ||
|
@@ -199,9 +203,9 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A | |
from opentelemetry.instrumentation.fastapi.package import _instruments | ||
from opentelemetry.instrumentation.fastapi.version import __version__ | ||
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor | ||
from opentelemetry.metrics import get_meter | ||
from opentelemetry.metrics import MeterProvider, get_meter | ||
from opentelemetry.semconv.trace import SpanAttributes | ||
from opentelemetry.trace import get_tracer | ||
from opentelemetry.trace import TracerProvider, get_tracer | ||
from opentelemetry.util.http import ( | ||
get_excluded_urls, | ||
parse_excluded_urls, | ||
|
@@ -226,9 +230,9 @@ def instrument_app( | |
server_request_hook: ServerRequestHook = None, | ||
client_request_hook: ClientRequestHook = None, | ||
client_response_hook: ClientResponseHook = None, | ||
tracer_provider=None, | ||
meter_provider=None, | ||
excluded_urls=None, | ||
tracer_provider: TracerProvider | None = None, | ||
meter_provider: MeterProvider | None = None, | ||
excluded_urls: str | None = None, | ||
http_capture_headers_server_request: list[str] | None = None, | ||
http_capture_headers_server_response: list[str] | None = None, | ||
http_capture_headers_sanitize_fields: list[str] | None = None, | ||
|
@@ -280,21 +284,40 @@ def instrument_app( | |
schema_url=_get_schema_url(sem_conv_opt_in_mode), | ||
) | ||
|
||
app.add_middleware( | ||
OpenTelemetryMiddleware, | ||
excluded_urls=excluded_urls, | ||
default_span_details=_get_default_span_details, | ||
server_request_hook=server_request_hook, | ||
client_request_hook=client_request_hook, | ||
client_response_hook=client_response_hook, | ||
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation | ||
tracer=tracer, | ||
meter=meter, | ||
http_capture_headers_server_request=http_capture_headers_server_request, | ||
http_capture_headers_server_response=http_capture_headers_server_response, | ||
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields, | ||
exclude_spans=exclude_spans, | ||
# Instead of using `app.add_middleware` we monkey patch `build_middleware_stack` to insert our middleware | ||
# as the outermost middleware. | ||
# Otherwise `OpenTelemetryMiddleware` would have unhandled exceptions tearing through it and would not be able | ||
# to faithfully record what is returned to the client since it technically cannot know what `ServerErrorMiddleware` is going to do. | ||
|
||
def build_middleware_stack(self: Starlette) -> ASGIApp: | ||
app = type(self).build_middleware_stack(self) | ||
app = OpenTelemetryMiddleware( | ||
app, | ||
excluded_urls=excluded_urls, | ||
default_span_details=_get_default_span_details, | ||
server_request_hook=server_request_hook, | ||
client_request_hook=client_request_hook, | ||
client_response_hook=client_response_hook, | ||
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation | ||
tracer=tracer, | ||
meter=meter, | ||
http_capture_headers_server_request=http_capture_headers_server_request, | ||
http_capture_headers_server_response=http_capture_headers_server_response, | ||
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields, | ||
exclude_spans=exclude_spans, | ||
) | ||
# Wrap in an outer layer of ServerErrorMiddleware so that any exceptions raised in OpenTelemetryMiddleware | ||
# are handled. | ||
# This should not happen unless there is a bug in OpenTelemetryMiddleware, but if there is we don't want that | ||
# to impact the user's application just because we wrapped the middlewares in this order. | ||
app = ServerErrorMiddleware(app) | ||
return app | ||
|
||
app._original_build_middleware_stack = app.build_middleware_stack | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: do you think it would be possible to use wrapt for monkeypatching instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather not. This is simpler and works just fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @xrmx genuine question: what would be the benefit of that in this case? |
||
app.build_middleware_stack = types.MethodType( | ||
build_middleware_stack, app | ||
) | ||
|
||
app._is_instrumented_by_opentelemetry = True | ||
if app not in _InstrumentedFastAPI._instrumented_fastapi_apps: | ||
_InstrumentedFastAPI._instrumented_fastapi_apps.add(app) | ||
|
@@ -305,11 +328,12 @@ def instrument_app( | |
|
||
@staticmethod | ||
def uninstrument_app(app: fastapi.FastAPI): | ||
app.user_middleware = [ | ||
x | ||
for x in app.user_middleware | ||
if x.cls is not OpenTelemetryMiddleware | ||
] | ||
original_build_middleware_stack = getattr( | ||
app, "_original_build_middleware_stack", None | ||
) | ||
if original_build_middleware_stack: | ||
app.build_middleware_stack = original_build_middleware_stack | ||
del app._original_build_middleware_stack | ||
app.middleware_stack = app.build_middleware_stack() | ||
app._is_instrumented_by_opentelemetry = False | ||
|
||
|
@@ -337,12 +361,7 @@ def _instrument(self, **kwargs): | |
_InstrumentedFastAPI._http_capture_headers_sanitize_fields = ( | ||
kwargs.get("http_capture_headers_sanitize_fields") | ||
) | ||
_excluded_urls = kwargs.get("excluded_urls") | ||
_InstrumentedFastAPI._excluded_urls = ( | ||
_excluded_urls_from_env | ||
if _excluded_urls is None | ||
else parse_excluded_urls(_excluded_urls) | ||
) | ||
_InstrumentedFastAPI._excluded_urls = kwargs.get("excluded_urls") | ||
_InstrumentedFastAPI._meter_provider = kwargs.get("meter_provider") | ||
_InstrumentedFastAPI._exclude_spans = kwargs.get("exclude_spans") | ||
fastapi.FastAPI = _InstrumentedFastAPI | ||
|
@@ -361,43 +380,29 @@ class _InstrumentedFastAPI(fastapi.FastAPI): | |
_server_request_hook: ServerRequestHook = None | ||
_client_request_hook: ClientRequestHook = None | ||
_client_response_hook: ClientResponseHook = None | ||
_http_capture_headers_server_request: list[str] | None = None | ||
_http_capture_headers_server_response: list[str] | None = None | ||
_http_capture_headers_sanitize_fields: list[str] | None = None | ||
_exclude_spans: list[Literal["receive", "send"]] | None = None | ||
|
||
_instrumented_fastapi_apps = set() | ||
_sem_conv_opt_in_mode = _HTTPStabilityMode.DEFAULT | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
tracer = get_tracer( | ||
__name__, | ||
__version__, | ||
_InstrumentedFastAPI._tracer_provider, | ||
schema_url=_get_schema_url( | ||
_InstrumentedFastAPI._sem_conv_opt_in_mode | ||
), | ||
) | ||
meter = get_meter( | ||
__name__, | ||
__version__, | ||
_InstrumentedFastAPI._meter_provider, | ||
schema_url=_get_schema_url( | ||
_InstrumentedFastAPI._sem_conv_opt_in_mode | ||
), | ||
) | ||
self.add_middleware( | ||
OpenTelemetryMiddleware, | ||
excluded_urls=_InstrumentedFastAPI._excluded_urls, | ||
default_span_details=_get_default_span_details, | ||
server_request_hook=_InstrumentedFastAPI._server_request_hook, | ||
client_request_hook=_InstrumentedFastAPI._client_request_hook, | ||
client_response_hook=_InstrumentedFastAPI._client_response_hook, | ||
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation | ||
tracer=tracer, | ||
meter=meter, | ||
http_capture_headers_server_request=_InstrumentedFastAPI._http_capture_headers_server_request, | ||
http_capture_headers_server_response=_InstrumentedFastAPI._http_capture_headers_server_response, | ||
http_capture_headers_sanitize_fields=_InstrumentedFastAPI._http_capture_headers_sanitize_fields, | ||
exclude_spans=_InstrumentedFastAPI._exclude_spans, | ||
FastAPIInstrumentor.instrument_app( | ||
self, | ||
server_request_hook=self._server_request_hook, | ||
client_request_hook=self._client_request_hook, | ||
client_response_hook=self._client_response_hook, | ||
tracer_provider=self._tracer_provider, | ||
meter_provider=self._meter_provider, | ||
excluded_urls=self._excluded_urls, | ||
http_capture_headers_server_request=self._http_capture_headers_server_request, | ||
http_capture_headers_server_response=self._http_capture_headers_server_response, | ||
http_capture_headers_sanitize_fields=self._http_capture_headers_sanitize_fields, | ||
exclude_spans=self._exclude_spans, | ||
) | ||
self._is_instrumented_by_opentelemetry = True | ||
_InstrumentedFastAPI._instrumented_fastapi_apps.add(self) | ||
|
||
def __del__(self): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
# pylint: disable=too-many-lines | ||
|
||
import unittest | ||
from contextlib import ExitStack | ||
from timeit import default_timer | ||
from unittest.mock import Mock, patch | ||
|
||
|
@@ -170,9 +171,14 @@ def setUp(self): | |
self._instrumentor = otel_fastapi.FastAPIInstrumentor() | ||
self._app = self._create_app() | ||
self._app.add_middleware(HTTPSRedirectMiddleware) | ||
self._client = TestClient(self._app) | ||
self._client = TestClient(self._app, base_url="https://testserver:443") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was the root cause of tests failures. It turns out every request was being made twice because it was getting redirected to https. Before this PR that wasn't being instrumented correctly, so this was not being caught! I think that's another major bug this PR is fixing. |
||
# run the lifespan, initialize the middleware stack | ||
# this is more in-line with what happens in a real application when the server starts up | ||
self._exit_stack = ExitStack() | ||
self._exit_stack.enter_context(self._client) | ||
Comment on lines
+175
to
+178
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
||
def tearDown(self): | ||
self._exit_stack.close() | ||
super().tearDown() | ||
self.env_patch.stop() | ||
self.exclude_patch.stop() | ||
|
@@ -205,11 +211,19 @@ async def _(param: str): | |
async def _(): | ||
return {"message": "ok"} | ||
|
||
@app.get("/error") | ||
async def _(): | ||
raise UnhandledException("This is an unhandled exception") | ||
|
||
app.mount("/sub", app=sub_app) | ||
|
||
return app | ||
|
||
|
||
class UnhandledException(Exception): | ||
pass | ||
|
||
|
||
class TestBaseManualFastAPI(TestBaseFastAPI): | ||
@classmethod | ||
def setUpClass(cls): | ||
|
@@ -398,6 +412,26 @@ def test_fastapi_excluded_urls(self): | |
spans = self.memory_exporter.get_finished_spans() | ||
self.assertEqual(len(spans), 0) | ||
|
||
def test_fastapi_unhandled_exception(self): | ||
"""If the application has an unhandled error the instrumentation should capture that a 500 response is returned.""" | ||
try: | ||
self._client.get("/error") | ||
except UnhandledException: | ||
pass | ||
else: | ||
self.fail("Expected UnhandledException") | ||
|
||
spans = self.memory_exporter.get_finished_spans() | ||
self.assertEqual(len(spans), 3) | ||
for span in spans: | ||
self.assertIn("GET /error", span.name) | ||
self.assertEqual( | ||
span.attributes[SpanAttributes.HTTP_ROUTE], "/error" | ||
) | ||
self.assertEqual( | ||
span.attributes[SpanAttributes.HTTP_STATUS_CODE], 500 | ||
) | ||
|
||
def test_fastapi_excluded_urls_not_env(self): | ||
"""Ensure that given fastapi routes are excluded when passed explicitly (not in the environment)""" | ||
app = self._create_app_explicit_excluded_urls() | ||
|
@@ -1124,9 +1158,11 @@ def test_request(self): | |
def test_mulitple_way_instrumentation(self): | ||
self._instrumentor.instrument_app(self._app) | ||
count = 0 | ||
for middleware in self._app.user_middleware: | ||
if middleware.cls is OpenTelemetryMiddleware: | ||
app = self._app.middleware_stack | ||
while app is not None: | ||
if isinstance(app, OpenTelemetryMiddleware): | ||
count += 1 | ||
app = getattr(app, "app", None) | ||
self.assertEqual(count, 1) | ||
|
||
def test_uninstrument_after_instrument(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Could you use
typing.Optional
to be consistent with most other places in the code?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other places are out dated.
What should happen is to start switching the other places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hard agreement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For context the
from __future__ import annotations
let us use this syntax