diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py index 1bcaeb776d..4f9c23659d 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py @@ -252,7 +252,7 @@ def __init__(self, meter, tracer): self._duration_histogram = self._meter.create_histogram( name="rpc.server.duration", unit="ms", - description="measures the duration of the inbound rpc", + description="measures duration of inbound RPC", ) self._request_size_histogram = self._meter.create_histogram( name="rpc.server.request.size", @@ -339,70 +339,16 @@ def _create_attributes(self, context, full_method): def intercept_service(self, continuation, handler_call_details): def telemetry_wrapper(behavior, request_streaming, response_streaming): def telemetry_interceptor(request_or_iterator, context): - with self._set_remote_context(context): - attributes = self._create_attributes(context, handler_call_details.method) - - with self._tracer.start_as_current_span( - name=handler_call_details.method, - kind=trace.SpanKind.SERVER, - attributes=attributes, - end_on_exit=False, - record_exception=False, - set_status_on_exception=False - ) as span: - - try: - # wrap the context - context = _OpenTelemetryServicerContext(context, span) - - # wrap / log the request (iterator) - if request_streaming: - request_or_iterator = self._log_stream_requests( - request_or_iterator, span, attributes - ) - else: - self._log_unary_request( - request_or_iterator, span, attributes - ) - - # call the actual RPC and track the duration - with self._record_duration(attributes, context): - response_or_iterator = behavior(request_or_iterator, context) - - # wrap / log the response (iterator) - if response_streaming: - response_or_iterator = self._log_stream_responses( - response_or_iterator, span, attributes, context - ) - else: - self._log_unary_response( - response_or_iterator, span, attributes, context - ) - - return response_or_iterator - - except Exception as exc: - # Bare exceptions are likely to be gRPC aborts, which - # we handle in our context wrapper. - # Here, we're interested in uncaught exceptions. - # pylint:disable=unidiomatic-typecheck - if type(exc) != Exception: - span.set_attribute( - SpanAttributes.RPC_GRPC_STATUS_CODE, - grpc.StatusCode.UNKNOWN.value[0] - ) - span.set_status( - Status( - status_code=StatusCode.ERROR, - description=f"{type(exc).__name__}: {exc}", - ) - ) - span.record_exception(exc) - raise exc - - finally: - if not response_streaming: - span.end() + if response_streaming: + return self._intercept_streaming_response( + behavior, request_or_iterator, context, + request_streaming, handler_call_details.method + ) + + return self._intercept_unary_response( + behavior, request_or_iterator, context, + request_streaming, handler_call_details.method + ) return telemetry_interceptor @@ -411,6 +357,118 @@ def telemetry_interceptor(request_or_iterator, context): telemetry_wrapper ) + def _intercept_unary_response(self, behavior, request_or_iterator, context, request_streaming, full_method): + with self._set_remote_context(context): + attributes = self._create_attributes(context, full_method) + + with self._tracer.start_as_current_span( + name=full_method, + kind=trace.SpanKind.SERVER, + attributes=attributes, + end_on_exit=True, + record_exception=False, + set_status_on_exception=False + ) as span: + + try: + # wrap the context + context = _OpenTelemetryServicerContext(context, span) + + # wrap / log the request (iterator) + if request_streaming: + request_or_iterator = self._log_streaming_request( + request_or_iterator, span, attributes + ) + else: + self._log_unary_request( + request_or_iterator, span, attributes + ) + + # call the actual RPC and track the duration + with self._record_duration(attributes, context): + response_or_iterator = behavior(request_or_iterator, context) + + # log the response (iterator) + self._log_unary_response( + response_or_iterator, span, attributes, context + ) + + return response_or_iterator + + except Exception as exc: + # Bare exceptions are likely to be gRPC aborts, which + # we handle in our context wrapper. + # Here, we're interested in uncaught exceptions. + # pylint:disable=unidiomatic-typecheck + if type(exc) != Exception: + span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, + grpc.StatusCode.UNKNOWN.value[0] + ) + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) + ) + span.record_exception(exc) + raise exc + + def _intercept_streaming_response(self, behavior, request_or_iterator, context, request_streaming, full_method): + with self._set_remote_context(context): + attributes = self._create_attributes(context, full_method) + + with self._tracer.start_as_current_span( + name=full_method, + kind=trace.SpanKind.SERVER, + attributes=attributes, + end_on_exit=True, + record_exception=False, + set_status_on_exception=False + ) as span: + + try: + # wrap the context + context = _OpenTelemetryServicerContext(context, span) + + # wrap / log the request (iterator) + if request_streaming: + request_or_iterator = self._log_streaming_request( + request_or_iterator, span, attributes + ) + else: + self._log_unary_request( + request_or_iterator, span, attributes + ) + + # call the actual RPC and track the duration + with self._record_duration(attributes, context): + response_or_iterator = behavior(request_or_iterator, context) + + # log the response (iterator) + yield from self._log_streaming_response( + response_or_iterator, span, attributes, context + ) + + except Exception as exc: + # Bare exceptions are likely to be gRPC aborts, which + # we handle in our context wrapper. + # Here, we're interested in uncaught exceptions. + # pylint:disable=unidiomatic-typecheck + if type(exc) != Exception: + span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, + grpc.StatusCode.UNKNOWN.value[0] + ) + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) + ) + span.record_exception(exc) + raise exc + def _log_unary_request(self, request, active_span, attributes): message_size_by = request.ByteSize() _add_message_event( @@ -429,7 +487,7 @@ def _log_unary_response(self, response, active_span, attributes, context): self._response_size_histogram.record(message_size_by, attributes) self._responses_per_rpc_histogram.record(1, attributes) - def _log_stream_requests(self, request_iterator, active_span, attributes): + def _log_streaming_request(self, request_iterator, active_span, attributes): req_id = 1 for req_id, msg in enumerate(request_iterator, start=1): message_size_by = msg.ByteSize() @@ -441,44 +499,20 @@ def _log_stream_requests(self, request_iterator, active_span, attributes): self._requests_per_rpc_histogram.record(req_id, attributes) - def _log_stream_responses(self, response_iterator, active_span, attributes, context): - with trace.use_span( - active_span, - end_on_exit=True, - record_exception=False, - set_status_on_exception=False - ): - try: - res_id = 1 - for res_id, msg in enumerate(response_iterator, start=1): - message_size_by = msg.ByteSize() - _add_message_event( - active_span, MessageTypeValues.SENT.value, message_size_by, message_id=res_id - ) - self._response_size_histogram.record(message_size_by, attributes) - yield msg - except Exception as exc: - # Bare exceptions are likely to be gRPC aborts, which - # we handle in our context wrapper. - # Here, we're interested in uncaught exceptions. - # pylint:disable=unidiomatic-typecheck - if type(exc) != Exception: - active_span.set_attribute( - SpanAttributes.RPC_GRPC_STATUS_CODE, - grpc.StatusCode.UNKNOWN.value[0] - ) - active_span.set_status( - Status( - status_code=StatusCode.ERROR, - description=f"{type(exc).__name__}: {exc}", - ) - ) - active_span.record_exception(exc) - raise exc - finally: - if context._code != grpc.StatusCode.OK: - attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0] - self._responses_per_rpc_histogram.record(res_id, attributes) + def _log_streaming_response(self, response_iterator, active_span, attributes, context): + try: + res_id = 1 + for res_id, msg in enumerate(response_iterator, start=1): + message_size_by = msg.ByteSize() + _add_message_event( + active_span, MessageTypeValues.SENT.value, message_size_by, message_id=res_id + ) + self._response_size_histogram.record(message_size_by, attributes) + yield msg + finally: + if context._code != grpc.StatusCode.OK: + attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0] + self._responses_per_rpc_histogram.record(res_id, attributes) @contextmanager def _record_duration(self, attributes, context): @@ -486,7 +520,7 @@ def _record_duration(self, attributes, context): try: yield finally: - duration = max(round((_time_ns() - start) * 1000), 0) + duration = max(round((_time_ns() - start) / 1000), 0) if context._code != grpc.StatusCode.OK: attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0] self._duration_histogram.record(duration, attributes)