From 8bf9211b97a4e7f7998db152ad431e9b3c282a8d Mon Sep 17 00:00:00 2001 From: Chris Guidry Date: Wed, 17 Jul 2024 19:01:56 -0400 Subject: [PATCH] Re-raise `redis.WatchError`s when they occur In #2668, we started to avoid having the `redis.WatchError` mark the span as having failed, but we were inadvertently suppressing that exception from the calling code. `redis.WatchError` is used for flow-of-control concurrency in many applications, and applications depend on catching the error to handle concurrent changes to keys during redis pipelines. This re-raises the `WatchError` to keep the instrumentation transparent to the application. Fixes #2639 --- .../instrumentation/redis/__init__.py | 21 +++++++++++++++---- .../tests/test_redis.py | 8 ++----- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index c5f19fc736..08337c2d4a 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -203,6 +203,8 @@ def _traced_execute_pipeline(func, instance, args, kwargs): span_name, ) = _build_span_meta_data_for_pipeline(instance) + exception = None + with tracer.start_as_current_span( span_name, kind=trace.SpanKind.CLIENT ) as span: @@ -216,13 +218,17 @@ def _traced_execute_pipeline(func, instance, args, kwargs): response = None try: response = func(*args, **kwargs) - except redis.WatchError: + except redis.WatchError as watch_exception: span.set_status(StatusCode.UNSET) + exception = watch_exception if callable(response_hook): response_hook(span, instance, response) - return response + if exception: + raise exception + + return response pipeline_class = ( "BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline" @@ -279,6 +285,8 @@ async def _async_traced_execute_pipeline(func, instance, args, kwargs): span_name, ) = _build_span_meta_data_for_pipeline(instance) + exception = None + with tracer.start_as_current_span( span_name, kind=trace.SpanKind.CLIENT ) as span: @@ -292,12 +300,17 @@ async def _async_traced_execute_pipeline(func, instance, args, kwargs): response = None try: response = await func(*args, **kwargs) - except redis.WatchError: + except redis.WatchError as watch_exception: span.set_status(StatusCode.UNSET) + exception = watch_exception if callable(response_hook): response_hook(span, instance, response) - return response + + if exception: + raise exception + + return response if redis.VERSION >= _REDIS_ASYNCIO_VERSION: wrap_function_wrapper( diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index 23d21b6e5a..c436589adb 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -359,7 +359,7 @@ def test_response_error(self): def test_watch_error_sync(self): def redis_operations(): - try: + with pytest.raises(WatchError): redis_client = fakeredis.FakeStrictRedis() pipe = redis_client.pipeline(transaction=True) pipe.watch("a") @@ -367,8 +367,6 @@ def redis_operations(): pipe.multi() pipe.set("a", "1") pipe.execute() - except WatchError: - pass redis_operations() @@ -400,7 +398,7 @@ def tearDown(self): @pytest.mark.asyncio async def test_watch_error_async(self): async def redis_operations(): - try: + with pytest.raises(WatchError): redis_client = FakeRedis() async with redis_client.pipeline(transaction=False) as pipe: await pipe.watch("a") @@ -408,8 +406,6 @@ async def redis_operations(): pipe.multi() await pipe.set("a", "1") await pipe.execute() - except WatchError: - pass await redis_operations()