diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index e887e7323f..7ef5aa2e6d 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -268,7 +268,7 @@ def _traced_aggregate(func, instance, args, kwargs): _set_span_attribute( span, "redis.commands.aggregate.query", - query.query_string(), + query._query, ) response = func(*args, **kwargs) _set_span_attribute( diff --git a/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py b/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py index 481b8d21c8..96510797cc 100644 --- a/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py +++ b/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py @@ -18,6 +18,15 @@ import redis import redis.asyncio +from redis.exceptions import ResponseError +from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.aggregation import AggregateRequest +from redis.commands.search.query import Query +from redis.commands.search.field import ( + TextField, + VectorField, +) + from opentelemetry import trace from opentelemetry.instrumentation.redis import RedisInstrumentor from opentelemetry.semconv.trace import SpanAttributes @@ -614,3 +623,72 @@ def test_get(self): self.assertEqual( span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?" ) + + +class TestRedisearchInstrument(TestBase): + def setUp(self): + super().setUp() + self.redis_client = redis.Redis(port=6379) + self.redis_client.flushall() + self.embedding_dim = 256 + RedisInstrumentor().instrument(tracer_provider=self.tracer_provider) + self.prepare_data() + self.create_index() + + def tearDown(self): + RedisInstrumentor().uninstrument() + super().tearDown() + + def prepare_data(self): + try: + self.redis_client.ft("idx:test_vss").dropindex(True) + except ResponseError: + print("No such index") + item = {"name": "test", + "value": "test_value", + "embeddings": [0.1] * 256} + pipeline = self.redis_client.pipeline() + pipeline.json().set(f"test:001", "$", item) + res = pipeline.execute() + assert False not in res + + def create_index(self): + schema = ( + TextField("$.name", no_stem=True, as_name="name"), + TextField("$.value", no_stem=True, as_name="value"), + VectorField("$.embeddings", + "FLAT", + { + "TYPE": "FLOAT32", + "DIM": self.embedding_dim, + "DISTANCE_METRIC": "COSINE", + }, + as_name="vector",), + ) + definition = IndexDefinition(prefix=["test:"], index_type=IndexType.JSON) + res = self.redis_client.ft("idx:test_vss").create_index(fields=schema, definition=definition) + assert "OK" in str(res) + + def test_redis_create_index(self): + spans = self.memory_exporter.get_finished_spans() + span = next(span for span in spans if span.name == "redis.create_index") + assert "redis.create_index.definition" in span.attributes + assert "redis.create_index.fields" in span.attributes + + def test_redis_aggregate(self): + query = "*" + self.redis_client.ft("idx:test_vss").aggregate(AggregateRequest(query).load()) + spans = self.memory_exporter.get_finished_spans() + span = next(span for span in spans if span.name == "redis.aggregate") + assert span.attributes.get("redis.commands.aggregate.query") == query + assert "redis.commands.aggregate.results" in span.attributes + + def test_redis_query(self): + query = "@name:test" + res = self.redis_client.ft("idx:test_vss").search(Query(query)) + + spans = self.memory_exporter.get_finished_spans() + span = next(span for span in spans if span.name == "redis.search") + + assert span.attributes.get("redis.commands.search.query") == query + assert span.attributes.get("redis.commands.search.total") == 1