From ea0a4a36d70895f8c40196d0b42233e89e1a8372 Mon Sep 17 00:00:00 2001
From: Corvin Lasogga <corvin.lasogga@ilt-extern.fraunhofer.de>
Date: Fri, 22 Jul 2022 17:11:16 +0200
Subject: [PATCH] bugfix of contextmanager of generator

---
 .../instrumentation/grpc/_server.py           | 244 ++++++++++--------
 1 file changed, 139 insertions(+), 105 deletions(-)

diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py
index 1bcaeb776d..4f9c23659d 100644
--- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py
+++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py
@@ -252,7 +252,7 @@ def __init__(self, meter, tracer):
         self._duration_histogram = self._meter.create_histogram(
             name="rpc.server.duration",
             unit="ms",
-            description="measures the duration of the inbound rpc",
+            description="measures duration of inbound RPC",
         )
         self._request_size_histogram = self._meter.create_histogram(
             name="rpc.server.request.size",
@@ -339,70 +339,16 @@ def _create_attributes(self, context, full_method):
     def intercept_service(self, continuation, handler_call_details):
         def telemetry_wrapper(behavior, request_streaming, response_streaming):
             def telemetry_interceptor(request_or_iterator, context):
-                with self._set_remote_context(context):
-                    attributes = self._create_attributes(context, handler_call_details.method)
-
-                    with self._tracer.start_as_current_span(
-                        name=handler_call_details.method,
-                        kind=trace.SpanKind.SERVER,
-                        attributes=attributes,
-                        end_on_exit=False,
-                        record_exception=False,
-                        set_status_on_exception=False
-                    ) as span:
-
-                        try:
-                            # wrap the context
-                            context = _OpenTelemetryServicerContext(context, span)
-
-                            # wrap / log the request (iterator)
-                            if request_streaming:
-                                request_or_iterator = self._log_stream_requests(
-                                    request_or_iterator, span, attributes
-                                )
-                            else:
-                                self._log_unary_request(
-                                    request_or_iterator, span, attributes
-                                )
-
-                            # call the actual RPC and track the duration
-                            with self._record_duration(attributes, context):
-                                response_or_iterator = behavior(request_or_iterator, context)
-
-                            # wrap / log the response (iterator)
-                            if response_streaming:
-                                response_or_iterator = self._log_stream_responses(
-                                    response_or_iterator, span, attributes, context
-                                )
-                            else:
-                                self._log_unary_response(
-                                    response_or_iterator, span, attributes, context
-                                )
-
-                            return response_or_iterator
-
-                        except Exception as exc:
-                            # Bare exceptions are likely to be gRPC aborts, which
-                            # we handle in our context wrapper.
-                            # Here, we're interested in uncaught exceptions.
-                            # pylint:disable=unidiomatic-typecheck
-                            if type(exc) != Exception:
-                                span.set_attribute(
-                                    SpanAttributes.RPC_GRPC_STATUS_CODE,
-                                    grpc.StatusCode.UNKNOWN.value[0]
-                                )
-                                span.set_status(
-                                    Status(
-                                        status_code=StatusCode.ERROR,
-                                        description=f"{type(exc).__name__}: {exc}",
-                                    )
-                                )
-                                span.record_exception(exc)
-                            raise exc
-
-                        finally:
-                            if not response_streaming:
-                                span.end()
+                if response_streaming:
+                    return self._intercept_streaming_response(
+                        behavior, request_or_iterator, context,
+                        request_streaming, handler_call_details.method
+                    )
+
+                return self._intercept_unary_response(
+                    behavior, request_or_iterator, context,
+                    request_streaming, handler_call_details.method
+                )
 
             return telemetry_interceptor
 
@@ -411,6 +357,118 @@ def telemetry_interceptor(request_or_iterator, context):
             telemetry_wrapper
         )
 
+    def _intercept_unary_response(self, behavior, request_or_iterator, context, request_streaming, full_method):
+        with self._set_remote_context(context):
+            attributes = self._create_attributes(context, full_method)
+
+            with self._tracer.start_as_current_span(
+                name=full_method,
+                kind=trace.SpanKind.SERVER,
+                attributes=attributes,
+                end_on_exit=True,
+                record_exception=False,
+                set_status_on_exception=False
+            ) as span:
+
+                try:
+                    # wrap the context
+                    context = _OpenTelemetryServicerContext(context, span)
+
+                    # wrap / log the request (iterator)
+                    if request_streaming:
+                        request_or_iterator = self._log_streaming_request(
+                            request_or_iterator, span, attributes
+                        )
+                    else:
+                        self._log_unary_request(
+                            request_or_iterator, span, attributes
+                        )
+
+                    # call the actual RPC and track the duration
+                    with self._record_duration(attributes, context):
+                        response_or_iterator = behavior(request_or_iterator, context)
+
+                    # log the response (iterator)
+                    self._log_unary_response(
+                        response_or_iterator, span, attributes, context
+                    )
+
+                    return response_or_iterator
+
+                except Exception as exc:
+                    # Bare exceptions are likely to be gRPC aborts, which
+                    # we handle in our context wrapper.
+                    # Here, we're interested in uncaught exceptions.
+                    # pylint:disable=unidiomatic-typecheck
+                    if type(exc) != Exception:
+                        span.set_attribute(
+                            SpanAttributes.RPC_GRPC_STATUS_CODE,
+                            grpc.StatusCode.UNKNOWN.value[0]
+                        )
+                        span.set_status(
+                            Status(
+                                status_code=StatusCode.ERROR,
+                                description=f"{type(exc).__name__}: {exc}",
+                            )
+                        )
+                        span.record_exception(exc)
+                    raise exc
+
+    def _intercept_streaming_response(self, behavior, request_or_iterator, context, request_streaming, full_method):
+        with self._set_remote_context(context):
+            attributes = self._create_attributes(context, full_method)
+
+            with self._tracer.start_as_current_span(
+                name=full_method,
+                kind=trace.SpanKind.SERVER,
+                attributes=attributes,
+                end_on_exit=True,
+                record_exception=False,
+                set_status_on_exception=False
+            ) as span:
+
+                try:
+                    # wrap the context
+                    context = _OpenTelemetryServicerContext(context, span)
+
+                    # wrap / log the request (iterator)
+                    if request_streaming:
+                        request_or_iterator = self._log_streaming_request(
+                            request_or_iterator, span, attributes
+                        )
+                    else:
+                        self._log_unary_request(
+                            request_or_iterator, span, attributes
+                        )
+
+                    # call the actual RPC and track the duration
+                    with self._record_duration(attributes, context):
+                        response_or_iterator = behavior(request_or_iterator, context)
+
+                        # log the response (iterator)
+                        yield from self._log_streaming_response(
+                            response_or_iterator, span, attributes, context
+                        )
+
+                except Exception as exc:
+                    # Bare exceptions are likely to be gRPC aborts, which
+                    # we handle in our context wrapper.
+                    # Here, we're interested in uncaught exceptions.
+                    # pylint:disable=unidiomatic-typecheck
+                    if type(exc) != Exception:
+                        span.set_attribute(
+                            SpanAttributes.RPC_GRPC_STATUS_CODE,
+                            grpc.StatusCode.UNKNOWN.value[0]
+                        )
+                        span.set_status(
+                            Status(
+                                status_code=StatusCode.ERROR,
+                                description=f"{type(exc).__name__}: {exc}",
+                            )
+                        )
+                        span.record_exception(exc)
+                    raise exc
+
     def _log_unary_request(self, request, active_span, attributes):
         message_size_by = request.ByteSize()
         _add_message_event(
@@ -429,7 +487,7 @@ def _log_unary_response(self, response, active_span, attributes, context):
         self._response_size_histogram.record(message_size_by, attributes)
         self._responses_per_rpc_histogram.record(1, attributes)
 
-    def _log_stream_requests(self, request_iterator, active_span, attributes):
+    def _log_streaming_request(self, request_iterator, active_span, attributes):
         req_id = 1
         for req_id, msg in enumerate(request_iterator, start=1):
             message_size_by = msg.ByteSize()
@@ -441,44 +499,20 @@ def _log_stream_requests(self, request_iterator, active_span, attributes):
 
         self._requests_per_rpc_histogram.record(req_id, attributes)
 
-    def _log_stream_responses(self, response_iterator, active_span, attributes, context):
-        with trace.use_span(
-            active_span,
-            end_on_exit=True,
-            record_exception=False,
-            set_status_on_exception=False
-        ):
-            try:
-                res_id = 1
-                for res_id, msg in enumerate(response_iterator, start=1):
-                    message_size_by = msg.ByteSize()
-                    _add_message_event(
-                        active_span, MessageTypeValues.SENT.value, message_size_by, message_id=res_id
-                    )
-                    self._response_size_histogram.record(message_size_by, attributes)
-                    yield msg
-            except Exception as exc:
-                # Bare exceptions are likely to be gRPC aborts, which
-                # we handle in our context wrapper.
-                # Here, we're interested in uncaught exceptions.
-                # pylint:disable=unidiomatic-typecheck
-                if type(exc) != Exception:
-                    active_span.set_attribute(
-                        SpanAttributes.RPC_GRPC_STATUS_CODE,
-                        grpc.StatusCode.UNKNOWN.value[0]
-                    )
-                    active_span.set_status(
-                        Status(
-                            status_code=StatusCode.ERROR,
-                            description=f"{type(exc).__name__}: {exc}",
-                        )
-                    )
-                    active_span.record_exception(exc)
-                raise exc
-            finally:
-                if context._code != grpc.StatusCode.OK:
-                    attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0]
-                self._responses_per_rpc_histogram.record(res_id, attributes)
+    def _log_streaming_response(self, response_iterator, active_span, attributes, context):
+        try:
+            res_id = 1
+            for res_id, msg in enumerate(response_iterator, start=1):
+                message_size_by = msg.ByteSize()
+                _add_message_event(
+                    active_span, MessageTypeValues.SENT.value, message_size_by, message_id=res_id
+                )
+                self._response_size_histogram.record(message_size_by, attributes)
+                yield msg
+        finally:
+            if context._code != grpc.StatusCode.OK:
+                attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0]
+            self._responses_per_rpc_histogram.record(res_id, attributes)
 
     @contextmanager
     def _record_duration(self, attributes, context):
@@ -486,7 +520,7 @@ def _record_duration(self, attributes, context):
         try:
             yield
         finally:
-            duration = max(round((_time_ns() - start) * 1000), 0)
+            duration = max(round((_time_ns() - start) / 1000), 0)
             if context._code != grpc.StatusCode.OK:
                 attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0]
             self._duration_histogram.record(duration, attributes)