Skip to content

Commit

Permalink
Add type hints to Redis
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Dec 13, 2024
1 parent 7660a00 commit 85dbb70
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ def response_hook(span, instance, response):
---
"""

import typing
from typing import Any, Collection
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Collection

import redis
from wrapt import wrap_function_wrapper
Expand All @@ -109,18 +110,43 @@ def response_hook(span, instance, response):
from opentelemetry.instrumentation.redis.version import __version__
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span, StatusCode
from opentelemetry.trace import Span, StatusCode, Tracer

if TYPE_CHECKING:
from typing import Awaitable, TypeAlias, TypeVar

import redis.asyncio.client
import redis.asyncio.cluster
import redis.client
import redis.cluster
import redis.connection

_RequestHookT: TypeAlias = "Callable[[Span, redis.connection.Connection, list[Any], dict[str, Any]], None]"
_ResponseHookT: TypeAlias = (
"Callable[[Span, redis.connection.Connection, Any], None]"
)

AsyncPipelineInstance = TypeVar(
"AsyncPipelineInstance",
redis.asyncio.client.Pipeline,
redis.asyncio.cluster.ClusterPipeline,
)
AsyncRedisInstance = TypeVar(
"AsyncRedisInstance", redis.asyncio.Redis, redis.asyncio.RedisCluster
)
PipelineInstance = TypeVar(
"PipelineInstance",
redis.client.Pipeline,
redis.cluster.ClusterPipeline,
)
RedisInstance = TypeVar(
"RedisInstance", redis.client.Redis, redis.cluster.RedisCluster
)
R = TypeVar("R")


_DEFAULT_SERVICE = "redis"

_RequestHookT = typing.Optional[
typing.Callable[
[Span, redis.connection.Connection, typing.List, typing.Dict], None
]
]
_ResponseHookT = typing.Optional[
typing.Callable[[Span, redis.connection.Connection, Any], None]
]

_REDIS_ASYNCIO_VERSION = (4, 2, 0)
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
Expand All @@ -132,7 +158,9 @@ def response_hook(span, instance, response):
_FIELD_TYPES = ["NUMERIC", "TEXT", "GEO", "TAG", "VECTOR"]


def _set_connection_attributes(span, conn):
def _set_connection_attributes(
span: Span, conn: RedisInstance | AsyncRedisInstance
) -> None:
if not span.is_recording() or not hasattr(conn, "connection_pool"):
return
for key, value in _extract_conn_attributes(
Expand All @@ -141,7 +169,9 @@ def _set_connection_attributes(span, conn):
span.set_attribute(key, value)


def _build_span_name(instance, cmd_args):
def _build_span_name(
instance: RedisInstance | AsyncRedisInstance, cmd_args: tuple[Any, ...]
) -> str:
if len(cmd_args) > 0 and cmd_args[0]:
if cmd_args[0] == "FT.SEARCH":
name = "redis.search"
Expand All @@ -154,7 +184,9 @@ def _build_span_name(instance, cmd_args):
return name


def _build_span_meta_data_for_pipeline(instance):
def _build_span_meta_data_for_pipeline(
instance: PipelineInstance | AsyncPipelineInstance,
) -> tuple[list[Any], str, str]:
try:
command_stack = (
instance.command_stack
Expand Down Expand Up @@ -184,11 +216,16 @@ def _build_span_meta_data_for_pipeline(instance):

# pylint: disable=R0915
def _instrument(
tracer,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
tracer: Tracer,
request_hook: _RequestHookT | None = None,
response_hook: _ResponseHookT | None = None,
):
def _traced_execute_command(func, instance, args, kwargs):
def _traced_execute_command(
func: Callable[..., R],
instance: RedisInstance,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> R:
query = _format_command_args(args)
name = _build_span_name(instance, args)
with tracer.start_as_current_span(
Expand All @@ -210,7 +247,12 @@ def _traced_execute_command(func, instance, args, kwargs):
response_hook(span, instance, response)
return response

def _traced_execute_pipeline(func, instance, args, kwargs):
def _traced_execute_pipeline(
func: Callable[..., R],
instance: PipelineInstance,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> R:
(
command_stack,
resource,
Expand Down Expand Up @@ -242,7 +284,7 @@ def _traced_execute_pipeline(func, instance, args, kwargs):

return response

def _add_create_attributes(span, args):
def _add_create_attributes(span: Span, args: tuple[Any, ...]):
_set_span_attribute_if_value(
span, "redis.create_index.index", _value_or_none(args, 1)
)
Expand All @@ -266,7 +308,7 @@ def _add_create_attributes(span, args):
field_attribute,
)

def _add_search_attributes(span, response, args):
def _add_search_attributes(span: Span, response, args):
_set_span_attribute_if_value(
span, "redis.search.index", _value_or_none(args, 1)
)
Expand Down Expand Up @@ -326,7 +368,12 @@ def _add_search_attributes(span, response, args):
_traced_execute_pipeline,
)

async def _async_traced_execute_command(func, instance, args, kwargs):
async def _async_traced_execute_command(
func: Callable[..., Awaitable[R]],
instance: AsyncRedisInstance,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Awaitable[R]:
query = _format_command_args(args)
name = _build_span_name(instance, args)

Expand All @@ -344,7 +391,12 @@ async def _async_traced_execute_command(func, instance, args, kwargs):
response_hook(span, instance, response)
return response

async def _async_traced_execute_pipeline(func, instance, args, kwargs):
async def _async_traced_execute_pipeline(
func: Callable[..., Awaitable[R]],
instance: AsyncPipelineInstance,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Awaitable[R]:
(
command_stack,
resource,
Expand Down Expand Up @@ -408,14 +460,15 @@ async def _async_traced_execute_pipeline(func, instance, args, kwargs):


class RedisInstrumentor(BaseInstrumentor):
"""An instrumentor for Redis
"""An instrumentor for Redis.
See `BaseInstrumentor`
"""

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
def _instrument(self, **kwargs: Any):
"""Instruments the redis module
Args:
Expand All @@ -436,7 +489,7 @@ def _instrument(self, **kwargs):
response_hook=kwargs.get("response_hook"),
)

def _uninstrument(self, **kwargs):
def _uninstrument(self, **kwargs: Any):
if redis.VERSION < (3, 0, 0):
unwrap(redis.StrictRedis, "execute_command")
unwrap(redis.StrictRedis, "pipeline")
Expand Down

0 comments on commit 85dbb70

Please sign in to comment.