diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bd207b86c..8c97910c6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- `opentelemetry-instrumentation-redis` Add `sanitize_query` config option to allow query sanitization. Enabled by default. + ([#1572](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1572)) + ## Fixed - Fix aiopg instrumentation to work with aiopg < 2.0.0 diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index b85c2336b0..e49ee8f727 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -64,6 +64,8 @@ async def redis_get(): response_hook (Callable) - a function with extra user-defined logic to be performed after performing the request this function signature is: def response_hook(span: Span, instance: redis.connection.Connection, response) -> None +sanitize_query True(Default) or False - enable the Redis query sanization + for example: .. code: python @@ -117,6 +119,9 @@ def response_hook(span, instance, response): _ResponseHookT = typing.Optional[ typing.Callable[[Span, redis.connection.Connection, Any], None] ] +_DbStatementSerializerT = typing.Optional[ + typing.Callable[[Any], str] +] _REDIS_ASYNCIO_VERSION = (4, 2, 0) if redis.VERSION >= _REDIS_ASYNCIO_VERSION: @@ -139,9 +144,11 @@ def _instrument( tracer, request_hook: _RequestHookT = None, response_hook: _ResponseHookT = None, + sanitize_query: bool = True, ): def _traced_execute_command(func, instance, args, kwargs): - query = _format_command_args(args) + query = _format_command_args(args, sanitize_query) + if len(args) > 0 and args[0]: name = args[0] else: @@ -169,7 +176,7 @@ def _traced_execute_pipeline(func, instance, args, kwargs): ) cmds = [ - _format_command_args(c.args if hasattr(c, "args") else c[0]) + _format_command_args(c.args if hasattr(c, "args") else c[0], sanitize_query) for c in command_stack ] resource = "\n".join(cmds) @@ -281,6 +288,7 @@ def _instrument(self, **kwargs): tracer, request_hook=kwargs.get("request_hook"), response_hook=kwargs.get("response_hook"), + sanitize_query=kwargs.get("sanitize_query", True), ) def _uninstrument(self, **kwargs): diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/util.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/util.py index fdc5cb5fd6..0cfe284371 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/util.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/util.py @@ -48,25 +48,29 @@ def _extract_conn_attributes(conn_kwargs): return attributes -def _format_command_args(args): +def _format_command_args(args, sanitize_query): """Format command arguments and trim them as needed""" - value_max_len = 100 - value_too_long_mark = "..." - cmd_max_len = 1000 - length = 0 - out = [] - for arg in args: - cmd = str(arg) + if sanitize_query: + # Sanitized query format: "COMMAND ? ?" + out = [str(args[0])] + ["?"] * (len(args) - 1) + else: + value_max_len = 100 + value_too_long_mark = "..." + cmd_max_len = 1000 + length = 0 + out = [] + for arg in args: + cmd = str(arg) - if len(cmd) > value_max_len: - cmd = cmd[:value_max_len] + value_too_long_mark + if len(cmd) > value_max_len: + cmd = cmd[:value_max_len] + value_too_long_mark - if length + len(cmd) > cmd_max_len: - prefix = cmd[: cmd_max_len - length] - out.append(f"{prefix}{value_too_long_mark}") - break + if length + len(cmd) > cmd_max_len: + prefix = cmd[: cmd_max_len - length] + out.append(f"{prefix}{value_too_long_mark}") + break - out.append(cmd) - length += len(cmd) + out.append(cmd) + length += len(cmd) return " ".join(out) diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index 3abcb516ff..2816fbaba2 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -148,6 +148,45 @@ def request_hook(span, conn, args, kwargs): span = spans[0] self.assertEqual(span.attributes.get(custom_attribute_name), "GET") + def test_query_sanitizer_enabled(self): + redis_client = redis.Redis() + connection = redis.connection.Connection() + redis_client.connection = connection + + RedisInstrumentor().uninstrument() + RedisInstrumentor().instrument( + tracer_provider=self.tracer_provider + ) + + with mock.patch.object(redis_client, "connection"): + redis_client.get("key") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + + span = spans[0] + self.assertEqual(span.attributes.get("db.statement"), "GET ?") + + def test_query_sanitizer_disabled(self): + redis_client = redis.Redis() + connection = redis.connection.Connection() + redis_client.connection = connection + + RedisInstrumentor().uninstrument() + RedisInstrumentor().instrument( + tracer_provider=self.tracer_provider, + sanitize_query=False, + ) + + with mock.patch.object(redis_client, "connection"): + redis_client.get("key") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + + span = spans[0] + self.assertEqual(span.attributes.get("db.statement"), "GET key") + def test_no_op_tracer_provider(self): RedisInstrumentor().uninstrument() tracer_provider = trace.NoOpTracerProvider