diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py index 4f9c23659d..c070b99232 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py @@ -21,32 +21,39 @@ Implementation of the service-side open-telemetry interceptor. """ +import copy import logging from contextlib import contextmanager +from typing import Callable, Dict, Iterable, Iterator, Generator, NoReturn, Optional import grpc -from opentelemetry import trace +from opentelemetry import metrics, trace from opentelemetry.context import attach, detach +from opentelemetry.instrumentation.grpc._types import Metadata, ProtoMessage, ProtoMessageOrIterator +from opentelemetry.instrumentation.grpc._utilities import _EventMetricRecorder, _MetricKind from opentelemetry.propagate import extract from opentelemetry.semconv.trace import MessageTypeValues, RpcSystemValues, SpanAttributes from opentelemetry.trace.status import Status, StatusCode -from opentelemetry.util._time import _time_ns +from opentelemetry.util.types import Attributes logger = logging.getLogger(__name__) -_MESSAGE = "message" -"""event name of a message.""" - _RPC_USER_AGENT = "rpc.user_agent" """span attribute for RPC user agent.""" # wrap an RPC call # see https://github.com/grpc/grpc/issues/18191 -def _wrap_rpc_behavior(handler, continuation): +def _wrap_rpc_behavior( + handler: Optional[grpc.RpcMethodHandler], + continuation: Callable[ + [ProtoMessageOrIterator, grpc.ServicerContext], + ProtoMessageOrIterator + ] +) -> Optional[grpc.RpcMethodHandler]: if handler is None: return None @@ -72,26 +79,14 @@ def _wrap_rpc_behavior(handler, continuation): ) -def _add_message_event( - span, - message_type, - message_size_by, - message_id=1 -): - span.add_event( - _MESSAGE, - { - SpanAttributes.MESSAGE_TYPE: message_type, - SpanAttributes.MESSAGE_ID: message_id, - SpanAttributes.MESSAGE_UNCOMPRESSED_SIZE: message_size_by, - } - ) - - # pylint:disable=abstract-method class _OpenTelemetryServicerContext(grpc.ServicerContext): - def __init__(self, servicer_context, active_span): + def __init__( + self, + servicer_context: grpc.ServicerContext, + active_span: trace.Span + ) -> None: self._servicer_context = servicer_context self._active_span = active_span self._code = grpc.StatusCode.OK @@ -101,13 +96,13 @@ def __init__(self, servicer_context, active_span): def __getattr__(self, attr): return getattr(self._servicer_context, attr) - def is_active(self, *args, **kwargs): - return self._servicer_context.is_active(*args, **kwargs) + # Interface of grpc.RpcContext - def time_remaining(self, *args, **kwargs): - return self._servicer_context.time_remaining(*args, **kwargs) + # pylint: disable=invalid-name + def add_callback(self, fn: Callable[[], None]) -> None: + return self._servicer_context.add_callback(fn) - def cancel(self, *args, **kwargs): + def cancel(self) -> None: self._code = grpc.StatusCode.CANCELLED self._details = grpc.StatusCode.CANCELLED.value[1] self._active_span.set_attribute( @@ -119,42 +114,17 @@ def cancel(self, *args, **kwargs): description=f"{self._code}: {self._details}", ) ) - return self._servicer_context.cancel(*args, **kwargs) - - def add_callback(self, *args, **kwargs): - return self._servicer_context.add_callback(*args, **kwargs) - - def disable_next_message_compression(self): - return self._service_context.disable_next_message_compression() - - def invocation_metadata(self): - return self._servicer_context.invocation_metadata() - - def peer(self): - return self._servicer_context.peer() - - def peer_identities(self): - return self._servicer_context.peer_identities() + return self._servicer_context.cancel() - def peer_identity_key(self): - return self._servicer_context.peer_identity_key() - - def auth_context(self): - return self._servicer_context.auth_context() - - def set_compression(self, compression): - return self._servicer_context.set_compression(compression) - - def send_initial_metadata(self, *args, **kwargs): - return self._servicer_context.send_initial_metadata(*args, **kwargs) + def is_active(self) -> bool: + return self._servicer_context.is_active() - def set_trailing_metadata(self, *args, **kwargs): - return self._servicer_context.set_trailing_metadata(*args, **kwargs) + def time_remaining(self) -> Optional[float]: + return self._servicer_context.time_remaining() - def trailing_metadata(self): - return self._servicer_context.trailing_metadata() + # Interface of grpc.ServicerContext - def abort(self, code, details): + def abort(self, code: grpc.StatusCode, details: str) -> NoReturn: if not hasattr(self._servicer_context, "abort"): raise RuntimeError( "abort() is not supported with the installed version of grpcio" @@ -172,21 +142,51 @@ def abort(self, code, details): ) return self._servicer_context.abort(code, details) - def abort_with_status(self, status): + def abort_with_status(self, status: grpc.Status) -> NoReturn: if not hasattr(self._servicer_context, "abort_with_status"): raise RuntimeError( - "abort_with_status() is not supported with the installed version of grpcio" + "abort_with_status() is not supported with the installed " + "version of grpcio" ) return self._servicer_context.abort_with_status(status) - def code(self): + def auth_context(self) -> Dict[str, Iterable[bytes]]: + return self._servicer_context.auth_context() + + def code(self) -> grpc.StatusCode: if not hasattr(self._servicer_context, "code"): raise RuntimeError( "code() is not supported with the installed version of grpcio" ) return self._servicer_context.code() - def set_code(self, code): + def details(self) -> str: + if not hasattr(self._servicer_context, "details"): + raise RuntimeError( + "details() is not supported with the installed version of " + "grpcio" + ) + return self._servicer_context.details() + + def disable_next_message_compression(self) -> None: + return self._service_context.disable_next_message_compression() + + def invocation_metadata(self) -> Metadata: + return self._servicer_context.invocation_metadata() + + def peer(self) -> str: + return self._servicer_context.peer() + + def peer_identities(self) -> Optional[Iterable[bytes]]: + return self._servicer_context.peer_identities() + + def peer_identity_key(self) -> Optional[str]: + return self._servicer_context.peer_identity_key() + + def send_initial_metadata(self, initial_metadata: Metadata) -> None: + return self._servicer_context.send_initial_metadata(initial_metadata) + + def set_code(self, code: grpc.StatusCode) -> None: self._code = code # use details if we already have it, otherwise the status description details = self._details or code.value[1] @@ -202,15 +202,10 @@ def set_code(self, code): ) return self._servicer_context.set_code(code) - def details(self): - if not hasattr(self._servicer_context, "details"): - raise RuntimeError( - "details() is not supported with the installed version of " - "grpcio" - ) - return self._servicer_context.details() + def set_compression(self, compression: grpc.Compression) -> None: + return self._servicer_context.set_compression(compression) - def set_details(self, details): + def set_details(self, details: str) -> None: self._details = details if self._code != grpc.StatusCode.OK: self._active_span.set_status( @@ -221,11 +216,20 @@ def set_details(self, details): ) return self._servicer_context.set_details(details) + def set_trailing_metadata(self, trailing_metadata: Metadata) -> None: + return self._servicer_context.set_trailing_metadata(trailing_metadata) + + def trailing_metadata(self) -> Metadata: + return self._servicer_context.trailing_metadata() + # pylint:disable=abstract-method # pylint:disable=no-self-use # pylint:disable=unused-argument -class OpenTelemetryServerInterceptor(grpc.ServerInterceptor): +class OpenTelemetryServerInterceptor( + _EventMetricRecorder, + grpc.ServerInterceptor +): """ A gRPC server interceptor, to add OpenTelemetry. @@ -245,38 +249,19 @@ class OpenTelemetryServerInterceptor(grpc.ServerInterceptor): """ - def __init__(self, meter, tracer): - self._meter = meter + def __init__( + self, + meter: metrics.Meter, + tracer: trace.Tracer + ) -> None: + super().__init__(meter, _MetricKind.SERVER) self._tracer = tracer - self._duration_histogram = self._meter.create_histogram( - name="rpc.server.duration", - unit="ms", - description="measures duration of inbound RPC", - ) - self._request_size_histogram = self._meter.create_histogram( - name="rpc.server.request.size", - unit="By", - description="measures size of RPC request messages (uncompressed)", - ) - self._response_size_histogram = self._meter.create_histogram( - name="rpc.server.response.size", - unit="By", - description="measures size of RPC response messages (uncompressed)", - ) - self._requests_per_rpc_histogram = self._meter.create_histogram( - name="rpc.server.requests_per_rpc", - unit="requests", - description="measures the number of messages received per RPC. Should be 1 for all non-streaming RPCs", - ) - self._responses_per_rpc_histogram = self._meter.create_histogram( - name="rpc.server.responses_per_rpc", - unit="responses", - description="measures the number of messages sent per RPC. Should be 1 for all non-streaming RPCs", - ) - @contextmanager - def _set_remote_context(self, context): + def _set_remote_context( + self, + context: grpc.ServicerContext + ) -> Generator[None, None, None]: metadata = context.invocation_metadata() if metadata: md_dict = {md.key: md.value for md in metadata} @@ -289,22 +274,19 @@ def _set_remote_context(self, context): else: yield - def _create_attributes(self, context, full_method): + def _create_attributes( + self, + context: grpc.ServicerContext, + full_method: str + ) -> Attributes: # standard attributes + service, method = full_method.lstrip("/").split("/", 1) attributes = { SpanAttributes.RPC_SYSTEM: RpcSystemValues.GRPC.value, - SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0], + SpanAttributes.RPC_SERVICE: service, + SpanAttributes.RPC_METHOD: method } - # add service and method attributes - service, method = full_method.lstrip("/").split("/", 1) - attributes.update( - { - SpanAttributes.RPC_SERVICE: service, - SpanAttributes.RPC_METHOD: method, - } - ) - # add some attributes from the metadata metadata = dict(context.invocation_metadata()) if "user-agent" in metadata: @@ -336,18 +318,48 @@ def _create_attributes(self, context, full_method): return attributes - def intercept_service(self, continuation, handler_call_details): - def telemetry_wrapper(behavior, request_streaming, response_streaming): - def telemetry_interceptor(request_or_iterator, context): - if response_streaming: - return self._intercept_streaming_response( + def intercept_service( + self, + continuation: Callable[ + [grpc.HandlerCallDetails], Optional[grpc.RpcMethodHandler] + ], + handler_call_details: grpc.HandlerCallDetails + ) -> Optional[grpc.RpcMethodHandler]: + + def telemetry_wrapper( + behavior: Callable[ + [ProtoMessageOrIterator, grpc.ServicerContext], + ProtoMessageOrIterator + ], + request_streaming: bool, + response_streaming: bool + ) -> Callable[ + [ProtoMessageOrIterator, grpc.ServicerContext], + ProtoMessageOrIterator + ]: + + def telemetry_interceptor( + request_or_iterator: ProtoMessageOrIterator, + context: grpc.ServicerContext + ) -> ProtoMessageOrIterator: + if request_streaming and response_streaming: + return self.intercept_stream_stream( behavior, request_or_iterator, context, - request_streaming, handler_call_details.method + handler_call_details.method ) - - return self._intercept_unary_response( + if not request_streaming and response_streaming: + return self.intercept_unary_stream( + behavior, request_or_iterator, context, + handler_call_details.method + ) + if request_streaming and not response_streaming: + return self.intercept_stream_unary( + behavior, request_or_iterator, context, + handler_call_details.method + ) + return self.intercept_unary_unary( behavior, request_or_iterator, context, - request_streaming, handler_call_details.method + handler_call_details.method ) return telemetry_interceptor @@ -357,170 +369,281 @@ def telemetry_interceptor(request_or_iterator, context): telemetry_wrapper ) - def _intercept_unary_response(self, behavior, request_or_iterator, context, request_streaming, full_method): + def intercept_unary_unary( + self, + continuation: Callable[ + [ProtoMessage, grpc.ServicerContext], ProtoMessage + ], + request: ProtoMessage, + context: grpc.ServicerContext, + full_method: str + ) -> ProtoMessage: with self._set_remote_context(context): - attributes = self._create_attributes(context, full_method) + metric_attributes = self._create_attributes(context, full_method) + span_attributes = copy.deepcopy(metric_attributes) + span_attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = ( + grpc.StatusCode.OK.value[0] + ) with self._tracer.start_as_current_span( name=full_method, kind=trace.SpanKind.SERVER, - attributes=attributes, + attributes=span_attributes, end_on_exit=True, record_exception=False, set_status_on_exception=False ) as span: + with self._record_duration_manager(metric_attributes, context): + + try: + # wrap the context + context = _OpenTelemetryServicerContext(context, span) + + # record the request + self._record_unary_request( + span, + request, + MessageTypeValues.RECEIVED, + metric_attributes + ) - try: - # wrap the context - context = _OpenTelemetryServicerContext(context, span) + # call the actual RPC + response = continuation(request, context) - # wrap / log the request (iterator) - if request_streaming: - request_or_iterator = self._log_streaming_request( - request_or_iterator, span, attributes - ) - else: - self._log_unary_request( - request_or_iterator, span, attributes + # record the response + self._record_unary_response( + span, + response, + MessageTypeValues.SENT, + metric_attributes ) - # call the actual RPC and track the duration - with self._record_duration(attributes, context): - response_or_iterator = behavior(request_or_iterator, context) + return response + + except Exception as exc: + # Bare exceptions are likely to be gRPC aborts, which + # we handle in our context wrapper. + # Here, we're interested in uncaught exceptions. + # pylint:disable=unidiomatic-typecheck + if type(exc) != Exception: + span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, + grpc.StatusCode.UNKNOWN.value[0] + ) + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) + ) + span.record_exception(exc) + raise exc + + def intercept_unary_stream( + self, + continuation: Callable[ + [ProtoMessage, grpc.ServicerContext], Iterator[ProtoMessage] + ], + request: ProtoMessage, + context: grpc.ServicerContext, + full_method: str + ) -> Iterator[ProtoMessage]: + with self._set_remote_context(context): + metric_attributes = self._create_attributes(context, full_method) + span_attributes = copy.deepcopy(metric_attributes) + span_attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = ( + grpc.StatusCode.OK.value[0] + ) - # log the response (iterator) - self._log_unary_response( - response_or_iterator, span, attributes, context - ) + with self._tracer.start_as_current_span( + name=full_method, + kind=trace.SpanKind.SERVER, + attributes=span_attributes, + end_on_exit=True, + record_exception=False, + set_status_on_exception=False + ) as span: - return response_or_iterator - - except Exception as exc: - # Bare exceptions are likely to be gRPC aborts, which - # we handle in our context wrapper. - # Here, we're interested in uncaught exceptions. - # pylint:disable=unidiomatic-typecheck - if type(exc) != Exception: - span.set_attribute( - SpanAttributes.RPC_GRPC_STATUS_CODE, - grpc.StatusCode.UNKNOWN.value[0] + with self._record_duration_manager(metric_attributes, context): + try: + # wrap the context + context = _OpenTelemetryServicerContext(context, span) + + # record the request + self._record_unary_request( + span, + request, + MessageTypeValues.RECEIVED, + metric_attributes ) - span.set_status( - Status( - status_code=StatusCode.ERROR, - description=f"{type(exc).__name__}: {exc}", - ) + + # call the actual RPC + response_iterator = continuation(request, context) + + # wrap the response iterator with a recorder + yield from self._record_streaming_response( + span, + response_iterator, + MessageTypeValues.SENT, + metric_attributes ) - span.record_exception(exc) - raise exc - def _intercept_streaming_response(self, behavior, request_or_iterator, context, request_streaming, full_method): + except Exception as exc: + # Bare exceptions are likely to be gRPC aborts, which + # we handle in our context wrapper. + # Here, we're interested in uncaught exceptions. + # pylint:disable=unidiomatic-typecheck + if type(exc) != Exception: + span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, + grpc.StatusCode.UNKNOWN.value[0] + ) + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) + ) + span.record_exception(exc) + raise exc + + def intercept_stream_unary( + self, + continuation: Callable[ + [Iterator[ProtoMessage], grpc.ServicerContext], ProtoMessage + ], + request_iterator: Iterator[ProtoMessage], + context: grpc.ServicerContext, + full_method: str + ) -> ProtoMessage: with self._set_remote_context(context): - attributes = self._create_attributes(context, full_method) + metric_attributes = self._create_attributes(context, full_method) + span_attributes = copy.deepcopy(metric_attributes) + span_attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = ( + grpc.StatusCode.OK.value[0] + ) with self._tracer.start_as_current_span( name=full_method, kind=trace.SpanKind.SERVER, - attributes=attributes, + attributes=span_attributes, end_on_exit=True, record_exception=False, set_status_on_exception=False ) as span: - - try: - # wrap the context - context = _OpenTelemetryServicerContext(context, span) - - # wrap / log the request (iterator) - if request_streaming: - request_or_iterator = self._log_streaming_request( - request_or_iterator, span, attributes - ) - else: - self._log_unary_request( - request_or_iterator, span, attributes + with self._record_duration_manager(metric_attributes, context): + + try: + # wrap the context + context = _OpenTelemetryServicerContext(context, span) + + # wrap the request iterator with a recorder + request_iterator = self._record_streaming_request( + span, + request_iterator, + MessageTypeValues.RECEIVED, + metric_attributes ) - # call the actual RPC and track the duration - with self._record_duration(attributes, context): - response_or_iterator = behavior(request_or_iterator, context) + # call the actual RPC + response = continuation(request_iterator, context) - # log the response (iterator) - yield from self._log_streaming_response( - response_or_iterator, span, attributes, context + # record the response + self._record_unary_response( + span, + response, + MessageTypeValues.SENT, + metric_attributes ) - except Exception as exc: - # Bare exceptions are likely to be gRPC aborts, which - # we handle in our context wrapper. - # Here, we're interested in uncaught exceptions. - # pylint:disable=unidiomatic-typecheck - if type(exc) != Exception: - span.set_attribute( - SpanAttributes.RPC_GRPC_STATUS_CODE, - grpc.StatusCode.UNKNOWN.value[0] - ) - span.set_status( - Status( - status_code=StatusCode.ERROR, - description=f"{type(exc).__name__}: {exc}", + return response + + except Exception as exc: + # Bare exceptions are likely to be gRPC aborts, which + # we handle in our context wrapper. + # Here, we're interested in uncaught exceptions. + # pylint:disable=unidiomatic-typecheck + if type(exc) != Exception: + span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, + grpc.StatusCode.UNKNOWN.value[0] ) - ) - span.record_exception(exc) - raise exc + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) + ) + span.record_exception(exc) + raise exc + + def intercept_stream_stream( + self, + continuation: Callable[ + [Iterator[ProtoMessage], grpc.ServicerContext], + Iterator[ProtoMessage] + ], + request_iterator: Iterator[ProtoMessage], + context: grpc.ServicerContext, + full_method: str + ) -> Iterator[ProtoMessage]: + with self._set_remote_context(context): + metric_attributes = self._create_attributes(context, full_method) + span_attributes = copy.deepcopy(metric_attributes) + span_attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = ( + grpc.StatusCode.OK.value[0] + ) - def _log_unary_request(self, request, active_span, attributes): - message_size_by = request.ByteSize() - _add_message_event( - active_span, MessageTypeValues.RECEIVED.value, message_size_by - ) - self._request_size_histogram.record(message_size_by, attributes) - self._requests_per_rpc_histogram.record(1, attributes) + with self._tracer.start_as_current_span( + name=full_method, + kind=trace.SpanKind.SERVER, + attributes=span_attributes, + end_on_exit=True, + record_exception=False, + set_status_on_exception=False + ) as span: - def _log_unary_response(self, response, active_span, attributes, context): - message_size_by = response.ByteSize() - _add_message_event( - active_span, MessageTypeValues.SENT.value, message_size_by - ) - if context._code != grpc.StatusCode.OK: - attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0] - self._response_size_histogram.record(message_size_by, attributes) - self._responses_per_rpc_histogram.record(1, attributes) - - def _log_streaming_request(self, request_iterator, active_span, attributes): - req_id = 1 - for req_id, msg in enumerate(request_iterator, start=1): - message_size_by = msg.ByteSize() - _add_message_event( - active_span, MessageTypeValues.RECEIVED.value, message_size_by, message_id=req_id - ) - self._request_size_histogram.record(message_size_by, attributes) - yield msg + with self._record_duration_manager(metric_attributes, context): + try: + # wrap the context + context = _OpenTelemetryServicerContext(context, span) + + # wrap the request iterator with a recorder + request_iterator = self._record_streaming_request( + span, + request_iterator, + MessageTypeValues.RECEIVED, + metric_attributes + ) - self._requests_per_rpc_histogram.record(req_id, attributes) + # call the actual RPC + response_iterator = continuation( + request_iterator, context + ) - def _log_streaming_response(self, response_iterator, active_span, attributes, context): - try: - res_id = 1 - for res_id, msg in enumerate(response_iterator, start=1): - message_size_by = msg.ByteSize() - _add_message_event( - active_span, MessageTypeValues.SENT.value, message_size_by, message_id=res_id - ) - self._response_size_histogram.record(message_size_by, attributes) - yield msg - finally: - if context._code != grpc.StatusCode.OK: - attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0] - self._responses_per_rpc_histogram.record(res_id, attributes) + # wrap the response iterator with a recorder + yield from self._record_streaming_response( + span, + response_iterator, + MessageTypeValues.SENT, + metric_attributes + ) - @contextmanager - def _record_duration(self, attributes, context): - start = _time_ns() - try: - yield - finally: - duration = max(round((_time_ns() - start) / 1000), 0) - if context._code != grpc.StatusCode.OK: - attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0] - self._duration_histogram.record(duration, attributes) + except Exception as exc: + # Bare exceptions are likely to be gRPC aborts, which + # we handle in our context wrapper. + # Here, we're interested in uncaught exceptions. + # pylint:disable=unidiomatic-typecheck + if type(exc) != Exception: + span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, + grpc.StatusCode.UNKNOWN.value[0] + ) + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) + ) + span.record_exception(exc) + raise exc