diff --git a/sentry_sdk/integrations/httpx.py b/sentry_sdk/integrations/httpx.py index 6f80b93f4d..2ddd44489f 100644 --- a/sentry_sdk/integrations/httpx.py +++ b/sentry_sdk/integrations/httpx.py @@ -2,7 +2,7 @@ from sentry_sdk.consts import OP, SPANDATA from sentry_sdk.integrations import Integration, DidNotEnable from sentry_sdk.tracing import BAGGAGE_HEADER_NAME -from sentry_sdk.tracing_utils import should_propagate_trace +from sentry_sdk.tracing_utils import Baggage, should_propagate_trace from sentry_sdk.utils import ( SENSITIVE_DATA_SUBSTITUTE, capture_internal_exceptions, @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import MutableMapping from typing import Any @@ -76,11 +77,9 @@ def send(self, request, **kwargs): key=key, value=value, url=request.url ) ) - if key == BAGGAGE_HEADER_NAME and request.headers.get( - BAGGAGE_HEADER_NAME - ): - # do not overwrite any existing baggage, just append to it - request.headers[key] += "," + value + + if key == BAGGAGE_HEADER_NAME: + _add_sentry_baggage_to_headers(request.headers, value) else: request.headers[key] = value @@ -148,3 +147,21 @@ async def send(self, request, **kwargs): return rv AsyncClient.send = send + + +def _add_sentry_baggage_to_headers(headers, sentry_baggage): + # type: (MutableMapping[str, str], str) -> None + """Add the Sentry baggage to the headers. + + This function directly mutates the provided headers. The provided sentry_baggage + is appended to the existing baggage. If the baggage already contains Sentry items, + they are stripped out first. + """ + existing_baggage = headers.get(BAGGAGE_HEADER_NAME, "") + stripped_existing_baggage = Baggage.strip_sentry_baggage(existing_baggage) + + separator = "," if len(stripped_existing_baggage) > 0 else "" + + headers[BAGGAGE_HEADER_NAME] = ( + stripped_existing_baggage + separator + sentry_baggage + ) diff --git a/sentry_sdk/tracing_utils.py b/sentry_sdk/tracing_utils.py index 150e73661e..0459563776 100644 --- a/sentry_sdk/tracing_utils.py +++ b/sentry_sdk/tracing_utils.py @@ -609,6 +609,21 @@ def serialize(self, include_third_party=False): return ",".join(items) + @staticmethod + def strip_sentry_baggage(header): + # type: (str) -> str + """Remove Sentry baggage from the given header. + + Given a Baggage header, return a new Baggage header with all Sentry baggage items removed. + """ + return ",".join( + ( + item + for item in header.split(",") + if not Baggage.SENTRY_PREFIX_REGEX.match(item.strip()) + ) + ) + def should_propagate_trace(client, url): # type: (sentry_sdk.client.BaseClient, str) -> bool diff --git a/tests/test_tracing_utils.py b/tests/test_tracing_utils.py index 239e631156..5c1f70516d 100644 --- a/tests/test_tracing_utils.py +++ b/tests/test_tracing_utils.py @@ -1,7 +1,7 @@ from dataclasses import asdict, dataclass from typing import Optional, List -from sentry_sdk.tracing_utils import _should_be_included +from sentry_sdk.tracing_utils import _should_be_included, Baggage import pytest @@ -94,3 +94,24 @@ def test_should_be_included(test_case, expected): kwargs = asdict(test_case) kwargs.pop("id") assert _should_be_included(**kwargs) == expected + + +@pytest.mark.parametrize( + ("header", "expected"), + ( + ("", ""), + ("foo=bar", "foo=bar"), + (" foo=bar, baz = qux ", " foo=bar, baz = qux "), + ("sentry-trace_id=123", ""), + (" sentry-trace_id = 123 ", ""), + ("sentry-trace_id=123,sentry-public_key=456", ""), + ("foo=bar,sentry-trace_id=123", "foo=bar"), + ("foo=bar,sentry-trace_id=123,baz=qux", "foo=bar,baz=qux"), + ( + "foo=bar,sentry-trace_id=123,baz=qux,sentry-public_key=456", + "foo=bar,baz=qux", + ), + ), +) +def test_strip_sentry_baggage(header, expected): + assert Baggage.strip_sentry_baggage(header) == expected