diff --git a/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware/otel_middleware.py b/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware/otel_middleware.py index 1b747fd2c0..cc8fec3c23 100644 --- a/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware/otel_middleware.py +++ b/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware/otel_middleware.py @@ -286,9 +286,20 @@ def process_request(self, request): request.META[self._environ_token] = token if _DjangoMiddleware._otel_request_hook: - _DjangoMiddleware._otel_request_hook( # pylint: disable=not-callable - span, request - ) + try: + _DjangoMiddleware._otel_request_hook( # pylint: disable=not-callable + span, request + ) + except Exception as exception: + # process_response() will not be called, so we need to clean up + if token: + detach(token) + activation.__exit__( + type(exception), + exception, + getattr(exception, "__traceback__", None), + ) + raise exception # pylint: disable=unused-argument def process_view(self, request, view_func, *args, **kwargs): @@ -341,6 +352,8 @@ def process_response(self, request, response): ) request_start_time = request.META.pop(self._environ_timer_key, None) + response_hook_exception = None + if activation and span: if is_asgi_request: set_status_code(span, response.status_code) @@ -385,10 +398,19 @@ def process_response(self, request, response): # record any exceptions raised while processing the request exception = request.META.pop(self._environ_exception_key, None) + if _DjangoMiddleware._otel_response_hook: - _DjangoMiddleware._otel_response_hook( # pylint: disable=not-callable - span, request, response - ) + try: + _DjangoMiddleware._otel_response_hook( # pylint: disable=not-callable + span, request, response + ) + except Exception as e: + response_hook_exception = e + if not exception: + exception = e + else: + # original exception takes precedence, so just log this one + span.record_exception(e) if exception: activation.__exit__( @@ -408,5 +430,8 @@ def process_response(self, request, response): if request.META.get(self._environ_token, None) is not None: detach(request.META.get(self._environ_token)) request.META.pop(self._environ_token) + + if response_hook_exception is not None: + raise response_hook_exception return response diff --git a/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py b/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py index 63af1e6b86..171dc65b3a 100644 --- a/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py +++ b/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py @@ -391,6 +391,50 @@ def response_hook(span, request, response): self.assertIsInstance(response_hook_args[1], HttpRequest) self.assertIsInstance(response_hook_args[2], HttpResponse) self.assertEqual(response_hook_args[2], response) + + def test_request_hook_exception(self): + + class RequestHookException(Exception): + pass + + def request_hook(span, request): + raise RequestHookException() + + _DjangoMiddleware._otel_request_hook = request_hook + with self.assertRaises(RequestHookException): + Client().get("/span_name/1234/") + _DjangoMiddleware._otel_request_hook = None + + # ensure that span ended + finished_spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(finished_spans), 1) + self.assertEquals(finished_spans[0].status.status_code, StatusCode.ERROR) + + def test_response_hook_exception(self): + + class ResponseHookException(Exception): + pass + + def response_hook(span, request, response): + raise ResponseHookException() + + _DjangoMiddleware._otel_response_hook = response_hook + with self.assertRaises(ResponseHookException): + Client().get("/span_name/1234/") + with self.assertRaises(ResponseHookException): + Client().get("/error/") + _DjangoMiddleware._otel_response_hook = None + + # ensure that span ended + finished_spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(finished_spans), 2) + self.assertEquals(finished_spans[0].status.status_code, StatusCode.ERROR) + self.assertIn("ResponseHookException", finished_spans[0].status.description) + self.assertEquals(finished_spans[1].status.status_code, StatusCode.ERROR) + # view error takes precedence over response hook error + self.assertIn("ValueError", finished_spans[1].status.description) + # ensure an event was added for both the view error and the response hook error + self.assertEquals(len(finished_spans[1].events), 2) def test_trace_parent(self): id_generator = RandomIdGenerator()