From f6b68d0c024cf40d15c08062a18bf70ea73847e6 Mon Sep 17 00:00:00 2001 From: Riccardo Magliocchetti Date: Tue, 29 Oct 2024 21:33:35 +0100 Subject: [PATCH] httpx: rewrite patching to use wrapt instead of subclassing client (#2909) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit httpx: rewrote patching to use wrapt instead of subclassing client Porting of httpx instrumentation to patch async transport methods instead of substituting the client. That is because the current approach will instrument httpx by instantianting another client with a custom transport class and this will race with code already subclassing. This one uses wrapt to patch the default httpx transport classes. --------- Co-authored-by: Emídio Neto <9735060+emdneto@users.noreply.github.com> --- CHANGELOG.md | 2 + .../pyproject.toml | 1 + .../instrumentation/httpx/__init__.py | 378 +++++++++++++----- .../tests/test_httpx_integration.py | 99 +++-- 4 files changed, 350 insertions(+), 130 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed4671d559..7597e60641 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#2871](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2871)) - `opentelemetry-instrumentation` Don't fail distro loading if instrumentor raises ImportError, instead skip them ([#2923](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2923)) +- `opentelemetry-instrumentation-httpx` Rewrote instrumentation to use wrapt instead of subclassing + ([#2909](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2909)) ## Version 1.27.0/0.48b0 (2024-08-28) diff --git a/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml b/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml index 599091716b..c986fac4a1 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml +++ b/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "opentelemetry-instrumentation == 0.49b0.dev", "opentelemetry-semantic-conventions == 0.49b0.dev", "opentelemetry-util-http == 0.49b0.dev", + "wrapt >= 1.0.0, < 2.0.0", ] [project.optional-dependencies] diff --git a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py index b9b9a31d3e..d3a2cecfe6 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=too-many-lines """ Usage ----- @@ -194,9 +195,11 @@ async def async_response_hook(span, request, response): import logging import typing from asyncio import iscoroutinefunction +from functools import partial from types import TracebackType import httpx +from wrapt import wrap_function_wrapper from opentelemetry.instrumentation._semconv import ( _get_schema_url, @@ -217,6 +220,7 @@ async def async_response_hook(span, request, response): from opentelemetry.instrumentation.utils import ( http_status_to_status_code, is_http_instrumentation_enabled, + unwrap, ) from opentelemetry.propagate import inject from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE @@ -731,44 +735,211 @@ def _instrument(self, **kwargs): ``async_request_hook``: Async ``request_hook`` for ``httpx.AsyncClient`` ``async_response_hook``: Async``response_hook`` for ``httpx.AsyncClient`` """ - self._original_client = httpx.Client - self._original_async_client = httpx.AsyncClient + tracer_provider = kwargs.get("tracer_provider") request_hook = kwargs.get("request_hook") response_hook = kwargs.get("response_hook") async_request_hook = kwargs.get("async_request_hook") - async_response_hook = kwargs.get("async_response_hook") - if callable(request_hook): - _InstrumentedClient._request_hook = request_hook - if callable(async_request_hook) and iscoroutinefunction( + async_request_hook = ( async_request_hook - ): - _InstrumentedAsyncClient._request_hook = async_request_hook - if callable(response_hook): - _InstrumentedClient._response_hook = response_hook - if callable(async_response_hook) and iscoroutinefunction( + if iscoroutinefunction(async_request_hook) + else None + ) + async_response_hook = kwargs.get("async_response_hook") + async_response_hook = ( async_response_hook - ): - _InstrumentedAsyncClient._response_hook = async_response_hook - tracer_provider = kwargs.get("tracer_provider") - _InstrumentedClient._tracer_provider = tracer_provider - _InstrumentedAsyncClient._tracer_provider = tracer_provider - # Intentionally using a private attribute here, see: - # https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2538#discussion_r1610603719 - httpx.Client = httpx._api.Client = _InstrumentedClient - httpx.AsyncClient = _InstrumentedAsyncClient + if iscoroutinefunction(async_response_hook) + else None + ) + + _OpenTelemetrySemanticConventionStability._initialize() + sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode( + _OpenTelemetryStabilitySignalType.HTTP, + ) + tracer = get_tracer( + __name__, + instrumenting_library_version=__version__, + tracer_provider=tracer_provider, + schema_url=_get_schema_url(sem_conv_opt_in_mode), + ) + + wrap_function_wrapper( + "httpx", + "HTTPTransport.handle_request", + partial( + self._handle_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + request_hook=request_hook, + response_hook=response_hook, + ), + ) + wrap_function_wrapper( + "httpx", + "AsyncHTTPTransport.handle_async_request", + partial( + self._handle_async_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + async_request_hook=async_request_hook, + async_response_hook=async_response_hook, + ), + ) def _uninstrument(self, **kwargs): - httpx.Client = httpx._api.Client = self._original_client - httpx.AsyncClient = self._original_async_client - _InstrumentedClient._tracer_provider = None - _InstrumentedClient._request_hook = None - _InstrumentedClient._response_hook = None - _InstrumentedAsyncClient._tracer_provider = None - _InstrumentedAsyncClient._request_hook = None - _InstrumentedAsyncClient._response_hook = None + unwrap(httpx.HTTPTransport, "handle_request") + unwrap(httpx.AsyncHTTPTransport, "handle_async_request") @staticmethod + def _handle_request_wrapper( # pylint: disable=too-many-locals + wrapped, + instance, + args, + kwargs, + tracer, + sem_conv_opt_in_mode, + request_hook, + response_hook, + ): + if not is_http_instrumentation_enabled(): + return wrapped(*args, **kwargs) + + method, url, headers, stream, extensions = _extract_parameters( + args, kwargs + ) + method_original = method.decode() + span_name = _get_default_span_name(method_original) + span_attributes = {} + # apply http client response attributes according to semconv + _apply_request_client_attributes_to_span( + span_attributes, + url, + method_original, + sem_conv_opt_in_mode, + ) + + request_info = RequestInfo(method, url, headers, stream, extensions) + + with tracer.start_as_current_span( + span_name, kind=SpanKind.CLIENT, attributes=span_attributes + ) as span: + exception = None + if callable(request_hook): + request_hook(span, request_info) + + _inject_propagation_headers(headers, args, kwargs) + + try: + response = wrapped(*args, **kwargs) + except Exception as exc: # pylint: disable=W0703 + exception = exc + response = getattr(exc, "response", None) + + if isinstance(response, (httpx.Response, tuple)): + status_code, headers, stream, extensions, http_version = ( + _extract_response(response) + ) + + if span.is_recording(): + # apply http client response attributes according to semconv + _apply_response_client_attributes_to_span( + span, + status_code, + http_version, + sem_conv_opt_in_mode, + ) + if callable(response_hook): + response_hook( + span, + request_info, + ResponseInfo(status_code, headers, stream, extensions), + ) + + if exception: + if span.is_recording() and _report_new(sem_conv_opt_in_mode): + span.set_attribute( + ERROR_TYPE, type(exception).__qualname__ + ) + raise exception.with_traceback(exception.__traceback__) + + return response + + @staticmethod + async def _handle_async_request_wrapper( # pylint: disable=too-many-locals + wrapped, + instance, + args, + kwargs, + tracer, + sem_conv_opt_in_mode, + async_request_hook, + async_response_hook, + ): + if not is_http_instrumentation_enabled(): + return await wrapped(*args, **kwargs) + + method, url, headers, stream, extensions = _extract_parameters( + args, kwargs + ) + method_original = method.decode() + span_name = _get_default_span_name(method_original) + span_attributes = {} + # apply http client response attributes according to semconv + _apply_request_client_attributes_to_span( + span_attributes, + url, + method_original, + sem_conv_opt_in_mode, + ) + + request_info = RequestInfo(method, url, headers, stream, extensions) + + with tracer.start_as_current_span( + span_name, kind=SpanKind.CLIENT, attributes=span_attributes + ) as span: + exception = None + if callable(async_request_hook): + await async_request_hook(span, request_info) + + _inject_propagation_headers(headers, args, kwargs) + + try: + response = await wrapped(*args, **kwargs) + except Exception as exc: # pylint: disable=W0703 + exception = exc + response = getattr(exc, "response", None) + + if isinstance(response, (httpx.Response, tuple)): + status_code, headers, stream, extensions, http_version = ( + _extract_response(response) + ) + + if span.is_recording(): + # apply http client response attributes according to semconv + _apply_response_client_attributes_to_span( + span, + status_code, + http_version, + sem_conv_opt_in_mode, + ) + + if callable(async_response_hook): + await async_response_hook( + span, + request_info, + ResponseInfo(status_code, headers, stream, extensions), + ) + + if exception: + if span.is_recording() and _report_new(sem_conv_opt_in_mode): + span.set_attribute( + ERROR_TYPE, type(exception).__qualname__ + ) + raise exception.with_traceback(exception.__traceback__) + + return response + def instrument_client( + self, client: typing.Union[httpx.Client, httpx.AsyncClient], tracer_provider: TracerProvider = None, request_hook: typing.Union[ @@ -788,67 +959,88 @@ def instrument_client( response_hook: A hook that receives the span, request, and response that is called right before the span ends """ - # pylint: disable=protected-access - if not hasattr(client, "_is_instrumented_by_opentelemetry"): - client._is_instrumented_by_opentelemetry = False - if not client._is_instrumented_by_opentelemetry: - if isinstance(client, httpx.Client): - client._original_transport = client._transport - client._original_mounts = client._mounts.copy() - transport = client._transport or httpx.HTTPTransport() - client._transport = SyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - client._is_instrumented_by_opentelemetry = True - client._mounts.update( - { - url_pattern: ( - SyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - if transport is not None - else transport - ) - for url_pattern, transport in client._original_mounts.items() - } - ) + if getattr(client, "_is_instrumented_by_opentelemetry", False): + _logger.warning( + "Attempting to instrument Httpx client while already instrumented" + ) + return - if isinstance(client, httpx.AsyncClient): - transport = client._transport or httpx.AsyncHTTPTransport() - client._original_mounts = client._mounts.copy() - client._transport = AsyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, + _OpenTelemetrySemanticConventionStability._initialize() + sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode( + _OpenTelemetryStabilitySignalType.HTTP, + ) + tracer = get_tracer( + __name__, + instrumenting_library_version=__version__, + tracer_provider=tracer_provider, + schema_url=_get_schema_url(sem_conv_opt_in_mode), + ) + + if iscoroutinefunction(request_hook): + async_request_hook = request_hook + request_hook = None + else: + # request_hook already set + async_request_hook = None + + if iscoroutinefunction(response_hook): + async_response_hook = response_hook + response_hook = None + else: + # response_hook already set + async_response_hook = None + + if hasattr(client._transport, "handle_request"): + wrap_function_wrapper( + client._transport, + "handle_request", + partial( + self._handle_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, request_hook=request_hook, response_hook=response_hook, + ), + ) + for transport in client._mounts.values(): + wrap_function_wrapper( + transport, + "handle_request", + partial( + self._handle_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + request_hook=request_hook, + response_hook=response_hook, + ), ) - client._is_instrumented_by_opentelemetry = True - client._mounts.update( - { - url_pattern: ( - AsyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - if transport is not None - else transport - ) - for url_pattern, transport in client._original_mounts.items() - } - ) - else: - _logger.warning( - "Attempting to instrument Httpx client while already instrumented" + client._is_instrumented_by_opentelemetry = True + if hasattr(client._transport, "handle_async_request"): + wrap_function_wrapper( + client._transport, + "handle_async_request", + partial( + self._handle_async_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + async_request_hook=async_request_hook, + async_response_hook=async_response_hook, + ), ) + for transport in client._mounts.values(): + wrap_function_wrapper( + transport, + "handle_async_request", + partial( + self._handle_async_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + async_request_hook=async_request_hook, + async_response_hook=async_response_hook, + ), + ) + client._is_instrumented_by_opentelemetry = True @staticmethod def uninstrument_client( @@ -859,15 +1051,13 @@ def uninstrument_client( Args: client: The httpx Client or AsyncClient instance """ - if hasattr(client, "_original_transport"): - client._transport = client._original_transport - del client._original_transport + if hasattr(client._transport, "handle_request"): + unwrap(client._transport, "handle_request") + for transport in client._mounts.values(): + unwrap(transport, "handle_request") + client._is_instrumented_by_opentelemetry = False + elif hasattr(client._transport, "handle_async_request"): + unwrap(client._transport, "handle_async_request") + for transport in client._mounts.values(): + unwrap(transport, "handle_async_request") client._is_instrumented_by_opentelemetry = False - if hasattr(client, "_original_mounts"): - client._mounts = client._original_mounts.copy() - del client._original_mounts - else: - _logger.warning( - "Attempting to uninstrument Httpx " - "client while already uninstrumented" - ) diff --git a/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py b/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py index 0d055515e0..07699700c4 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py +++ b/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py @@ -21,6 +21,7 @@ import httpx import respx +from wrapt import ObjectProxy import opentelemetry.instrumentation.httpx from opentelemetry import trace @@ -171,6 +172,7 @@ def tearDown(self): super().tearDown() self.env_patch.stop() respx.stop() + HTTPXClientInstrumentor().uninstrument() def assert_span( self, exporter: "SpanExporter" = None, num_spans: int = 1 @@ -204,7 +206,7 @@ def test_basic(self): self.assertEqual(span.name, "GET") self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -228,7 +230,7 @@ def test_nonstandard_http_method(self): self.assertIs(span.kind, trace.SpanKind.CLIENT) self.assertEqual(span.name, "HTTP") self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "_OTHER", SpanAttributes.HTTP_URL: self.URL, @@ -252,7 +254,7 @@ def test_nonstandard_http_method_new_semconv(self): self.assertIs(span.kind, trace.SpanKind.CLIENT) self.assertEqual(span.name, "HTTP") self.assertEqual( - span.attributes, + dict(span.attributes), { HTTP_REQUEST_METHOD: "_OTHER", URL_FULL: self.URL, @@ -292,7 +294,7 @@ def test_basic_new_semconv(self): SpanAttributes.SCHEMA_URL, ) self.assertEqual( - span.attributes, + dict(span.attributes), { HTTP_REQUEST_METHOD: "GET", URL_FULL: url, @@ -327,7 +329,7 @@ def test_basic_both_semconv(self): ) self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", HTTP_REQUEST_METHOD: "GET", @@ -454,7 +456,7 @@ def test_requests_500_error(self): span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -510,7 +512,7 @@ def test_requests_timeout_exception_new_semconv(self): span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { HTTP_REQUEST_METHOD: "GET", URL_FULL: url, @@ -531,7 +533,7 @@ def test_requests_timeout_exception_both_semconv(self): span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", HTTP_REQUEST_METHOD: "GET", @@ -632,7 +634,7 @@ def test_response_hook(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -741,8 +743,10 @@ def create_proxy_transport(self, url: str): def setUp(self): super().setUp() - HTTPXClientInstrumentor().instrument() self.client = self.create_client() + HTTPXClientInstrumentor().instrument_client(self.client) + + def tearDown(self): HTTPXClientInstrumentor().uninstrument() def create_proxy_mounts(self): @@ -755,14 +759,25 @@ def create_proxy_mounts(self): ), } - def assert_proxy_mounts(self, mounts, num_mounts, transport_type): + def assert_proxy_mounts(self, mounts, num_mounts, transport_type=None): self.assertEqual(len(mounts), num_mounts) for transport in mounts: with self.subTest(transport): - self.assertIsInstance( - transport, - transport_type, - ) + if transport_type: + self.assertIsInstance( + transport, + transport_type, + ) + else: + handler = getattr(transport, "handle_request", None) + if not handler: + handler = getattr( + transport, "handle_async_request" + ) + self.assertTrue( + isinstance(handler, ObjectProxy) + and getattr(handler, "__wrapped__") + ) def test_custom_tracer_provider(self): resource = resources.Resource.create({}) @@ -778,7 +793,6 @@ def test_custom_tracer_provider(self): self.assertEqual(result.text, "Hello!") span = self.assert_span(exporter=exporter) self.assertIs(span.resource, resource) - HTTPXClientInstrumentor().uninstrument() def test_response_hook(self): response_hook_key = ( @@ -797,7 +811,7 @@ def test_response_hook(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -805,7 +819,6 @@ def test_response_hook(self): HTTP_RESPONSE_BODY: "Hello!", }, ) - HTTPXClientInstrumentor().uninstrument() def test_response_hook_sync_async_kwargs(self): HTTPXClientInstrumentor().instrument( @@ -819,7 +832,7 @@ def test_response_hook_sync_async_kwargs(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -827,7 +840,6 @@ def test_response_hook_sync_async_kwargs(self): HTTP_RESPONSE_BODY: "Hello!", }, ) - HTTPXClientInstrumentor().uninstrument() def test_request_hook(self): request_hook_key = ( @@ -846,7 +858,6 @@ def test_request_hook(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET" + self.URL) - HTTPXClientInstrumentor().uninstrument() def test_request_hook_sync_async_kwargs(self): HTTPXClientInstrumentor().instrument( @@ -860,7 +871,6 @@ def test_request_hook_sync_async_kwargs(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET" + self.URL) - HTTPXClientInstrumentor().uninstrument() def test_request_hook_no_span_update(self): HTTPXClientInstrumentor().instrument( @@ -873,7 +883,6 @@ def test_request_hook_no_span_update(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET") - HTTPXClientInstrumentor().uninstrument() def test_not_recording(self): with mock.patch("opentelemetry.trace.INVALID_SPAN") as mock_span: @@ -891,7 +900,6 @@ def test_not_recording(self): self.assertTrue(mock_span.is_recording.called) self.assertFalse(mock_span.set_attribute.called) self.assertFalse(mock_span.set_status.called) - HTTPXClientInstrumentor().uninstrument() def test_suppress_instrumentation_new_client(self): HTTPXClientInstrumentor().instrument() @@ -901,7 +909,6 @@ def test_suppress_instrumentation_new_client(self): self.assertEqual(result.text, "Hello!") self.assert_span(num_spans=0) - HTTPXClientInstrumentor().uninstrument() def test_instrument_client(self): client = self.create_client() @@ -929,8 +936,6 @@ def test_instrumentation_without_client(self): self.URL, ) - HTTPXClientInstrumentor().uninstrument() - def test_uninstrument(self): HTTPXClientInstrumentor().instrument() HTTPXClientInstrumentor().uninstrument() @@ -980,9 +985,7 @@ def test_instrument_proxy(self): self.assert_proxy_mounts( client._mounts.values(), 2, - (SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport), ) - HTTPXClientInstrumentor().uninstrument() def test_instrument_client_with_proxy(self): proxy_mounts = self.create_proxy_mounts() @@ -999,7 +1002,6 @@ def test_instrument_client_with_proxy(self): self.assert_proxy_mounts( client._mounts.values(), 2, - (SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport), ) HTTPXClientInstrumentor().uninstrument_client(client) @@ -1010,7 +1012,6 @@ def test_uninstrument_client_with_proxy(self): self.assert_proxy_mounts( client._mounts.values(), 2, - (SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport), ) HTTPXClientInstrumentor().uninstrument_client(client) @@ -1180,6 +1181,21 @@ def perform_request( def create_proxy_transport(self, url): return httpx.HTTPTransport(proxy=httpx.Proxy(url)) + def test_can_instrument_subclassed_client(self): + class CustomClient(httpx.Client): + pass + + client = CustomClient() + self.assertFalse( + isinstance(client._transport.handle_request, ObjectProxy) + ) + + HTTPXClientInstrumentor().instrument() + + self.assertTrue( + isinstance(client._transport.handle_request, ObjectProxy) + ) + class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): response_hook = staticmethod(_async_response_hook) @@ -1188,10 +1204,8 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): def setUp(self): super().setUp() - HTTPXClientInstrumentor().instrument() - self.client = self.create_client() self.client2 = self.create_client() - HTTPXClientInstrumentor().uninstrument() + HTTPXClientInstrumentor().instrument_client(self.client2) def create_client( self, @@ -1245,7 +1259,6 @@ def test_async_response_hook_does_nothing_if_not_coroutine(self): SpanAttributes.HTTP_STATUS_CODE: 200, }, ) - HTTPXClientInstrumentor().uninstrument() def test_async_request_hook_does_nothing_if_not_coroutine(self): HTTPXClientInstrumentor().instrument( @@ -1258,4 +1271,18 @@ def test_async_request_hook_does_nothing_if_not_coroutine(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET") - HTTPXClientInstrumentor().uninstrument() + + def test_can_instrument_subclassed_async_client(self): + class CustomAsyncClient(httpx.AsyncClient): + pass + + client = CustomAsyncClient() + self.assertFalse( + isinstance(client._transport.handle_async_request, ObjectProxy) + ) + + HTTPXClientInstrumentor().instrument() + + self.assertTrue( + isinstance(client._transport.handle_async_request, ObjectProxy) + )