diff --git a/sentry_sdk/integrations/_asgi_common.py b/sentry_sdk/integrations/_asgi_common.py new file mode 100644 index 0000000000..3d14393b03 --- /dev/null +++ b/sentry_sdk/integrations/_asgi_common.py @@ -0,0 +1,104 @@ +import urllib + +from sentry_sdk.hub import _should_send_default_pii +from sentry_sdk.integrations._wsgi_common import _filter_headers +from sentry_sdk._types import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + from typing import Dict + from typing import Optional + from typing_extensions import Literal + + +def _get_headers(asgi_scope): + # type: (Any) -> Dict[str, str] + """ + Extract headers from the ASGI scope, in the format that the Sentry protocol expects. + """ + headers = {} # type: Dict[str, str] + for raw_key, raw_value in asgi_scope["headers"]: + key = raw_key.decode("latin-1") + value = raw_value.decode("latin-1") + if key in headers: + headers[key] = headers[key] + ", " + value + else: + headers[key] = value + + return headers + + +def _get_url(asgi_scope, default_scheme, host): + # type: (Dict[str, Any], Literal["ws", "http"], Optional[str]) -> str + """ + Extract URL from the ASGI scope, without also including the querystring. + """ + scheme = asgi_scope.get("scheme", default_scheme) + + server = asgi_scope.get("server", None) + path = asgi_scope.get("root_path", "") + asgi_scope.get("path", "") + + if host: + return "%s://%s%s" % (scheme, host, path) + + if server is not None: + host, port = server + default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] + if port != default_port: + return "%s://%s:%s%s" % (scheme, host, port, path) + return "%s://%s%s" % (scheme, host, path) + return path + + +def _get_query(asgi_scope): + # type: (Any) -> Any + """ + Extract querystring from the ASGI scope, in the format that the Sentry protocol expects. + """ + qs = asgi_scope.get("query_string") + if not qs: + return None + return urllib.parse.unquote(qs.decode("latin-1")) + + +def _get_ip(asgi_scope): + # type: (Any) -> str + """ + Extract IP Address from the ASGI scope based on request headers with fallback to scope client. + """ + headers = _get_headers(asgi_scope) + try: + return headers["x-forwarded-for"].split(",")[0].strip() + except (KeyError, IndexError): + pass + + try: + return headers["x-real-ip"] + except KeyError: + pass + + return asgi_scope.get("client")[0] + + +def _get_request_data(asgi_scope): + # type: (Any) -> Dict[str, Any] + """ + Returns data related to the HTTP request from the ASGI scope. + """ + request_data = {} # type: Dict[str, Any] + ty = asgi_scope["type"] + if ty in ("http", "websocket"): + request_data["method"] = asgi_scope.get("method") + + request_data["headers"] = headers = _filter_headers(_get_headers(asgi_scope)) + request_data["query_string"] = _get_query(asgi_scope) + + request_data["url"] = _get_url( + asgi_scope, "http" if ty == "http" else "ws", headers.get("host") + ) + + client = asgi_scope.get("client") + if client and _should_send_default_pii(): + request_data["env"] = {"REMOTE_ADDR": _get_ip(asgi_scope)} + + return request_data diff --git a/sentry_sdk/integrations/asgi.py b/sentry_sdk/integrations/asgi.py index 25846cfc6e..dedebfcddb 100644 --- a/sentry_sdk/integrations/asgi.py +++ b/sentry_sdk/integrations/asgi.py @@ -6,20 +6,25 @@ import asyncio import inspect -import urllib +import logging from copy import deepcopy from sentry_sdk._functools import partial from sentry_sdk._types import TYPE_CHECKING from sentry_sdk.api import continue_trace from sentry_sdk.consts import OP -from sentry_sdk.hub import Hub, _should_send_default_pii -from sentry_sdk.integrations._wsgi_common import _filter_headers +from sentry_sdk.hub import Hub +from sentry_sdk.integrations._asgi_common import ( + _get_headers, + _get_request_data, + _get_url, +) from sentry_sdk.integrations.modules import _get_installed_modules from sentry_sdk.sessions import auto_session_tracking from sentry_sdk.tracing import ( SOURCE_FOR_STYLE, TRANSACTION_SOURCE_ROUTE, + TRANSACTION_SOURCE_URL, ) from sentry_sdk.utils import ( ContextVar, @@ -37,8 +42,6 @@ from typing import Optional from typing import Callable - from typing_extensions import Literal - from sentry_sdk._types import Event, Hint @@ -168,17 +171,35 @@ async def _run_app(self, scope, receive, send, asgi_version): ty = scope["type"] if ty in ("http", "websocket"): + ( + transaction_name, + transaction_source, + ) = self._get_transaction_name_and_source( + self.transaction_style, scope + ) transaction = continue_trace( - self._get_headers(scope), + _get_headers(scope), op="{}.server".format(ty), + name=transaction_name, + source=transaction_source, + ) + logging.warning( + "[ASGI] Created Transaction %s, %s, %s", + transaction, + transaction.sampled, + transaction_name, ) else: transaction = Transaction(op=OP.HTTP_SERVER) - transaction.name = _DEFAULT_TRANSACTION_NAME - transaction.source = TRANSACTION_SOURCE_ROUTE transaction.set_tag("asgi.type", ty) + logging.warning( + "[ASGI] Starting Transaction %s, %s, %s", + transaction, + transaction.sampled, + transaction.name, + ) with hub.start_transaction( transaction, custom_sampling_context={"asgi_scope": scope} ): @@ -214,46 +235,22 @@ async def _sentry_wrapped_send(event): def event_processor(self, event, hint, asgi_scope): # type: (Event, Hint, Any) -> Optional[Event] - request_info = event.get("request", {}) - - ty = asgi_scope["type"] - if ty in ("http", "websocket"): - request_info["method"] = asgi_scope.get("method") - request_info["headers"] = headers = _filter_headers( - self._get_headers(asgi_scope) - ) - request_info["query_string"] = self._get_query(asgi_scope) - - request_info["url"] = self._get_url( - asgi_scope, "http" if ty == "http" else "ws", headers.get("host") - ) - - client = asgi_scope.get("client") - if client and _should_send_default_pii(): - request_info["env"] = {"REMOTE_ADDR": self._get_ip(asgi_scope)} - - self._set_transaction_name_and_source(event, self.transaction_style, asgi_scope) + request_data = event.get("request", {}) + request_data.update(_get_request_data(asgi_scope)) + event["request"] = deepcopy(request_data) - event["request"] = deepcopy(request_info) + transaction_name, transaction_source = self._get_transaction_name_and_source( + self.transaction_style, asgi_scope + ) + event["transaction"] = transaction_name + event["transaction_info"] = {"source": transaction_source} return event - # Helper functions for extracting request data. - # - # Note: Those functions are not public API. If you want to mutate request - # data to your liking it's recommended to use the `before_send` callback - # for that. - - def _set_transaction_name_and_source(self, event, transaction_style, asgi_scope): + def _get_transaction_name_and_source(self, transaction_style, asgi_scope): # type: (Event, str, Any) -> None - transaction_name_already_set = ( - event.get("transaction", _DEFAULT_TRANSACTION_NAME) - != _DEFAULT_TRANSACTION_NAME - ) - if transaction_name_already_set: - return - - name = "" + name = None + source = None if transaction_style == "endpoint": endpoint = asgi_scope.get("endpoint") @@ -262,84 +259,36 @@ def _set_transaction_name_and_source(self, event, transaction_style, asgi_scope) # an endpoint, overwrite our generic transaction name. if endpoint: name = transaction_from_function(endpoint) or "" + source = SOURCE_FOR_STYLE[transaction_style] + else: + ty = asgi_scope.get("type") + if ty in ("http", "websocket"): + name = _get_url( + asgi_scope, "http" if ty == "http" else "ws", host=None + ) + source = TRANSACTION_SOURCE_URL + else: + name = _DEFAULT_TRANSACTION_NAME + source = TRANSACTION_SOURCE_ROUTE elif transaction_style == "url": - # FastAPI includes the route object in the scope to let Sentry extract the - # path from it for the transaction name route = asgi_scope.get("route") if route: - path = getattr(route, "path", None) - if path is not None: - name = path - - if not name: - event["transaction"] = _DEFAULT_TRANSACTION_NAME - event["transaction_info"] = {"source": TRANSACTION_SOURCE_ROUTE} - return - - event["transaction"] = name - event["transaction_info"] = {"source": SOURCE_FOR_STYLE[transaction_style]} - - def _get_url(self, scope, default_scheme, host): - # type: (Dict[str, Any], Literal["ws", "http"], Optional[str]) -> str - """ - Extract URL from the ASGI scope, without also including the querystring. - """ - scheme = scope.get("scheme", default_scheme) - - server = scope.get("server", None) - path = scope.get("root_path", "") + scope.get("path", "") - - if host: - return "%s://%s%s" % (scheme, host, path) - - if server is not None: - host, port = server - default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] - if port != default_port: - return "%s://%s:%s%s" % (scheme, host, port, path) - return "%s://%s%s" % (scheme, host, path) - return path - - def _get_query(self, scope): - # type: (Any) -> Any - """ - Extract querystring from the ASGI scope, in the format that the Sentry protocol expects. - """ - qs = scope.get("query_string") - if not qs: - return None - return urllib.parse.unquote(qs.decode("latin-1")) - - def _get_ip(self, scope): - # type: (Any) -> str - """ - Extract IP Address from the ASGI scope based on request headers with fallback to scope client. - """ - headers = self._get_headers(scope) - try: - return headers["x-forwarded-for"].split(",")[0].strip() - except (KeyError, IndexError): - pass - - try: - return headers["x-real-ip"] - except KeyError: - pass + name = route + source = SOURCE_FOR_STYLE[transaction_style] + else: + ty = asgi_scope.get("type") + if ty in ("http", "websocket"): + name = _get_url( + asgi_scope, "http" if ty == "http" else "ws", host=None + ) + source = TRANSACTION_SOURCE_URL + else: + name = _DEFAULT_TRANSACTION_NAME + source = TRANSACTION_SOURCE_ROUTE - return scope.get("client")[0] + if name is None: + name = _DEFAULT_TRANSACTION_NAME + source = TRANSACTION_SOURCE_ROUTE - def _get_headers(self, scope): - # type: (Any) -> Dict[str, str] - """ - Extract headers from the ASGI scope, in the format that the Sentry protocol expects. - """ - headers = {} # type: Dict[str, str] - for raw_key, raw_value in scope["headers"]: - key = raw_key.decode("latin-1") - value = raw_value.decode("latin-1") - if key in headers: - headers[key] = headers[key] + ", " + value - else: - headers[key] = value - return headers + return name, source diff --git a/sentry_sdk/integrations/fastapi.py b/sentry_sdk/integrations/fastapi.py index 17e0576c18..43e40aaa7f 100644 --- a/sentry_sdk/integrations/fastapi.py +++ b/sentry_sdk/integrations/fastapi.py @@ -37,7 +37,7 @@ def setup_once(): patch_get_request_handler() -def _set_transaction_name_and_source(scope, transaction_style, request): +def _get_transaction_name_and_source(scope, transaction_style, request): # type: (Scope, str, Any) -> None name = "" @@ -59,7 +59,7 @@ def _set_transaction_name_and_source(scope, transaction_style, request): else: source = SOURCE_FOR_STYLE[transaction_style] - scope.set_transaction_name(name, source=source) + return name, source def patch_get_request_handler(): @@ -97,10 +97,12 @@ async def _sentry_app(*args, **kwargs): with hub.configure_scope() as sentry_scope: request = args[0] - - _set_transaction_name_and_source( + transaction_name, transaction_source = _get_transaction_name_and_source( sentry_scope, integration.transaction_style, request ) + sentry_scope.set_transaction_name( + transaction_name, source=transaction_source + ) extractor = StarletteRequestExtractor(request) info = await extractor.extract_request_info() diff --git a/sentry_sdk/integrations/starlette.py b/sentry_sdk/integrations/starlette.py index b44e8f10b7..23529e089f 100644 --- a/sentry_sdk/integrations/starlette.py +++ b/sentry_sdk/integrations/starlette.py @@ -3,18 +3,25 @@ import asyncio import functools from copy import deepcopy +import logging from sentry_sdk._compat import iteritems from sentry_sdk._types import TYPE_CHECKING +from sentry_sdk.api import continue_trace from sentry_sdk.consts import OP from sentry_sdk.hub import Hub, _should_send_default_pii from sentry_sdk.integrations import DidNotEnable, Integration +from sentry_sdk.integrations._asgi_common import _get_headers, _get_request_data from sentry_sdk.integrations._wsgi_common import ( _is_json_content_type, request_body_within_bounds, ) -from sentry_sdk.integrations.asgi import SentryAsgiMiddleware -from sentry_sdk.tracing import SOURCE_FOR_STYLE, TRANSACTION_SOURCE_ROUTE + +from sentry_sdk.tracing import ( + SOURCE_FOR_STYLE, + TRANSACTION_SOURCE_ROUTE, + TRANSACTION_SOURCE_URL, +) from sentry_sdk.utils import ( AnnotatedValue, capture_internal_exceptions, @@ -26,8 +33,6 @@ if TYPE_CHECKING: from typing import Any, Awaitable, Callable, Dict, Optional - from sentry_sdk.scope import Scope as SentryScope - try: import starlette # type: ignore from starlette import __version__ as STARLETTE_VERSION @@ -39,7 +44,7 @@ ) from starlette.requests import Request # type: ignore from starlette.routing import Match # type: ignore - from starlette.types import ASGIApp, Receive, Scope as StarletteScope, Send # type: ignore + from starlette.types import ASGIApp # type: ignore except ImportError: raise DidNotEnable("Starlette is not installed") @@ -86,14 +91,67 @@ def setup_once(): "Unparsable Starlette version: {}".format(STARLETTE_VERSION) ) + patch_starlette_application() patch_middlewares() - patch_asgi_app() patch_request_response() if version >= (0, 24): patch_templates() +def patch_starlette_application(): + # type: () -> None + def wrap_call_function(func): + # type: (Callable[..., Any]) -> Callable[..., Any] + @functools.wraps(func) + async def __sentry_call__(self, scope, receive, send): + # type: (Starlette, Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]]) -> None + hub = Hub.current + + integration = hub.get_integration(StarletteIntegration) + if integration is None: + await func(self, scope, receive, send) + + request = Request(scope, receive) + + transaction_name, transaction_source = _get_transaction_name_and_source( + integration.transaction_style, request + ) + + ty = scope["type"] + transaction = continue_trace( + _get_headers(scope), + op="{}.server".format(ty), + name=transaction_name, + source=transaction_source, + ) + logging.warning( + "[Starlette] Created Transaction %s, %s, %s", + transaction, + transaction.sampled, + transaction_name, + ) + + logging.warning( + "[Starlette] Starting Transaction %s, %s, %s", + transaction, + transaction.sampled, + transaction_name, + ) + with hub.start_transaction( + transaction, custom_sampling_context={"asgi_scope": scope} + ): + try: + await func(self, scope, receive, send) + except Exception as exc: + _capture_exception(exc, handled=False) + raise exc from None + + return __sentry_call__ + + Starlette.__call__ = wrap_call_function(Starlette.__call__) + + def _enable_span_for_middleware(middleware_class): # type: (Any) -> type old_call = middleware_class.__call__ @@ -312,9 +370,8 @@ def patch_middlewares(): def _sentry_middleware_init(self, cls, **options): # type: (Any, Any, Any) -> None - if cls == SentryAsgiMiddleware: - return old_middleware_init(self, cls, **options) + # TODO: check if we should return the return value of old_middleware_init at the end of this function? span_enabled_cls = _enable_span_for_middleware(cls) old_middleware_init(self, span_enabled_cls, **options) @@ -327,29 +384,6 @@ def _sentry_middleware_init(self, cls, **options): Middleware.__init__ = _sentry_middleware_init -def patch_asgi_app(): - # type: () -> None - """ - Instrument Starlette ASGI app using the SentryAsgiMiddleware. - """ - old_app = Starlette.__call__ - - async def _sentry_patched_asgi_app(self, scope, receive, send): - # type: (Starlette, StarletteScope, Receive, Send) -> None - if Hub.current.get_integration(StarletteIntegration) is None: - return await old_app(self, scope, receive, send) - - middleware = SentryAsgiMiddleware( - lambda *a, **kw: old_app(self, *a, **kw), - mechanism_type=StarletteIntegration.identifier, - ) - - middleware.__call__ = middleware._run_asgi3 - return await middleware(scope, receive, send) - - Starlette.__call__ = _sentry_patched_asgi_app - - # This was vendored in from Starlette to support Starlette 0.19.1 because # this function was only introduced in 0.20.x def _is_async_callable(obj): @@ -382,11 +416,22 @@ async def _sentry_async_func(*args, **kwargs): with hub.configure_scope() as sentry_scope: request = args[0] + ( + transaction_name, + transaction_source, + ) = _get_transaction_name_and_source( + integration.transaction_style, request + ) + logging.warning( + "[Starlette] Setting Transaction Name %s", transaction_name + ) - _set_transaction_name_and_source( - sentry_scope, integration.transaction_style, request + sentry_scope.set_transaction_name( + name=transaction_name, + source=transaction_source, ) + asgi_request_data = _get_request_data(request.scope) extractor = StarletteRequestExtractor(request) info = await extractor.extract_request_info() @@ -397,6 +442,7 @@ def event_processor(event, hint): # Add info from request to event request_info = event.get("request", {}) + request_info.update(asgi_request_data) if info: if "cookies" in info: request_info["cookies"] = info["cookies"] @@ -430,11 +476,7 @@ def _sentry_sync_func(*args, **kwargs): sentry_scope.profile.update_active_thread_id() request = args[0] - - _set_transaction_name_and_source( - sentry_scope, integration.transaction_style, request - ) - + asgi_request_data = _get_request_data(request.scope) extractor = StarletteRequestExtractor(request) cookies = extractor.extract_cookies_from_request() @@ -445,6 +487,7 @@ def event_processor(event, hint): # Extract information from request request_info = event.get("request", {}) + request_info.update(asgi_request_data) if cookies: request_info["cookies"] = cookies @@ -619,32 +662,41 @@ async def json(self): return await self.request.json() -def _set_transaction_name_and_source(scope, transaction_style, request): - # type: (SentryScope, str, Any) -> None - name = "" +def _get_transaction_name_and_source(transaction_style, request): + # type: (str, Any) -> None + name = None + source = None if transaction_style == "endpoint": endpoint = request.scope.get("endpoint") if endpoint: name = transaction_from_function(endpoint) or "" + source = SOURCE_FOR_STYLE[transaction_style] + else: + name = request.scope.get("raw_path") + source = TRANSACTION_SOURCE_URL elif transaction_style == "url": - router = request.scope["router"] - for route in router.routes: - match = route.matches(request.scope) - - if match[0] == Match.FULL: - if transaction_style == "endpoint": - name = transaction_from_function(match[1]["endpoint"]) or "" - break - elif transaction_style == "url": - name = route.path - break - - if not name: + router = request.scope.get("router") + if router: + for route in router.routes: + match = route.matches(request.scope) + + if match[0] == Match.FULL: + if transaction_style == "endpoint": + name = transaction_from_function(match[1]["endpoint"]) or "" + source = SOURCE_FOR_STYLE[transaction_style] + break + elif transaction_style == "url": + name = route.path + source = SOURCE_FOR_STYLE[transaction_style] + break + else: + name = request.scope.get("raw_path") + source = TRANSACTION_SOURCE_URL + + if name is None: name = _DEFAULT_TRANSACTION_NAME source = TRANSACTION_SOURCE_ROUTE - else: - source = SOURCE_FOR_STYLE[transaction_style] - scope.set_transaction_name(name, source=source) + return name, source diff --git a/tests/integrations/asgi/test_asgi.py b/tests/integrations/asgi/test_asgi.py index dcd770ac37..e4fc90cdc2 100644 --- a/tests/integrations/asgi/test_asgi.py +++ b/tests/integrations/asgi/test_asgi.py @@ -6,6 +6,7 @@ import sentry_sdk from sentry_sdk import capture_message from sentry_sdk.integrations.asgi import SentryAsgiMiddleware, _looks_like_asgi3 +from sentry_sdk.integrations._asgi_common import _get_ip, _get_headers async_asgi_testclient = pytest.importorskip("async_asgi_testclient") from async_asgi_testclient import TestClient @@ -19,7 +20,17 @@ @pytest.fixture def asgi3_app(): async def app(scope, receive, send): - if ( + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + ... # Do some startup here! + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + ... # Do some shutdown here! + await send({"type": "lifespan.shutdown.complete"}) + return + elif ( scope["type"] == "http" and "route" in scope and scope["route"] == "/trigger/error" @@ -52,21 +63,32 @@ async def send_with_error(event): 1 / 0 async def app(scope, receive, send): - await send_with_error( - { - "type": "http.response.start", - "status": 200, - "headers": [ - [b"content-type", b"text/plain"], - ], - } - ) - await send_with_error( - { - "type": "http.response.body", - "body": b"Hello, world!", - } - ) + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + ... # Do some startup here! + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + ... # Do some shutdown here! + await send({"type": "lifespan.shutdown.complete"}) + return + else: + await send_with_error( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"content-type", b"text/plain"], + ], + } + ) + await send_with_error( + { + "type": "http.response.body", + "body": b"Hello, world!", + } + ) return app @@ -139,10 +161,11 @@ async def test_capture_transaction( events = capture_events() await client.get("/?somevalue=123") - (transaction_event,) = events + (transaction_event, lifespan_transaction_event) = events assert transaction_event["type"] == "transaction" - assert transaction_event["transaction"] == "generic ASGI request" + assert transaction_event["transaction"] == "/" + assert transaction_event["transaction_info"] == {"source": "url"} assert transaction_event["contexts"]["trace"]["op"] == "http.server" assert transaction_event["request"] == { "headers": { @@ -172,15 +195,17 @@ async def test_capture_transaction_with_error( async with TestClient(app) as client: await client.get("/") - (error_event, transaction_event) = events + (error_event, transaction_event, lifespan_transaction_event) = events - assert error_event["transaction"] == "generic ASGI request" + assert error_event["transaction"] == "/" assert error_event["contexts"]["trace"]["op"] == "http.server" assert error_event["exception"]["values"][0]["type"] == "ZeroDivisionError" assert error_event["exception"]["values"][0]["value"] == "division by zero" assert error_event["exception"]["values"][0]["mechanism"]["handled"] is False assert error_event["exception"]["values"][0]["mechanism"]["type"] == "asgi" + assert transaction_event["transaction"] == "/" + assert transaction_event["transaction_info"] == {"source": "url"} assert transaction_event["type"] == "transaction" assert transaction_event["contexts"]["trace"] == DictionaryContaining( error_event["contexts"]["trace"] @@ -389,7 +414,7 @@ async def test_auto_session_tracking_with_aggregates( ( "/message", "url", - "generic ASGI request", + "/message", "route", ), ( @@ -423,7 +448,7 @@ async def test_transaction_style( events = capture_events() await client.get(url) - (transaction_event,) = events + (transaction_event, lifespan_transaction_event) = events assert transaction_event["transaction"] == expected_transaction assert transaction_event["transaction_info"] == {"source": expected_source} @@ -472,8 +497,7 @@ def test_get_ip_x_forwarded_for(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "8.8.8.8" # x-forwarded-for overrides x-real-ip @@ -485,8 +509,7 @@ def test_get_ip_x_forwarded_for(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "8.8.8.8" # when multiple x-forwarded-for headers are, the first is taken @@ -499,8 +522,7 @@ def test_get_ip_x_forwarded_for(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "5.5.5.5" @@ -513,8 +535,7 @@ def test_get_ip_x_real_ip(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "10.10.10.10" # x-forwarded-for overrides x-real-ip @@ -526,8 +547,7 @@ def test_get_ip_x_real_ip(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "8.8.8.8" @@ -539,8 +559,7 @@ def test_get_ip(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "127.0.0.1" # x-forwarded-for header overides the ip from client @@ -551,8 +570,7 @@ def test_get_ip(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "8.8.8.8" # x-real-for header overides the ip from client @@ -563,8 +581,7 @@ def test_get_ip(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - ip = middleware._get_ip(scope) + ip = _get_ip(scope) assert ip == "10.10.10.10" @@ -579,8 +596,7 @@ def test_get_headers(): "client": ("127.0.0.1", 60457), "headers": headers, } - middleware = SentryAsgiMiddleware({}) - headers = middleware._get_headers(scope) + headers = _get_headers(scope) assert headers == { "x-real-ip": "10.10.10.10", "some_header": "123, abc", diff --git a/tests/integrations/fastapi/test_fastapi.py b/tests/integrations/fastapi/test_fastapi.py index 5a770a70af..8967878899 100644 --- a/tests/integrations/fastapi/test_fastapi.py +++ b/tests/integrations/fastapi/test_fastapi.py @@ -322,3 +322,32 @@ def test_response_status_code_not_found_in_transaction_context( "response" in transaction["contexts"].keys() ), "Response context not found in transaction" assert transaction["contexts"]["response"]["status_code"] == 404 + + +def test_transaction_name(): + """ + Tests that the transaction name is something meaningful. + """ + # this: sampling_context["transaction_context"]["name"] + # for transaction_style "endpoint" and "url" + assert False + + +def test_transaction_name_in_traces_sampler(): + """ + Tests that a custom traces_sampler has a meaningful the transaction name + In this case the URL or endpoint, because we do not have the route yet. + """ + # this: sampling_context["transaction_context"]["name"] + # for transaction_style "endpoint" and "url" + assert False + + +def test_transaction_name_in_middleware(): + """ + Tests that the transaction name in the middleware + (like CORSMiddleware) is something meaningful. + In this case the URL or endpoint, because we do not have the route yet. + """ + # for transaction_style "endpoint" and "url" + assert False diff --git a/tests/integrations/starlette/test_starlette.py b/tests/integrations/starlette/test_starlette.py index cc4d8cf3ba..d9e8f8742c 100644 --- a/tests/integrations/starlette/test_starlette.py +++ b/tests/integrations/starlette/test_starlette.py @@ -700,7 +700,9 @@ def test_middleware_callback_spans(sentry_init, capture_events): }, { "op": "middleware.starlette.send", - "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send", + "description": "_ASGIAdapter.send..send" + if STARLETTE_VERSION < (0, 21) + else "_TestClientTransport.handle_request..send", "tags": {"starlette.middleware_name": "ServerErrorMiddleware"}, }, { @@ -715,7 +717,9 @@ def test_middleware_callback_spans(sentry_init, capture_events): }, { "op": "middleware.starlette.send", - "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send", + "description": "_ASGIAdapter.send..send" + if STARLETTE_VERSION < (0, 21) + else "_TestClientTransport.handle_request..send", "tags": {"starlette.middleware_name": "ServerErrorMiddleware"}, }, ] @@ -789,7 +793,9 @@ def test_middleware_partial_receive_send(sentry_init, capture_events): }, { "op": "middleware.starlette.send", - "description": "SentryAsgiMiddleware._run_app.._sentry_wrapped_send", + "description": "_ASGIAdapter.send..send" + if STARLETTE_VERSION < (0, 21) + else "_TestClientTransport.handle_request..send", "tags": {"starlette.middleware_name": "ServerErrorMiddleware"}, }, { @@ -830,7 +836,7 @@ def handler(request, exc): app = starlette_app_factory(debug=False) app.add_exception_handler(500, handler) - client = TestClient(SentryAsgiMiddleware(app), raise_server_exceptions=False) + client = TestClient(app, raise_server_exceptions=False) response = client.get("/custom_error") assert response.status_code == 500 @@ -858,7 +864,32 @@ def test_legacy_setup( client.get("/message/123456") (event,) = events - assert event["transaction"] == "/message/{message_id}" + assert ( + event["transaction"] + == "tests.integrations.starlette.test_starlette.starlette_app_factory.._message_with_id" + ) + + +def test_legacy_setup_transaction_style_url( + sentry_init, + capture_events, +): + # Check that behaviour does not change + # if the user just adds the new Integration + # and forgets to remove SentryAsgiMiddleware + sentry_init() + app = starlette_app_factory() + asgi_app = SentryAsgiMiddleware(app, transaction_style="url") + + events = capture_events() + + client = TestClient(asgi_app) + client.get("/message/123456") + + (event,) = events + assert ( + event["transaction"] == "http://testserver/message/123456" + ) # the url from AsgiMiddleware (because it does not know about routes) @pytest.mark.parametrize("endpoint", ["/sync/thread_ids", "/async/thread_ids"]) @@ -869,11 +900,10 @@ def test_active_thread_id(sentry_init, capture_envelopes, teardown_profiling, en _experiments={"profiles_sample_rate": 1.0}, ) app = starlette_app_factory() - asgi_app = SentryAsgiMiddleware(app) envelopes = capture_envelopes() - client = TestClient(asgi_app) + client = TestClient(app) response = client.get(endpoint) assert response.status_code == 200 @@ -949,3 +979,97 @@ def test_template_tracing_meta(sentry_init, capture_events): # Python 2 does not preserve sort order rendered_baggage = match.group(2) assert sorted(rendered_baggage.split(",")) == sorted(baggage.split(",")) + + +@pytest.mark.parametrize( + "transaction_style,expected_transaction_name,expected_transaction_source", + [ + ( + "endpoint", + "tests.integrations.starlette.test_starlette.starlette_app_factory.._message_with_id", + "component", + ), + ( + "url", + "/message/{message_id}", + "route", + ), + ], +) +def test_transaction_name( + sentry_init, + transaction_style, + capture_envelopes, + expected_transaction_name, + expected_transaction_source, +): + """ + Tests that the transaction name is something meaningful. + """ + sentry_init( + auto_enabling_integrations=False, # Make sure that httpx integration is not added, because it adds tracing information to the starlette test clients request. + integrations=[StarletteIntegration(transaction_style=transaction_style)], + traces_sample_rate=1.0, + ) + + envelopes = capture_envelopes() + + app = starlette_app_factory() + client = TestClient(app) + client.get("/message/123456") + + (_, transaction_envelope) = envelopes + transaction_event = transaction_envelope.get_transaction_event() + + assert transaction_event["transaction"] == expected_transaction_name + assert ( + transaction_event["transaction_info"]["source"] == expected_transaction_source + ) + + +@pytest.mark.parametrize( + "transaction_style,expected_transaction_name,expected_transaction_source", + [ + ("endpoint", b"/message/123456", "url"), + ("url", b"/message/123456", "url"), + ], +) +def test_transaction_name_in_traces_sampler( + sentry_init, + transaction_style, + expected_transaction_name, + expected_transaction_source, +): + """ + Tests that a custom traces_sampler has a meaningful transaction name. + In this case the URL or endpoint, because we do not have the route yet. + """ + + def dummy_traces_sampler(sampling_context): + assert ( + sampling_context["transaction_context"]["name"] == expected_transaction_name + ) + assert ( + sampling_context["transaction_context"]["source"] + == expected_transaction_source + ) + + sentry_init( + auto_enabling_integrations=False, # Make sure that httpx integration is not added, because it adds tracing information to the starlette test clients request. + integrations=[StarletteIntegration(transaction_style=transaction_style)], + traces_sampler=dummy_traces_sampler, + traces_sample_rate=1.0, + ) + + app = starlette_app_factory() + client = TestClient(app) + client.get("/message/123456") + + +def test_transaction_name_in_middleware(): + """ + Tests that the transaction name in the middleware (like CORSMiddleware) is something meaningful. + In this case the URL or endpoint, because we do not have the route yet. + """ + # for transaction_style "endpoint" and "url" + assert False