diff --git a/CHANGELOG.md b/CHANGELOG.md index db6bb0cf8d..649ba8a224 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#1920](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1920)) - Consolidate instrumentation suppression mechanisms and fix bug in httpx instrumentation ([#2061](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2061)) +- `opentelemetry-instrument-grpc` Fix arity of context.abort for AIO RPCs + ([#2066](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2066)) ### Fixed diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_server.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_server.py index 7c20de0cc0..0db4c36edf 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_server.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_server.py @@ -12,13 +12,63 @@ # See the License for the specific language governing permissions and # limitations under the License. +import grpc import grpc.aio - -from ._server import ( - OpenTelemetryServerInterceptor, - _OpenTelemetryServicerContext, - _wrap_rpc_behavior, -) +import wrapt + +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.trace.status import Status, StatusCode + +from ._server import OpenTelemetryServerInterceptor, _wrap_rpc_behavior + + +# pylint:disable=abstract-method +class _OpenTelemetryAioServicerContext(wrapt.ObjectProxy): + def __init__(self, servicer_context, active_span): + super().__init__(servicer_context) + self._self_active_span = active_span + self._self_code = grpc.StatusCode.OK + self._self_details = None + + async def abort(self, code, details="", trailing_metadata=tuple()): + self._self_code = code + self._self_details = details + self._self_active_span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, code.value[0] + ) + self._self_active_span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{code}:{details}", + ) + ) + return await self.__wrapped__.abort(code, details, trailing_metadata) + + def set_code(self, code): + self._self_code = code + details = self._self_details or code.value[1] + self._self_active_span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, code.value[0] + ) + if code != grpc.StatusCode.OK: + self._self_active_span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{code}:{details}", + ) + ) + return self.__wrapped__.set_code(code) + + def set_details(self, details): + self._self_details = details + if self._self_code != grpc.StatusCode.OK: + self._self_active_span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{self._self_code}:{details}", + ) + ) + return self.__wrapped__.set_details(details) class OpenTelemetryAioServerInterceptor( @@ -66,7 +116,7 @@ async def _unary_interceptor(request_or_iterator, context): set_status_on_exception=False, ) as span: # wrap the context - context = _OpenTelemetryServicerContext(context, span) + context = _OpenTelemetryAioServicerContext(context, span) # And now we run the actual RPC. try: @@ -91,7 +141,7 @@ async def _stream_interceptor(request_or_iterator, context): context, set_status_on_exception=False, ) as span: - context = _OpenTelemetryServicerContext(context, span) + context = _OpenTelemetryAioServicerContext(context, span) try: async for response in behavior( diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py index 52391124b7..242295c08c 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py @@ -88,8 +88,11 @@ async def run_with_test_server( channel = grpc.aio.insecure_channel(f"localhost:{port:d}") await server.start() - resp = await runnable(channel) - await server.stop(1000) + + try: + resp = await runnable(channel) + finally: + await server.stop(1000) return resp @@ -514,9 +517,79 @@ async def request(channel): request = Request(client_id=1, request_data=failure_message) msg = request.SerializeToString() - with testcase.assertRaises(Exception): + with testcase.assertRaises(grpc.RpcError) as cm: + await channel.unary_unary(rpc_call)(msg) + + self.assertEqual( + cm.exception.code(), grpc.StatusCode.FAILED_PRECONDITION + ) + self.assertEqual(cm.exception.details(), failure_message) + + await run_with_test_server(request, servicer=AbortServicer()) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + # make sure this span errored, with the right status and detail + self.assertEqual(span.status.status_code, StatusCode.ERROR) + self.assertEqual( + span.status.description, + f"{grpc.StatusCode.FAILED_PRECONDITION}:{failure_message}", + ) + + # Check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.FAILED_PRECONDITION.value[ + 0 + ], + }, + ) + + async def test_abort_with_trailing_metadata(self): + """Check that we can catch an abort properly when trailing_metadata provided""" + rpc_call = "/GRPCTestServer/SimpleMethod" + failure_message = "failure message" + + class AbortServicer(GRPCTestServerServicer): + # pylint:disable=C0103 + async def SimpleMethod(self, request, context): + metadata = (("meta", "data"),) + await context.abort( + grpc.StatusCode.FAILED_PRECONDITION, + failure_message, + trailing_metadata=metadata, + ) + + testcase = self + + async def request(channel): + request = Request(client_id=1, request_data=failure_message) + msg = request.SerializeToString() + + with testcase.assertRaises(grpc.RpcError) as cm: await channel.unary_unary(rpc_call)(msg) + self.assertEqual( + cm.exception.code(), grpc.StatusCode.FAILED_PRECONDITION + ) + self.assertEqual(cm.exception.details(), failure_message) + await run_with_test_server(request, servicer=AbortServicer()) spans_list = self.memory_exporter.get_finished_spans()