From d91b2412199fc12995ff2b007db80e70cdfd9c86 Mon Sep 17 00:00:00 2001 From: Anton Pirker Date: Fri, 1 Sep 2023 15:31:57 +0200 Subject: [PATCH] Cleanup ASGI integration (#2335) This does not change behaviour/functionality. Some smaller refactoring to make it easier to work on ASGI (and probably Starlette) integration --- sentry_sdk/integrations/_asgi_common.py | 104 ++++++++++++++++++++ sentry_sdk/integrations/asgi.py | 124 +++++++----------------- sentry_sdk/integrations/fastapi.py | 5 +- sentry_sdk/integrations/starlette.py | 4 + tests/integrations/asgi/test_asgi.py | 87 ++++++++++------- 5 files changed, 196 insertions(+), 128 deletions(-) create mode 100644 sentry_sdk/integrations/_asgi_common.py diff --git a/sentry_sdk/integrations/_asgi_common.py b/sentry_sdk/integrations/_asgi_common.py new file mode 100644 index 0000000000..3d14393b03 --- /dev/null +++ b/sentry_sdk/integrations/_asgi_common.py @@ -0,0 +1,104 @@ +import urllib + +from sentry_sdk.hub import _should_send_default_pii +from sentry_sdk.integrations._wsgi_common import _filter_headers +from sentry_sdk._types import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + from typing import Dict + from typing import Optional + from typing_extensions import Literal + + +def _get_headers(asgi_scope): + # type: (Any) -> Dict[str, str] + """ + Extract headers from the ASGI scope, in the format that the Sentry protocol expects. + """ + headers = {} # type: Dict[str, str] + for raw_key, raw_value in asgi_scope["headers"]: + key = raw_key.decode("latin-1") + value = raw_value.decode("latin-1") + if key in headers: + headers[key] = headers[key] + ", " + value + else: + headers[key] = value + + return headers + + +def _get_url(asgi_scope, default_scheme, host): + # type: (Dict[str, Any], Literal["ws", "http"], Optional[str]) -> str + """ + Extract URL from the ASGI scope, without also including the querystring. + """ + scheme = asgi_scope.get("scheme", default_scheme) + + server = asgi_scope.get("server", None) + path = asgi_scope.get("root_path", "") + asgi_scope.get("path", "") + + if host: + return "%s://%s%s" % (scheme, host, path) + + if server is not None: + host, port = server + default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] + if port != default_port: + return "%s://%s:%s%s" % (scheme, host, port, path) + return "%s://%s%s" % (scheme, host, path) + return path + + +def _get_query(asgi_scope): + # type: (Any) -> Any + """ + Extract querystring from the ASGI scope, in the format that the Sentry protocol expects. + """ + qs = asgi_scope.get("query_string") + if not qs: + return None + return urllib.parse.unquote(qs.decode("latin-1")) + + +def _get_ip(asgi_scope): + # type: (Any) -> str + """ + Extract IP Address from the ASGI scope based on request headers with fallback to scope client. + """ + headers = _get_headers(asgi_scope) + try: + return headers["x-forwarded-for"].split(",")[0].strip() + except (KeyError, IndexError): + pass + + try: + return headers["x-real-ip"] + except KeyError: + pass + + return asgi_scope.get("client")[0] + + +def _get_request_data(asgi_scope): + # type: (Any) -> Dict[str, Any] + """ + Returns data related to the HTTP request from the ASGI scope. + """ + request_data = {} # type: Dict[str, Any] + ty = asgi_scope["type"] + if ty in ("http", "websocket"): + request_data["method"] = asgi_scope.get("method") + + request_data["headers"] = headers = _filter_headers(_get_headers(asgi_scope)) + request_data["query_string"] = _get_query(asgi_scope) + + request_data["url"] = _get_url( + asgi_scope, "http" if ty == "http" else "ws", headers.get("host") + ) + + client = asgi_scope.get("client") + if client and _should_send_default_pii(): + request_data["env"] = {"REMOTE_ADDR": _get_ip(asgi_scope)} + + return request_data diff --git a/sentry_sdk/integrations/asgi.py b/sentry_sdk/integrations/asgi.py index 25846cfc6e..b5170d3ab7 100644 --- a/sentry_sdk/integrations/asgi.py +++ b/sentry_sdk/integrations/asgi.py @@ -6,15 +6,18 @@ import asyncio import inspect -import urllib from copy import deepcopy from sentry_sdk._functools import partial from sentry_sdk._types import TYPE_CHECKING from sentry_sdk.api import continue_trace from sentry_sdk.consts import OP -from sentry_sdk.hub import Hub, _should_send_default_pii -from sentry_sdk.integrations._wsgi_common import _filter_headers +from sentry_sdk.hub import Hub + +from sentry_sdk.integrations._asgi_common import ( + _get_headers, + _get_request_data, +) from sentry_sdk.integrations.modules import _get_installed_modules from sentry_sdk.sessions import auto_session_tracking from sentry_sdk.tracing import ( @@ -37,8 +40,6 @@ from typing import Optional from typing import Callable - from typing_extensions import Literal - from sentry_sdk._types import Event, Hint @@ -169,19 +170,32 @@ async def _run_app(self, scope, receive, send, asgi_version): if ty in ("http", "websocket"): transaction = continue_trace( - self._get_headers(scope), + _get_headers(scope), op="{}.server".format(ty), ) + logger.debug( + "[ASGI] Created transaction (continuing trace): %s", + transaction, + ) else: transaction = Transaction(op=OP.HTTP_SERVER) + logger.debug( + "[ASGI] Created transaction (new): %s", transaction + ) transaction.name = _DEFAULT_TRANSACTION_NAME transaction.source = TRANSACTION_SOURCE_ROUTE transaction.set_tag("asgi.type", ty) + logger.debug( + "[ASGI] Set transaction name and source on transaction: '%s' / '%s'", + transaction.name, + transaction.source, + ) with hub.start_transaction( transaction, custom_sampling_context={"asgi_scope": scope} ): + logger.debug("[ASGI] Started transaction: %s", transaction) try: async def _sentry_wrapped_send(event): @@ -214,31 +228,15 @@ async def _sentry_wrapped_send(event): def event_processor(self, event, hint, asgi_scope): # type: (Event, Hint, Any) -> Optional[Event] - request_info = event.get("request", {}) - - ty = asgi_scope["type"] - if ty in ("http", "websocket"): - request_info["method"] = asgi_scope.get("method") - request_info["headers"] = headers = _filter_headers( - self._get_headers(asgi_scope) - ) - request_info["query_string"] = self._get_query(asgi_scope) - - request_info["url"] = self._get_url( - asgi_scope, "http" if ty == "http" else "ws", headers.get("host") - ) - - client = asgi_scope.get("client") - if client and _should_send_default_pii(): - request_info["env"] = {"REMOTE_ADDR": self._get_ip(asgi_scope)} + request_data = event.get("request", {}) + request_data.update(_get_request_data(asgi_scope)) + event["request"] = deepcopy(request_data) self._set_transaction_name_and_source(event, self.transaction_style, asgi_scope) - event["request"] = deepcopy(request_info) - return event - # Helper functions for extracting request data. + # Helper functions. # # Note: Those functions are not public API. If you want to mutate request # data to your liking it's recommended to use the `before_send` callback @@ -275,71 +273,17 @@ def _set_transaction_name_and_source(self, event, transaction_style, asgi_scope) if not name: event["transaction"] = _DEFAULT_TRANSACTION_NAME event["transaction_info"] = {"source": TRANSACTION_SOURCE_ROUTE} + logger.debug( + "[ASGI] Set default transaction name and source on event: '%s' / '%s'", + event["transaction"], + event["transaction_info"]["source"], + ) return event["transaction"] = name event["transaction_info"] = {"source": SOURCE_FOR_STYLE[transaction_style]} - - def _get_url(self, scope, default_scheme, host): - # type: (Dict[str, Any], Literal["ws", "http"], Optional[str]) -> str - """ - Extract URL from the ASGI scope, without also including the querystring. - """ - scheme = scope.get("scheme", default_scheme) - - server = scope.get("server", None) - path = scope.get("root_path", "") + scope.get("path", "") - - if host: - return "%s://%s%s" % (scheme, host, path) - - if server is not None: - host, port = server - default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] - if port != default_port: - return "%s://%s:%s%s" % (scheme, host, port, path) - return "%s://%s%s" % (scheme, host, path) - return path - - def _get_query(self, scope): - # type: (Any) -> Any - """ - Extract querystring from the ASGI scope, in the format that the Sentry protocol expects. - """ - qs = scope.get("query_string") - if not qs: - return None - return urllib.parse.unquote(qs.decode("latin-1")) - - def _get_ip(self, scope): - # type: (Any) -> str - """ - Extract IP Address from the ASGI scope based on request headers with fallback to scope client. - """ - headers = self._get_headers(scope) - try: - return headers["x-forwarded-for"].split(",")[0].strip() - except (KeyError, IndexError): - pass - - try: - return headers["x-real-ip"] - except KeyError: - pass - - return scope.get("client")[0] - - def _get_headers(self, scope): - # type: (Any) -> Dict[str, str] - """ - Extract headers from the ASGI scope, in the format that the Sentry protocol expects. - """ - headers = {} # type: Dict[str, str] - for raw_key, raw_value in scope["headers"]: - key = raw_key.decode("latin-1") - value = raw_value.decode("latin-1") - if key in headers: - headers[key] = headers[key] + ", " + value - else: - headers[key] = value - return headers + logger.debug( + "[ASGI] Set transaction name and source on event: '%s' / '%s'", + event["transaction"], + event["transaction_info"]["source"], + ) diff --git a/sentry_sdk/integrations/fastapi.py b/sentry_sdk/integrations/fastapi.py index 17e0576c18..11c9bdcf51 100644 --- a/sentry_sdk/integrations/fastapi.py +++ b/sentry_sdk/integrations/fastapi.py @@ -5,7 +5,7 @@ from sentry_sdk.hub import Hub, _should_send_default_pii from sentry_sdk.integrations import DidNotEnable from sentry_sdk.tracing import SOURCE_FOR_STYLE, TRANSACTION_SOURCE_ROUTE -from sentry_sdk.utils import transaction_from_function +from sentry_sdk.utils import transaction_from_function, logger if TYPE_CHECKING: from typing import Any, Callable, Dict @@ -60,6 +60,9 @@ def _set_transaction_name_and_source(scope, transaction_style, request): source = SOURCE_FOR_STYLE[transaction_style] scope.set_transaction_name(name, source=source) + logger.debug( + "[FastAPI] Set transaction name and source on scope: %s / %s", name, source + ) def patch_get_request_handler(): diff --git a/sentry_sdk/integrations/starlette.py b/sentry_sdk/integrations/starlette.py index b44e8f10b7..1e3944aff3 100644 --- a/sentry_sdk/integrations/starlette.py +++ b/sentry_sdk/integrations/starlette.py @@ -19,6 +19,7 @@ AnnotatedValue, capture_internal_exceptions, event_from_exception, + logger, parse_version, transaction_from_function, ) @@ -648,3 +649,6 @@ def _set_transaction_name_and_source(scope, transaction_style, request): source = SOURCE_FOR_STYLE[transaction_style] scope.set_transaction_name(name, source=source) + logger.debug( + "[Starlette] Set transaction name and source on scope: %s / %s", name, source + ) diff --git a/tests/integrations/asgi/test_asgi.py b/tests/integrations/asgi/test_asgi.py index dcd770ac37..29aab5783a 100644 --- a/tests/integrations/asgi/test_asgi.py +++ b/tests/integrations/asgi/test_asgi.py @@ -5,6 +5,7 @@ import pytest import sentry_sdk from sentry_sdk import capture_message +from sentry_sdk.integrations._asgi_common import _get_ip, _get_headers from sentry_sdk.integrations.asgi import SentryAsgiMiddleware, _looks_like_asgi3 async_asgi_testclient = pytest.importorskip("async_asgi_testclient") @@ -19,7 +20,15 @@ @pytest.fixture def asgi3_app(): async def app(scope, receive, send): - if ( + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + await send({"type": "lifespan.shutdown.complete"}) + return + elif ( scope["type"] == "http" and "route" in scope and scope["route"] == "/trigger/error" @@ -52,21 +61,32 @@ async def send_with_error(event): 1 / 0 async def app(scope, receive, send): - await send_with_error( - { - "type": "http.response.start", - "status": 200, - "headers": [ - [b"content-type", b"text/plain"], - ], - } - ) - await send_with_error( - { - "type": "http.response.body", - "body": b"Hello, world!", - } - ) + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + ... # Do some startup here! + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + ... # Do some shutdown here! + await send({"type": "lifespan.shutdown.complete"}) + return + else: + await send_with_error( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"content-type", b"text/plain"], + ], + } + ) + await send_with_error( + { + "type": "http.response.body", + "body": b"Hello, world!", + } + ) return app @@ -139,10 +159,11 @@ async def test_capture_transaction( events = capture_events() await client.get("/?somevalue=123") - (transaction_event,) = events + (transaction_event, lifespan_transaction_event) = events assert transaction_event["type"] == "transaction" assert transaction_event["transaction"] == "generic ASGI request" + assert transaction_event["transaction_info"] == {"source": "route"} assert transaction_event["contexts"]["trace"]["op"] == "http.server" assert transaction_event["request"] == { "headers": { @@ -172,9 +193,10 @@ async def test_capture_transaction_with_error( async with TestClient(app) as client: await client.get("/") - (error_event, transaction_event) = events + (error_event, transaction_event, lifespan_transaction_event) = events assert error_event["transaction"] == "generic ASGI request" + assert error_event["transaction_info"] == {"source": "route"} assert error_event["contexts"]["trace"]["op"] == "http.server" assert error_event["exception"]["values"][0]["type"] == "ZeroDivisionError" assert error_event["exception"]["values"][0]["value"] == "division by zero" @@ -423,7 +445,7 @@ async def test_transaction_style( events = capture_events() await client.get(url) - (transaction_event,) = events + (transaction_event, lifespan_transaction_event) = events assert transaction_event["transaction"] == expected_transaction assert transaction_event["transaction_info"] == {"source": expected_source} @@ -472,8 +494,7 @@ def test_get_ip_x_forwarded_for(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "8.8.8.8" # x-forwarded-for overrides x-real-ip @@ -485,8 +506,7 @@ def test_get_ip_x_forwarded_for(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "8.8.8.8" # when multiple x-forwarded-for headers are, the first is taken @@ -499,8 +519,7 @@ def test_get_ip_x_forwarded_for(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "5.5.5.5" @@ -513,8 +532,7 @@ def test_get_ip_x_real_ip(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "10.10.10.10" # x-forwarded-for overrides x-real-ip @@ -526,8 +544,7 @@ def test_get_ip_x_real_ip(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "8.8.8.8" @@ -539,8 +556,7 @@ def test_get_ip(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "127.0.0.1" # x-forwarded-for header overides the ip from client @@ -551,8 +567,7 @@ def test_get_ip(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "8.8.8.8" # x-real-for header overides the ip from client @@ -563,8 +578,7 @@ def test_get_ip(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "10.10.10.10" @@ -579,8 +593,7 @@ def test_get_headers(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - headers = middleware._get_headers(scope) + headers = _get_headers(scope) assert headers == { "x-real-ip": "10.10.10.10", "some_header": "123, abc",