diff --git a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py index f2a18a2770..37b1559bb7 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py @@ -731,7 +731,9 @@ def _instrument(self, **kwargs): self._original_async_client = httpx.AsyncClient request_hook = kwargs.get("request_hook") response_hook = kwargs.get("response_hook") - async_request_hook = kwargs.get("async_request_hook", request_hook) + async_request_hook = kwargs.get( + "async_request_hook", self._wrap_async_request_hook(request_hook) + ) async_response_hook = kwargs.get("async_response_hook", response_hook) if callable(request_hook): _InstrumentedClient._request_hook = request_hook @@ -749,6 +751,16 @@ def _instrument(self, **kwargs): httpx.Client = httpx._api.Client = _InstrumentedClient httpx.AsyncClient = _InstrumentedAsyncClient + # Wrap a given request hook function and ensure it is asynchronous + def _wrap_async_request_hook(self, request_hook_func): + if request_hook_func is None: + return None + + async def async_request_hook(span, req): + return request_hook_func(span, req) + + return async_request_hook + def _uninstrument(self, **kwargs): httpx.Client = httpx._api.Client = self._original_client httpx.AsyncClient = self._original_async_client