From 1521bb1205128a85b515b127def1f0decabb1998 Mon Sep 17 00:00:00 2001 From: aditya_tewary Date: Wed, 31 Jul 2024 12:33:43 +0530 Subject: [PATCH] Add support to intrument a single MySQL conn --- docs/integrations/mysql.md | 4 + logfire-api/logfire_api/__init__.pyi | 2 + .../_internal/integrations/mysql.pyi | 34 ++ logfire-api/logfire_api/_internal/main.pyi | 17 + logfire/_internal/integrations/mysql.py | 26 +- logfire/_internal/main.py | 351 ++++++------------ pyproject.toml | 1 + requirements-dev.lock | 1 + tests/otel_integrations/test_mysql.py | 141 +++++-- 9 files changed, 309 insertions(+), 268 deletions(-) create mode 100644 logfire-api/logfire_api/_internal/integrations/mysql.pyi diff --git a/docs/integrations/mysql.md b/docs/integrations/mysql.md index 0688413ec..53d21e800 100644 --- a/docs/integrations/mysql.md +++ b/docs/integrations/mysql.md @@ -48,6 +48,7 @@ import mysql.connector logfire.configure() +# To instrument the whole module: logfire.instrument_mysql() connection = mysql.connector.connect( @@ -59,6 +60,9 @@ connection = mysql.connector.connect( use_pure=True, ) +# Or instrument just the connection: +# connection = logfire.instrument_mysql(connection) + with logfire.span('Create table and insert data'), connection.cursor() as cursor: cursor.execute( 'CREATE TABLE IF NOT EXISTS test (id INT AUTO_INCREMENT PRIMARY KEY, num integer, data varchar(255));' diff --git a/logfire-api/logfire_api/__init__.pyi b/logfire-api/logfire_api/__init__.pyi index d3fa6c56b..1c8fb5ecd 100644 --- a/logfire-api/logfire_api/__init__.pyi +++ b/logfire-api/logfire_api/__init__.pyi @@ -51,6 +51,7 @@ __all__ = [ 'instrument_sqlalchemy', 'instrument_redis', 'instrument_pymongo', + 'instrument_mysql', 'AutoTraceModule', 'with_tags', 'with_settings', @@ -88,6 +89,7 @@ instrument_aiohttp_client = DEFAULT_LOGFIRE_INSTANCE.instrument_aiohttp_client instrument_sqlalchemy = DEFAULT_LOGFIRE_INSTANCE.instrument_sqlalchemy instrument_redis = DEFAULT_LOGFIRE_INSTANCE.instrument_redis instrument_pymongo = DEFAULT_LOGFIRE_INSTANCE.instrument_pymongo +instrument_mysql = DEFAULT_LOGFIRE_INSTANCE.instrument_mysql shutdown = DEFAULT_LOGFIRE_INSTANCE.shutdown with_tags = DEFAULT_LOGFIRE_INSTANCE.with_tags with_settings = DEFAULT_LOGFIRE_INSTANCE.with_settings diff --git a/logfire-api/logfire_api/_internal/integrations/mysql.pyi b/logfire-api/logfire_api/_internal/integrations/mysql.pyi new file mode 100644 index 000000000..eb7440e22 --- /dev/null +++ b/logfire-api/logfire_api/_internal/integrations/mysql.pyi @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from opentelemetry.instrumentation.mysql import MySQLInstrumentor + +if TYPE_CHECKING: + from mysql.connector.abstracts import MySQLConnectionAbstract + from mysql.connector.pooling import PooledMySQLConnection + from typing_extensions import TypedDict, TypeVar, Unpack + + MySQLConnection = TypeVar('MySQLConnection', PooledMySQLConnection, MySQLConnectionAbstract, None) + + class MySQLInstrumentKwargs(TypedDict, total=False): + skip_dep_check: bool + + +def instrument_mysql( + conn: MySQLConnection = None, + **kwargs: Unpack[MySQLInstrumentKwargs], +) -> MySQLConnection: + """Instrument the `mysql` module or a specific MySQL connection so that spans are automatically created for each operation. + + This function uses the OpenTelemetry MySQL Instrumentation library to instrument either the entire `mysql` module or a specific MySQL connection. + + Args: + conn: The MySQL connection to instrument. If None, the entire `mysql` module is instrumented. + **kwargs: Additional keyword arguments to pass to the OpenTelemetry `instrument` methods. + + Returns: + If a connection is provided, returns the instrumented connection. If no connection is provided, returns None. + + See the `Logfire.instrument_mysql` method for details. + """ diff --git a/logfire-api/logfire_api/_internal/main.pyi b/logfire-api/logfire_api/_internal/main.pyi index 6b8447a5c..64fff9680 100644 --- a/logfire-api/logfire_api/_internal/main.pyi +++ b/logfire-api/logfire_api/_internal/main.pyi @@ -18,6 +18,7 @@ from .integrations.pymongo import PymongoInstrumentKwargs as PymongoInstrumentKw from .integrations.redis import RedisInstrumentKwargs as RedisInstrumentKwargs from .integrations.sqlalchemy import SQLAlchemyInstrumentKwargs as SQLAlchemyInstrumentKwargs from .integrations.starlette import StarletteInstrumentKwargs as StarletteInstrumentKwargs +from .integrations.mysql import MySQLConnection as MySQLConnection, MySQLInstrumentKwargs as MySQLInstrumentKwargs from .json_encoder import logfire_json_dumps as logfire_json_dumps from .json_schema import JsonSchemaProperties as JsonSchemaProperties, attributes_json_schema as attributes_json_schema, attributes_json_schema_properties as attributes_json_schema_properties, create_json_schema as create_json_schema from .metrics import ProxyMeterProvider as ProxyMeterProvider @@ -618,6 +619,22 @@ class Logfire: [OpenTelemetry pymongo Instrumentation](https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/pymongo/pymongo.html) library, specifically `PymongoInstrumentor().instrument()`, to which it passes `**kwargs`. """ + def instrument_mysql(self, conn: MySQLConnection, **kwargs: Unpack[MySQLInstrumentKwargs], + ) -> MySQLConnection: + """Instrument the `mysql` module or a specific MySQL connection so that spans are automatically created for each operation. + + Uses the + [OpenTelemetry MySQL Instrumentation](https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/mysql/mysql.html) + library. + + Args: + conn: The `mysql` connection to instrument, or `None` to instrument all connections. + **kwargs: Additional keyword arguments to pass to the OpenTelemetry `instrument` methods. + + Returns: + If a connection is provided, returns the instrumented connection. If no connection is provided, returns None. + + """ def instrument_redis(self, **kwargs: Unpack[RedisInstrumentKwargs]) -> None: """Instrument the `redis` module so that spans are automatically created for each operation. diff --git a/logfire/_internal/integrations/mysql.py b/logfire/_internal/integrations/mysql.py index 07791da2a..aaee38962 100644 --- a/logfire/_internal/integrations/mysql.py +++ b/logfire/_internal/integrations/mysql.py @@ -5,15 +5,33 @@ from opentelemetry.instrumentation.mysql import MySQLInstrumentor if TYPE_CHECKING: - from typing_extensions import TypedDict, Unpack + from mysql.connector.abstracts import MySQLConnectionAbstract + from mysql.connector.pooling import PooledMySQLConnection + from typing_extensions import TypedDict, TypeVar, Unpack + + MySQLConnection = TypeVar('MySQLConnection', PooledMySQLConnection, MySQLConnectionAbstract, None) class MySQLInstrumentKwargs(TypedDict, total=False): skip_dep_check: bool -def instrument_mysql(**kwargs: Unpack[MySQLInstrumentKwargs]) -> None: - """Instrument the `mysql` module so that spans are automatically created for each operation. +def instrument_mysql( + conn: MySQLConnection = None, + **kwargs: Unpack[MySQLInstrumentKwargs], +) -> MySQLConnection: + """Instrument the `mysql` module or a specific MySQL connection so that spans are automatically created for each operation. + + This function uses the OpenTelemetry MySQL Instrumentation library to instrument either the entire `mysql` module or a specific MySQL connection. + + Args: + conn: The MySQL connection to instrument. If None, the entire `mysql` module is instrumented. + **kwargs: Additional keyword arguments to pass to the OpenTelemetry `instrument` methods. + + Returns: + If a connection is provided, returns the instrumented connection. If no connection is provided, returns None. See the `Logfire.instrument_mysql` method for details. """ - MySQLInstrumentor().instrument(**kwargs) # type: ignore[reportUnknownMemberType] + if conn is not None: + return MySQLInstrumentor().instrument_connection(conn) # type: ignore[reportUnknownMemberType] + return MySQLInstrumentor().instrument(**kwargs) # type: ignore[reportUnknownMemberType] diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 2bca5cb10..14518a2b7 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -9,18 +9,7 @@ from functools import cached_property, partial from time import time from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ContextManager, - Iterable, - Literal, - Sequence, - TypeVar, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Literal, Sequence, TypeVar, Union, cast import opentelemetry.context as context_api import opentelemetry.trace as trace_api @@ -78,8 +67,8 @@ from .integrations.asyncpg import AsyncPGInstrumentKwargs from .integrations.celery import CeleryInstrumentKwargs from .integrations.flask import FlaskInstrumentKwargs - from .integrations.mysql import MySQLInstrumentKwargs from .integrations.httpx import HTTPXInstrumentKwargs + from .integrations.mysql import MySQLConnection, MySQLInstrumentKwargs from .integrations.psycopg import PsycopgInstrumentKwargs from .integrations.pymongo import PymongoInstrumentKwargs from .integrations.redis import RedisInstrumentKwargs @@ -98,8 +87,8 @@ # 2. It mirrors the exc_info argument of the stdlib logging methods # 3. The argument name exc_info is very suggestive of the sys function. ExcInfo: typing.TypeAlias = Union[ - "tuple[type[BaseException], BaseException, TracebackType | None]", - "tuple[None, None, None]", + 'tuple[type[BaseException], BaseException, TracebackType | None]', + 'tuple[None, None, None]', BaseException, bool, None, @@ -116,7 +105,7 @@ def __init__( sample_rate: float | None = None, tags: Sequence[str] = (), console_log: bool = True, - otel_scope: str = "logfire", + otel_scope: str = 'logfire', ) -> None: self._tags = tuple(tags) self._config = config @@ -144,9 +133,7 @@ def _logs_tracer(self) -> Tracer: def _spans_tracer(self) -> Tracer: return self._get_tracer(is_span_tracer=True) - def _get_tracer( - self, *, is_span_tracer: bool, otel_scope: str | None = None - ) -> Tracer: # pragma: no cover + def _get_tracer(self, *, is_span_tracer: bool, otel_scope: str | None = None) -> Tracer: # pragma: no cover return self._tracer_provider.get_tracer( self._otel_scope if otel_scope is None else otel_scope, VERSION, @@ -186,9 +173,7 @@ def _span( otlp_attributes = user_attributes(merged_attributes) if json_schema_properties := attributes_json_schema_properties(attributes): - otlp_attributes[ATTRIBUTES_JSON_SCHEMA_KEY] = attributes_json_schema( - json_schema_properties - ) + otlp_attributes[ATTRIBUTES_JSON_SCHEMA_KEY] = attributes_json_schema(json_schema_properties) tags = (self._tags or ()) + tuple(_tags or ()) if tags: @@ -215,9 +200,7 @@ def _span( log_internal_error() return NoopSpan() # type: ignore - def _fast_span( - self, name: str, attributes: otel_types.Attributes - ) -> FastLogfireSpan: + def _fast_span(self, name: str, attributes: otel_types.Attributes) -> FastLogfireSpan: """A simple version of `_span` optimized for auto-tracing that doesn't support message formatting. Returns a similarly simplified version of `LogfireSpan` which must immediately be used as a context manager. @@ -230,10 +213,7 @@ def _fast_span( return NoopSpan() # type: ignore def _instrument_span_with_args( - self, - name: str, - attributes: dict[str, otel_types.AttributeValue], - function_args: dict[str, Any], + self, name: str, attributes: dict[str, otel_types.AttributeValue], function_args: dict[str, Any] ) -> FastLogfireSpan: """A version of `_span` used by `@instrument` with `extract_args=True`. @@ -242,15 +222,9 @@ def _instrument_span_with_args( """ try: msg_template: str = attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] # type: ignore - attributes[ATTRIBUTES_MESSAGE_KEY] = logfire_format( - msg_template, function_args, self._config.scrubber - ) - if json_schema_properties := attributes_json_schema_properties( - function_args - ): - attributes[ATTRIBUTES_JSON_SCHEMA_KEY] = attributes_json_schema( - json_schema_properties - ) + attributes[ATTRIBUTES_MESSAGE_KEY] = logfire_format(msg_template, function_args, self._config.scrubber) + if json_schema_properties := attributes_json_schema_properties(function_args): + attributes[ATTRIBUTES_JSON_SCHEMA_KEY] = attributes_json_schema(json_schema_properties) attributes.update(user_attributes(function_args)) return self._fast_span(name, attributes) except Exception: # pragma: no cover @@ -285,9 +259,9 @@ def trace( Set to `True` to use the currently handled exception. """ - if any(k.startswith("_") for k in attributes): - raise ValueError("Attribute keys cannot start with an underscore.") - self.log("trace", msg_template, attributes, tags=_tags, exc_info=_exc_info) + if any(k.startswith('_') for k in attributes): + raise ValueError('Attribute keys cannot start with an underscore.') + self.log('trace', msg_template, attributes, tags=_tags, exc_info=_exc_info) def debug( self, @@ -317,9 +291,9 @@ def debug( Set to `True` to use the currently handled exception. """ - if any(k.startswith("_") for k in attributes): - raise ValueError("Attribute keys cannot start with an underscore.") - self.log("debug", msg_template, attributes, tags=_tags, exc_info=_exc_info) + if any(k.startswith('_') for k in attributes): + raise ValueError('Attribute keys cannot start with an underscore.') + self.log('debug', msg_template, attributes, tags=_tags, exc_info=_exc_info) def info( self, @@ -349,9 +323,9 @@ def info( Set to `True` to use the currently handled exception. """ - if any(k.startswith("_") for k in attributes): - raise ValueError("Attribute keys cannot start with an underscore.") - self.log("info", msg_template, attributes, tags=_tags, exc_info=_exc_info) + if any(k.startswith('_') for k in attributes): + raise ValueError('Attribute keys cannot start with an underscore.') + self.log('info', msg_template, attributes, tags=_tags, exc_info=_exc_info) def notice( self, @@ -381,9 +355,9 @@ def notice( Set to `True` to use the currently handled exception. """ - if any(k.startswith("_") for k in attributes): - raise ValueError("Attribute keys cannot start with an underscore.") - self.log("notice", msg_template, attributes, tags=_tags, exc_info=_exc_info) + if any(k.startswith('_') for k in attributes): + raise ValueError('Attribute keys cannot start with an underscore.') + self.log('notice', msg_template, attributes, tags=_tags, exc_info=_exc_info) def warn( self, @@ -413,9 +387,9 @@ def warn( Set to `True` to use the currently handled exception. """ - if any(k.startswith("_") for k in attributes): - raise ValueError("Attribute keys cannot start with an underscore.") - self.log("warn", msg_template, attributes, tags=_tags, exc_info=_exc_info) + if any(k.startswith('_') for k in attributes): + raise ValueError('Attribute keys cannot start with an underscore.') + self.log('warn', msg_template, attributes, tags=_tags, exc_info=_exc_info) def error( self, @@ -445,9 +419,9 @@ def error( Set to `True` to use the currently handled exception. """ - if any(k.startswith("_") for k in attributes): - raise ValueError("Attribute keys cannot start with an underscore.") - self.log("error", msg_template, attributes, tags=_tags, exc_info=_exc_info) + if any(k.startswith('_') for k in attributes): + raise ValueError('Attribute keys cannot start with an underscore.') + self.log('error', msg_template, attributes, tags=_tags, exc_info=_exc_info) def fatal( self, @@ -477,9 +451,9 @@ def fatal( Set to `True` to use the currently handled exception. """ - if any(k.startswith("_") for k in attributes): - raise ValueError("Attribute keys cannot start with an underscore.") - self.log("fatal", msg_template, attributes, tags=_tags, exc_info=_exc_info) + if any(k.startswith('_') for k in attributes): + raise ValueError('Attribute keys cannot start with an underscore.') + self.log('fatal', msg_template, attributes, tags=_tags, exc_info=_exc_info) def exception( self, @@ -501,9 +475,9 @@ def exception( _exc_info: Set to an exception or a tuple as returned by [`sys.exc_info()`][sys.exc_info] to record a traceback with the log message. """ - if any(k.startswith("_") for k in attributes): # pragma: no cover - raise ValueError("Attribute keys cannot start with an underscore.") - self.log("error", msg_template, attributes, tags=_tags, exc_info=_exc_info) + if any(k.startswith('_') for k in attributes): # pragma: no cover + raise ValueError('Attribute keys cannot start with an underscore.') + self.log('error', msg_template, attributes, tags=_tags, exc_info=_exc_info) def span( self, @@ -534,8 +508,8 @@ def span( attributes: The arguments to include in the span and format the message template with. Attributes starting with an underscore are not allowed. """ - if any(k.startswith("_") for k in attributes): - raise ValueError("Attribute keys cannot start with an underscore.") + if any(k.startswith('_') for k in attributes): + raise ValueError('Attribute keys cannot start with an underscore.') return self._span( msg_template, attributes, @@ -573,9 +547,7 @@ def my_function(a: int): span_name: The span name. If not provided, the `msg_template` will be used. extract_args: Whether to extract arguments from the function signature and log them as span attributes. """ - args = LogfireArgs( - tuple(self._tags), self._sample_rate, msg_template, span_name, extract_args - ) + args = LogfireArgs(tuple(self._tags), self._sample_rate, msg_template, span_name, extract_args) return instrument(self, args) def log( @@ -648,16 +620,14 @@ def log( otlp_attributes = user_attributes(merged_attributes) otlp_attributes = { - ATTRIBUTES_SPAN_TYPE_KEY: "log", + ATTRIBUTES_SPAN_TYPE_KEY: 'log', **log_level_attributes(level), ATTRIBUTES_MESSAGE_TEMPLATE_KEY: msg_template, ATTRIBUTES_MESSAGE_KEY: msg, **otlp_attributes, } if json_schema_properties := attributes_json_schema_properties(attributes): - otlp_attributes[ATTRIBUTES_JSON_SCHEMA_KEY] = attributes_json_schema( - json_schema_properties - ) + otlp_attributes[ATTRIBUTES_JSON_SCHEMA_KEY] = attributes_json_schema(json_schema_properties) tags = self._tags + tuple(tags or ()) if tags: @@ -676,9 +646,7 @@ def log( start_time = self._config.ns_timestamp_generator() if custom_scope_suffix: - tracer = self._get_tracer( - is_span_tracer=False, otel_scope=f"logfire.{custom_scope_suffix}" - ) + tracer = self._get_tracer(is_span_tracer=False, otel_scope=f'logfire.{custom_scope_suffix}') else: tracer = self._logs_tracer @@ -696,9 +664,7 @@ def log( if isinstance(exc_info, BaseException): _record_exception(span, exc_info) elif exc_info is not None: # pragma: no cover - raise TypeError( - f"Invalid type for exc_info: {exc_info.__class__.__name__}" - ) + raise TypeError(f'Invalid type for exc_info: {exc_info.__class__.__name__}') span.end(start_time) @@ -735,7 +701,7 @@ def with_trace_sample_rate(self, sample_rate: float) -> Logfire: # pragma: no c A new Logfire instance with the sampling ratio applied. """ if sample_rate > 1 or sample_rate < 0: - raise ValueError("sample_rate must be between 0 and 1") + raise ValueError('sample_rate must be between 0 and 1') return Logfire( config=self._config, tags=self._tags, @@ -774,11 +740,7 @@ def with_settings( tags=self._tags + tuple(tags), sample_rate=self._sample_rate, console_log=self._console_log if console_log is None else console_log, - otel_scope=( - self._otel_scope - if custom_scope_suffix is None - else f"logfire.{custom_scope_suffix}" - ), + otel_scope=self._otel_scope if custom_scope_suffix is None else f'logfire.{custom_scope_suffix}', ) def force_flush(self, timeout_millis: int = 3_000) -> bool: # pragma: no cover @@ -792,9 +754,7 @@ def force_flush(self, timeout_millis: int = 3_000) -> bool: # pragma: no cover """ return self._config.force_flush(timeout_millis) - def log_slow_async_callbacks( - self, slow_duration: float = 0.1 - ) -> ContextManager[None]: + def log_slow_async_callbacks(self, slow_duration: float = 0.1) -> ContextManager[None]: """Log a warning whenever a function running in the asyncio event loop blocks for too long. This works by patching the `asyncio.events.Handle._run` method. @@ -815,7 +775,7 @@ def install_auto_tracing( self, modules: Sequence[str] | Callable[[AutoTraceModule], bool], *, - check_imported_modules: Literal["error", "warn", "ignore"] = "error", + check_imported_modules: Literal['error', 'warn', 'ignore'] = 'error', min_duration: float = 0, ) -> None: """Install automatic tracing. @@ -844,31 +804,24 @@ def install_auto_tracing( Otherwise, the first time(s) each function is called, it will be timed but not traced. Only after the function has run for at least `min_duration` will it be traced in subsequent calls. """ - install_auto_tracing( - self, - modules, - check_imported_modules=check_imported_modules, - min_duration=min_duration, - ) + install_auto_tracing(self, modules, check_imported_modules=check_imported_modules, min_duration=min_duration) def _warn_if_not_initialized_for_instrumentation(self): - self.config.warn_if_not_initialized("Instrumentation will have no effect") + self.config.warn_if_not_initialized('Instrumentation will have no effect') def instrument_fastapi( self, app: FastAPI, *, capture_headers: bool = False, - request_attributes_mapper: ( - Callable[ - [ - Request | WebSocket, - dict[str, Any], - ], - dict[str, Any] | None, - ] - | None - ) = None, + request_attributes_mapper: Callable[ + [ + Request | WebSocket, + dict[str, Any], + ], + dict[str, Any] | None, + ] + | None = None, use_opentelemetry_instrumentation: bool = True, excluded_urls: str | Iterable[str] | None = None, **opentelemetry_kwargs: Any, @@ -926,13 +879,11 @@ def instrument_fastapi( def instrument_openai( self, - openai_client: ( - openai.OpenAI - | openai.AsyncOpenAI - | type[openai.OpenAI] - | type[openai.AsyncOpenAI] - | None - ) = None, + openai_client: openai.OpenAI + | openai.AsyncOpenAI + | type[openai.OpenAI] + | type[openai.AsyncOpenAI] + | None = None, *, suppress_other_instrumentation: bool = True, ) -> ContextManager[None]: @@ -987,18 +938,14 @@ def instrument_openai( import openai from .integrations.llm_providers.llm_provider import instrument_llm_provider - from .integrations.llm_providers.openai import ( - get_endpoint_config, - is_async_client, - on_response, - ) + from .integrations.llm_providers.openai import get_endpoint_config, is_async_client, on_response self._warn_if_not_initialized_for_instrumentation() return instrument_llm_provider( self, openai_client or (openai.OpenAI, openai.AsyncOpenAI), suppress_other_instrumentation, - "OpenAI", + 'OpenAI', get_endpoint_config, on_response, is_async_client, @@ -1006,13 +953,11 @@ def instrument_openai( def instrument_anthropic( self, - anthropic_client: ( - anthropic.Anthropic - | anthropic.AsyncAnthropic - | type[anthropic.Anthropic] - | type[anthropic.AsyncAnthropic] - | None - ) = None, + anthropic_client: anthropic.Anthropic + | anthropic.AsyncAnthropic + | type[anthropic.Anthropic] + | type[anthropic.AsyncAnthropic] + | None = None, *, suppress_other_instrumentation: bool = True, ) -> ContextManager[None]: @@ -1066,11 +1011,7 @@ def instrument_anthropic( """ import anthropic - from .integrations.llm_providers.anthropic import ( - get_endpoint_config, - is_async_client, - on_response, - ) + from .integrations.llm_providers.anthropic import get_endpoint_config, is_async_client, on_response from .integrations.llm_providers.llm_provider import instrument_llm_provider self._warn_if_not_initialized_for_instrumentation() @@ -1078,7 +1019,7 @@ def instrument_anthropic( self, anthropic_client or (anthropic.Anthropic, anthropic.AsyncAnthropic), suppress_other_instrumentation, - "Anthropic", + 'Anthropic', get_endpoint_config, on_response, is_async_client, @@ -1179,9 +1120,7 @@ def instrument_requests(self, excluded_urls: str | None = None, **kwargs: Any): self._warn_if_not_initialized_for_instrumentation() return instrument_requests(excluded_urls=excluded_urls, **kwargs) - def instrument_psycopg( - self, conn_or_module: Any = None, **kwargs: Unpack[PsycopgInstrumentKwargs] - ) -> None: + def instrument_psycopg(self, conn_or_module: Any = None, **kwargs: Unpack[PsycopgInstrumentKwargs]) -> None: """Instrument a `psycopg` connection or module so that spans are automatically created for each query. Uses the OpenTelemetry instrumentation libraries for @@ -1206,11 +1145,7 @@ def instrument_psycopg( return instrument_psycopg(conn_or_module, **kwargs) def instrument_flask( - self, - app: Flask, - *, - capture_headers: bool = False, - **kwargs: Unpack[FlaskInstrumentKwargs], + self, app: Flask, *, capture_headers: bool = False, **kwargs: Unpack[FlaskInstrumentKwargs] ) -> None: """Instrument `app` so that spans are automatically created for each request. @@ -1226,11 +1161,7 @@ def instrument_flask( return instrument_flask(app, capture_headers=capture_headers, **kwargs) def instrument_starlette( - self, - app: Starlette, - *, - capture_headers: bool = False, - **kwargs: Unpack[StarletteInstrumentKwargs], + self, app: Starlette, *, capture_headers: bool = False, **kwargs: Unpack[StarletteInstrumentKwargs] ) -> None: """Instrument `app` so that spans are automatically created for each request. @@ -1257,9 +1188,7 @@ def instrument_aiohttp_client(self, **kwargs: Any): self._warn_if_not_initialized_for_instrumentation() return instrument_aiohttp_client(**kwargs) - def instrument_sqlalchemy( - self, **kwargs: Unpack[SQLAlchemyInstrumentKwargs] - ) -> None: + def instrument_sqlalchemy(self, **kwargs: Unpack[SQLAlchemyInstrumentKwargs]) -> None: """Instrument the `sqlalchemy` module so that spans are automatically created for each query. Uses the @@ -1295,21 +1224,31 @@ def instrument_redis(self, **kwargs: Unpack[RedisInstrumentKwargs]) -> None: self._warn_if_not_initialized_for_instrumentation() return instrument_redis(**kwargs) - def instrument_mysql(self, **kwargs: Unpack[MySQLInstrumentKwargs]) -> None: - """Instrument the `mysql` module so that spans are automatically created for each operation. + def instrument_mysql( + self, + conn: MySQLConnection = None, + **kwargs: Unpack[MySQLInstrumentKwargs], + ) -> MySQLConnection: + """Instrument the `mysql` module or a specific MySQL connection so that spans are automatically created for each operation. Uses the [OpenTelemetry MySQL Instrumentation](https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/mysql/mysql.html) - library, specifically `MySQLInstrumentor().instrument()`, to which it passes `**kwargs`. + library. + + Args: + conn: The `mysql` connection to instrument, or `None` to instrument all connections. + **kwargs: Additional keyword arguments to pass to the OpenTelemetry `instrument` methods. + + Returns: + If a connection is provided, returns the instrumented connection. If no connection is provided, returns None. + """ from .integrations.mysql import instrument_mysql self._warn_if_not_initialized_for_instrumentation() - return instrument_mysql(**kwargs) + return instrument_mysql(conn, **kwargs) - def metric_counter( - self, name: str, *, unit: str = "", description: str = "" - ) -> Counter: + def metric_counter(self, name: str, *, unit: str = '', description: str = '') -> Counter: """Create a counter metric. A counter is a cumulative metric that represents a single numerical value that only ever goes up. @@ -1339,9 +1278,7 @@ def metric_counter( """ return self._config.meter.create_counter(name, unit, description) - def metric_histogram( - self, name: str, *, unit: str = "", description: str = "" - ) -> Histogram: + def metric_histogram(self, name: str, *, unit: str = '', description: str = '') -> Histogram: """Create a histogram metric. A histogram is a metric that samples observations (usually things like request durations or response sizes). @@ -1369,9 +1306,7 @@ def transfer(amount: int): """ return self._config.meter.create_histogram(name, unit, description) - def metric_gauge( - self, name: str, *, unit: str = "", description: str = "" - ) -> Gauge: + def metric_gauge(self, name: str, *, unit: str = '', description: str = '') -> Gauge: """Create a gauge metric. Gauge is a synchronous instrument which can be used to record non-additive measurements. @@ -1399,9 +1334,7 @@ def update_cpu_usage(cpu_percent): """ return self._config.meter.create_gauge(name, unit, description) - def metric_up_down_counter( - self, name: str, *, unit: str = "", description: str = "" - ) -> UpDownCounter: + def metric_up_down_counter(self, name: str, *, unit: str = '', description: str = '') -> UpDownCounter: """Create an up-down counter metric. An up-down counter is a cumulative metric that represents a single numerical value that can be adjusted up or @@ -1440,8 +1373,8 @@ def metric_counter_callback( name: str, *, callbacks: Sequence[CallbackT], - unit: str = "", - description: str = "", + unit: str = '', + description: str = '', ) -> None: """Create a counter metric that uses a callback to collect observations. @@ -1483,12 +1416,7 @@ def cpu_usage_callback(options: CallbackOptions): self._config.meter.create_observable_counter(name, callbacks, unit, description) def metric_gauge_callback( - self, - name: str, - callbacks: Sequence[CallbackT], - *, - unit: str = "", - description: str = "", + self, name: str, callbacks: Sequence[CallbackT], *, unit: str = '', description: str = '' ) -> None: """Create a gauge metric that uses a callback to collect observations. @@ -1528,12 +1456,7 @@ def thread_count_callback(options: CallbackOptions): self._config.meter.create_observable_gauge(name, callbacks, unit, description) def metric_up_down_counter_callback( - self, - name: str, - callbacks: Sequence[CallbackT], - *, - unit: str = "", - description: str = "", + self, name: str, callbacks: Sequence[CallbackT], *, unit: str = '', description: str = '' ) -> None: """Create an up-down counter metric that uses a callback to collect observations. @@ -1570,13 +1493,9 @@ def inventory_callback(options: CallbackOptions): unit: The unit of the metric. description: The description of the metric. """ - self._config.meter.create_observable_up_down_counter( - name, callbacks, unit, description - ) + self._config.meter.create_observable_up_down_counter(name, callbacks, unit, description) - def shutdown( - self, timeout_millis: int = 30_000, flush: bool = True - ) -> bool: # pragma: no cover + def shutdown(self, timeout_millis: int = 30_000, flush: bool = True) -> bool: # pragma: no cover """Shut down all tracers and meters. This will clean up any resources used by the tracers and meters and flush any remaining spans and metrics. @@ -1611,7 +1530,7 @@ def shutdown( class FastLogfireSpan: """A simple version of `LogfireSpan` optimized for auto-tracing.""" - __slots__ = ("_span", "_token", "_atexit") + __slots__ = ('_span', '_token', '_atexit') def __init__(self, span: trace_api.Span) -> None: self._span = span @@ -1623,12 +1542,7 @@ def __enter__(self) -> FastLogfireSpan: return self @handle_internal_errors() - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: Any, - ) -> None: + def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None: atexit.unregister(self._atexit) context_api.detach(self._token) _exit_span(self._span, exc_value) @@ -1670,9 +1584,7 @@ def __enter__(self) -> LogfireSpan: attributes=self._otlp_attributes, ) if self._token is None: # pragma: no branch - self._token = context_api.attach( - trace_api.set_span_in_context(self._span) - ) + self._token = context_api.attach(trace_api.set_span_in_context(self._span)) self._atexit = partial(self.__exit__, None, None, None) atexit.register(self._atexit) @@ -1680,12 +1592,7 @@ def __enter__(self) -> LogfireSpan: return self @handle_internal_errors() - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: Any, - ) -> None: + def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None: if self._token is None: # pragma: no cover return @@ -1733,13 +1640,12 @@ def end(self) -> None: exits. """ if self._span is None: # pragma: no cover - raise RuntimeError("Span has not been started") + raise RuntimeError('Span has not been started') if self._span.is_recording(): with handle_internal_errors(): if self._added_attributes: self._span.set_attribute( - ATTRIBUTES_JSON_SCHEMA_KEY, - attributes_json_schema(self._json_schema_properties), + ATTRIBUTES_JSON_SCHEMA_KEY, attributes_json_schema(self._json_schema_properties) ) self._span.end() @@ -1776,7 +1682,7 @@ def record_exception( Delegates to the OpenTelemetry SDK `Span.record_exception` method. """ if self._span is None: - raise RuntimeError("Span has not been started") + raise RuntimeError('Span has not been started') # Check if the span has been sampled out first, since _record_exception is somewhat expensive. if not self._span.is_recording(): @@ -1803,7 +1709,7 @@ def set_level(self, level: LevelName | int): self._span.set_attributes(attributes) def _get_attribute(self, key: str, default: Any) -> Any: - attributes = getattr(self._span, "attributes", self._otlp_attributes) + attributes = getattr(self._span, 'attributes', self._otlp_attributes) return attributes.get(key, default) @@ -1832,18 +1738,13 @@ def __getattr__(self, _name: str) -> Any: def __enter__(self) -> NoopSpan: return self - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: Any, - ) -> None: + def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None: pass # Implement methods/properties that return something to get the type right. @property def message_template(self) -> str: # pragma: no cover - return "" + return '' @property def tags(self) -> Sequence[str]: # pragma: no cover @@ -1851,7 +1752,7 @@ def tags(self) -> Sequence[str]: # pragma: no cover @property def message(self) -> str: # pragma: no cover - return "" + return '' # This is required to make `span.message = ` not raise an error. @message.setter @@ -1890,10 +1791,10 @@ def _record_exception( span.set_status( trace_api.Status( status_code=trace_api.StatusCode.ERROR, - description=f"{exception.__class__.__name__}: {exception}", + description=f'{exception.__class__.__name__}: {exception}', ) ) - span.set_attributes(log_level_attributes("error")) + span.set_attributes(log_level_attributes('error')) attributes = {**(attributes or {})} if ValidationError is not None and isinstance(exception, ValidationError): @@ -1910,21 +1811,13 @@ def _record_exception( # OTEL's record_exception uses `traceback.format_exc()` which is for the current exception, # ignoring the passed exception. # So we override the stacktrace attribute with the correct one. - stacktrace = "".join( - traceback.format_exception( - type(exception), exception, exception.__traceback__ - ) - ) + stacktrace = ''.join(traceback.format_exception(type(exception), exception, exception.__traceback__)) attributes[SpanAttributes.EXCEPTION_STACKTRACE] = stacktrace - span.record_exception( - exception, attributes=attributes, timestamp=timestamp, escaped=escaped - ) + span.record_exception(exception, attributes=attributes, timestamp=timestamp, escaped=escaped) -AttributesValueType = TypeVar( - "AttributesValueType", bound=Union[Any, otel_types.AttributeValue] -) +AttributesValueType = TypeVar('AttributesValueType', bound=Union[Any, otel_types.AttributeValue]) def user_attributes(attributes: dict[str, Any]) -> dict[str, otel_types.AttributeValue]: @@ -1950,13 +1843,13 @@ def set_user_attribute( """ otel_value: otel_types.AttributeValue if value is None: - otel_value = cast("list[str]", otlp_attributes.get(NULL_ARGS_KEY, [])) + [key] + otel_value = cast('list[str]', otlp_attributes.get(NULL_ARGS_KEY, [])) + [key] key = NULL_ARGS_KEY elif isinstance(value, int): if value > OTLP_MAX_INT_SIZE: warnings.warn( - f"Integer value {value} is larger than the maximum OTLP integer size of {OTLP_MAX_INT_SIZE} (64-bits), " - " if you need support for sending larger integers, please open a feature request", + f'Integer value {value} is larger than the maximum OTLP integer size of {OTLP_MAX_INT_SIZE} (64-bits), ' + ' if you need support for sending larger integers, please open a feature request', UserWarning, ) otel_value = str(value) @@ -1970,5 +1863,5 @@ def set_user_attribute( return key, otel_value -_PARAMS = ParamSpec("_PARAMS") -_RETURN = TypeVar("_RETURN") +_PARAMS = ParamSpec('_PARAMS') +_RETURN = TypeVar('_RETURN') diff --git a/pyproject.toml b/pyproject.toml index b3a858ffa..5e5d14475 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,7 @@ dev-dependencies = [ "celery>=5.4.0", "testcontainers", "mysql-connector-python~=8.0", + "pymysql", ] [tool.rye.scripts] diff --git a/requirements-dev.lock b/requirements-dev.lock index 6fe4c76e8..318f9db78 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -353,6 +353,7 @@ pymdown-extensions==10.8.1 # via mkdocs-material # via mkdocstrings pymongo==4.8.0 +pymysql==1.1.1 pyright==1.1.373 pytest==8.3.1 # via pytest-django diff --git a/tests/otel_integrations/test_mysql.py b/tests/otel_integrations/test_mysql.py index 04affb80d..2cbaa2950 100644 --- a/tests/otel_integrations/test_mysql.py +++ b/tests/otel_integrations/test_mysql.py @@ -1,42 +1,113 @@ -from unittest import mock - import mysql.connector +import pytest +from dirty_equals import IsInt from inline_snapshot import snapshot +from testcontainers.mysql import MySqlContainer import logfire from logfire.testing import TestExporter -def connect_and_execute_query(): - cnx = mysql.connector.connect(database='test') - cursor = cnx.cursor() - query = 'SELECT * FROM test' - cursor.execute(query) - return cnx, query - - -def test_mysql_instrumentation(exporter: TestExporter): - with mock.patch('mysql.connector.connect') as mock_connect: - mock_cursor = mock.MagicMock() - mock_connect.return_value.cursor.return_value = mock_cursor - mock_connect.return_value.user = 'test_user' - logfire.instrument_mysql() - connect_and_execute_query() - assert exporter.exported_spans_as_dict() == snapshot( - [ - { - 'name': 'SELECT', - 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, - 'parent': None, - 'start_time': 1000000000, - 'end_time': 2000000000, - 'attributes': { - 'logfire.span_type': 'span', - 'logfire.msg': 'SELECT * FROM test', - 'db.system': 'mysql', - 'db.statement': 'SELECT * FROM test', - 'db.user': 'test_user', - }, - } - ] - ) +@pytest.fixture(scope='module') +def mysql_container(): + with MySqlContainer() as mysql_container: + yield mysql_container + + +def get_mysql_connection(mysql_container: MySqlContainer): + host = mysql_container.get_container_host_ip() + port = mysql_container.get_exposed_port(3306) + connection = mysql.connector.connect(host=host, port=port, user='test', password='test', database='test') + return connection + + +def test_mysql_instrumentation(exporter: TestExporter, mysql_container: MySqlContainer): + logfire.instrument_mysql() + conn = get_mysql_connection(mysql_container) + cursor = conn.cursor() + cursor.execute('DROP TABLE IF EXISTS test') + cursor.execute('CREATE TABLE test (id INT PRIMARY KEY, name VARCHAR(255))') + assert exporter.exported_spans_as_dict() == snapshot( + [ + { + 'name': 'DROP', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 2000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'DROP TABLE IF EXISTS test', + 'db.system': 'mysql', + 'db.name': 'test', + 'db.statement': 'DROP TABLE IF EXISTS test', + 'db.user': 'test', + 'net.peer.name': 'localhost', + 'net.peer.port': IsInt(), + }, + }, + { + 'name': 'CREATE', + 'context': {'trace_id': 2, 'span_id': 3, 'is_remote': False}, + 'parent': None, + 'start_time': 3000000000, + 'end_time': 4000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'CREATE TABLE test (id INT PRIMARY KEY, name VARCHAR(255))', + 'db.system': 'mysql', + 'db.name': 'test', + 'db.statement': 'CREATE TABLE test (id INT PRIMARY KEY, name VARCHAR(255))', + 'db.user': 'test', + 'net.peer.name': 'localhost', + 'net.peer.port': IsInt(), + }, + }, + ] + ) + + +def test_instrument_mysql_connection(exporter: TestExporter, mysql_container: MySqlContainer): + conn = get_mysql_connection(mysql_container) + conn = logfire.instrument_mysql(conn) # type: ignore + cursor = conn.cursor() # type: ignore + cursor.execute('DROP TABLE IF EXISTS test') # type: ignore + cursor.execute('CREATE TABLE test (id INT PRIMARY KEY, name VARCHAR(255))') # type: ignore + assert exporter.exported_spans_as_dict() == snapshot( + [ + { + 'name': 'DROP', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 2000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'DROP TABLE IF EXISTS test', + 'db.system': 'mysql', + 'db.name': 'test', + 'db.statement': 'DROP TABLE IF EXISTS test', + 'db.user': 'test', + 'net.peer.name': 'localhost', + 'net.peer.port': IsInt(), + }, + }, + { + 'name': 'CREATE', + 'context': {'trace_id': 2, 'span_id': 3, 'is_remote': False}, + 'parent': None, + 'start_time': 3000000000, + 'end_time': 4000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'CREATE TABLE test (id INT PRIMARY KEY, name VARCHAR(255))', + 'db.system': 'mysql', + 'db.name': 'test', + 'db.statement': 'CREATE TABLE test (id INT PRIMARY KEY, name VARCHAR(255))', + 'db.user': 'test', + 'net.peer.name': 'localhost', + 'net.peer.port': IsInt(), + }, + }, + ] + )