From 0d21837412a5285511865be8b6683485b03cf7d8 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 23 Dec 2024 11:03:34 +0100 Subject: [PATCH] Support `AsyncEngine` to SQLAlchemy (#717) --- logfire/_internal/integrations/sqlalchemy.py | 12 +- logfire/_internal/main.py | 13 +- pyproject.toml | 1 + tests/otel_integrations/test_sqlalchemy.py | 205 +++++++++++++++++-- uv.lock | 14 ++ 5 files changed, 227 insertions(+), 18 deletions(-) diff --git a/logfire/_internal/integrations/sqlalchemy.py b/logfire/_internal/integrations/sqlalchemy.py index e234b189c..63edbceac 100644 --- a/logfire/_internal/integrations/sqlalchemy.py +++ b/logfire/_internal/integrations/sqlalchemy.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib from typing import TYPE_CHECKING try: @@ -13,6 +14,7 @@ if TYPE_CHECKING: from sqlalchemy import Engine + from sqlalchemy.ext.asyncio import AsyncEngine from typing_extensions import TypedDict, Unpack class CommenterOptions(TypedDict, total=False): @@ -21,15 +23,19 @@ class CommenterOptions(TypedDict, total=False): opentelemetry_values: bool class SQLAlchemyInstrumentKwargs(TypedDict, total=False): - engine: Engine | None enable_commenter: bool | None commenter_options: CommenterOptions | None skip_dep_check: bool -def instrument_sqlalchemy(**kwargs: Unpack[SQLAlchemyInstrumentKwargs]) -> None: +def instrument_sqlalchemy(engine: AsyncEngine | Engine | None, **kwargs: Unpack[SQLAlchemyInstrumentKwargs]) -> None: """Instrument the `sqlalchemy` module so that spans are automatically created for each query. See the `Logfire.instrument_sqlalchemy` method for details. """ - SQLAlchemyInstrumentor().instrument(**kwargs) + with contextlib.suppress(ImportError): + from sqlalchemy.ext.asyncio import AsyncEngine + + if isinstance(engine, AsyncEngine): + engine = engine.sync_engine + return SQLAlchemyInstrumentor().instrument(engine=engine, **kwargs) diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index d3c17752d..72a9de6ad 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -74,6 +74,8 @@ from fastapi import FastAPI from flask.app import Flask from opentelemetry.metrics import _Gauge as Gauge + from sqlalchemy import Engine + from sqlalchemy.ext.asyncio import AsyncEngine from starlette.applications import Starlette from starlette.requests import Request from starlette.websockets import WebSocket @@ -1498,17 +1500,26 @@ def instrument_aiohttp_client(self, **kwargs: Any) -> None: self._warn_if_not_initialized_for_instrumentation() return instrument_aiohttp_client(self, **kwargs) - def instrument_sqlalchemy(self, **kwargs: Unpack[SQLAlchemyInstrumentKwargs]) -> None: + def instrument_sqlalchemy( + self, + engine: AsyncEngine | Engine | None = None, + **kwargs: Unpack[SQLAlchemyInstrumentKwargs], + ) -> None: """Instrument the `sqlalchemy` module so that spans are automatically created for each query. Uses the [OpenTelemetry SQLAlchemy Instrumentation](https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/sqlalchemy/sqlalchemy.html) library, specifically `SQLAlchemyInstrumentor().instrument()`, to which it passes `**kwargs`. + + Args: + engine: The `sqlalchemy` engine to instrument, or `None` to instrument all engines. + **kwargs: Additional keyword arguments to pass to the OpenTelemetry `instrument` methods. """ from .integrations.sqlalchemy import instrument_sqlalchemy self._warn_if_not_initialized_for_instrumentation() return instrument_sqlalchemy( + engine=engine, **{ # type: ignore 'tracer_provider': self._config.get_tracer_provider(), 'meter_provider': self._config.get_meter_provider(), diff --git a/pyproject.toml b/pyproject.toml index 5b8b64adf..c088900d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,6 +159,7 @@ dev = [ "logfire-api", "requests", "setuptools>=75.3.0", + "aiosqlite>=0.20.0", ] docs = [ "mkdocs>=1.5.0", diff --git a/tests/otel_integrations/test_sqlalchemy.py b/tests/otel_integrations/test_sqlalchemy.py index c82a5411c..dff2b3212 100644 --- a/tests/otel_integrations/test_sqlalchemy.py +++ b/tests/otel_integrations/test_sqlalchemy.py @@ -8,7 +8,9 @@ import pytest from inline_snapshot import snapshot +from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from sqlalchemy.engine import Engine, create_engine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.sql import text from sqlalchemy.types import Integer, String @@ -18,6 +20,19 @@ from logfire.testing import TestExporter +class Base(DeclarativeBase): + pass + + +# `auth` is in default scrubbing patterns, but `db.statement` attribute is in scrubbing SAFE_KEYS. +# So, logfire shouldn't redact `auth` in the `db.statement` attribute. +class AuthRecord(Base): + __tablename__ = 'auth_records' + id: Mapped[int] = mapped_column(primary_key=True) + number: Mapped[int] = mapped_column(Integer, nullable=False) + content: Mapped[str] = mapped_column(String, nullable=False) + + @contextmanager def sqlite_engine(path: Path) -> Iterator[Engine]: path.unlink(missing_ok=True) @@ -30,22 +45,8 @@ def sqlite_engine(path: Path) -> Iterator[Engine]: def test_sqlalchemy_instrumentation(exporter: TestExporter): with sqlite_engine(Path('example.db')) as engine: - # Need to ensure this import happens _after_ importing sqlalchemy - from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor - logfire.instrument_sqlalchemy(engine=engine) - class Base(DeclarativeBase): - pass - - # `auth` is in default scrubbing patterns, but `db.statement` attribute is in scrubbing SAFE_KEYS. - # So, logfire shouldn't redact `auth` in the `db.statement` attribute. - class AuthRecord(Base): - __tablename__ = 'auth_records' - id: Mapped[int] = mapped_column(primary_key=True) - number: Mapped[int] = mapped_column(Integer, nullable=False) - content: Mapped[str] = mapped_column(String, nullable=False) - Base.metadata.create_all(engine) with Session(engine) as session: @@ -207,6 +208,182 @@ class AuthRecord(Base): SQLAlchemyInstrumentor().uninstrument() +@contextmanager +def sqlite_async_engine(path: Path) -> Iterator[AsyncEngine]: + path.unlink(missing_ok=True) + engine = create_async_engine(f'sqlite+aiosqlite:///{path}') + try: + yield engine + finally: + path.unlink() + + +@pytest.mark.anyio +async def test_sqlalchemy_async_instrumentation(exporter: TestExporter): + with sqlite_async_engine(Path('example.db')) as engine: + logfire.instrument_sqlalchemy(engine=engine) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async with AsyncSession(engine) as session: + record = AuthRecord(id=1, number=2, content='abc') + await session.execute(text('select * from auth_records')) + session.add(record) + await session.commit() + await session.delete(record) + await session.commit() + + assert exporter.exported_spans_as_dict() == snapshot( + [ + { + 'name': 'connect', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 2000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'connect', + 'db.name': 'example.db', + 'db.system': 'sqlite', + 'logfire.level_num': 5, + }, + }, + { + 'name': 'PRAGMA example.db', + 'context': {'trace_id': 2, 'span_id': 3, 'is_remote': False}, + 'parent': None, + 'start_time': 3000000000, + 'end_time': 4000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'PRAGMA main.table_info("auth_records")', + 'db.statement': 'PRAGMA main.table_info("auth_records")', + 'db.system': 'sqlite', + 'db.name': 'example.db', + }, + }, + { + 'name': 'PRAGMA example.db', + 'context': {'trace_id': 3, 'span_id': 5, 'is_remote': False}, + 'parent': None, + 'start_time': 5000000000, + 'end_time': 6000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'PRAGMA temp.table_info("auth_records")', + 'db.statement': 'PRAGMA temp.table_info("auth_records")', + 'db.system': 'sqlite', + 'db.name': 'example.db', + }, + }, + { + 'name': 'CREATE example.db', + 'context': {'trace_id': 4, 'span_id': 7, 'is_remote': False}, + 'parent': None, + 'start_time': 7000000000, + 'end_time': 8000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'CREATE TABLE auth_records ( id INTEGER … t VARCHAR NOT NULL, PRIMARY KEY (id)\n)', + 'db.statement': '\nCREATE TABLE auth_records (\n\tid INTEGER NOT NULL, \n\tnumber INTEGER NOT NULL, \n\tcontent VARCHAR NOT NULL, \n\tPRIMARY KEY (id)\n)\n\n', + 'db.system': 'sqlite', + 'db.name': 'example.db', + }, + }, + { + 'name': 'connect', + 'context': {'trace_id': 5, 'span_id': 9, 'is_remote': False}, + 'parent': None, + 'start_time': 9000000000, + 'end_time': 10000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'connect', + 'db.name': 'example.db', + 'db.system': 'sqlite', + 'logfire.level_num': 5, + }, + }, + { + 'name': 'select example.db', + 'context': {'trace_id': 6, 'span_id': 11, 'is_remote': False}, + 'parent': None, + 'start_time': 11000000000, + 'end_time': 12000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'select * from auth_records', + 'db.statement': 'select * from auth_records', + 'db.system': 'sqlite', + 'db.name': 'example.db', + }, + }, + { + 'name': 'INSERT example.db', + 'context': {'trace_id': 7, 'span_id': 13, 'is_remote': False}, + 'parent': None, + 'start_time': 13000000000, + 'end_time': 14000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'INSERT INTO auth_records (id, number, content) VALUES (?, ?, ?)', + 'db.statement': 'INSERT INTO auth_records (id, number, content) VALUES (?, ?, ?)', + 'db.system': 'sqlite', + 'db.name': 'example.db', + }, + }, + { + 'name': 'connect', + 'context': {'trace_id': 8, 'span_id': 15, 'is_remote': False}, + 'parent': None, + 'start_time': 15000000000, + 'end_time': 16000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'connect', + 'db.name': 'example.db', + 'db.system': 'sqlite', + 'logfire.level_num': 5, + }, + }, + { + 'name': 'SELECT example.db', + 'context': {'trace_id': 9, 'span_id': 17, 'is_remote': False}, + 'parent': None, + 'start_time': 17000000000, + 'end_time': 18000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'SELECT auth_recor…ds_content FROM auth_records WHERE …', + 'db.statement': """\ +SELECT auth_records.id AS auth_records_id, auth_records.number AS auth_records_number, auth_records.content AS auth_records_content \nFROM auth_records \nWHERE auth_records.id = ?\ +""", + 'db.system': 'sqlite', + 'db.name': 'example.db', + }, + }, + { + 'name': 'DELETE example.db', + 'context': {'trace_id': 10, 'span_id': 19, 'is_remote': False}, + 'parent': None, + 'start_time': 19000000000, + 'end_time': 20000000000, + 'attributes': { + 'logfire.span_type': 'span', + 'logfire.msg': 'DELETE FROM auth_records WHERE auth_records.id = ?', + 'db.statement': 'DELETE FROM auth_records WHERE auth_records.id = ?', + 'db.system': 'sqlite', + 'db.name': 'example.db', + }, + }, + ] + ) + + SQLAlchemyInstrumentor().uninstrument() + + def test_missing_opentelemetry_dependency() -> None: with mock.patch.dict('sys.modules', {'opentelemetry.instrumentation.sqlalchemy': None}): with pytest.raises(RuntimeError) as exc_info: diff --git a/uv.lock b/uv.lock index c16023b7c..670291e96 100644 --- a/uv.lock +++ b/uv.lock @@ -147,6 +147,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17", size = 7617 }, ] +[[package]] +name = "aiosqlite" +version = "0.20.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0d/3a/22ff5415bf4d296c1e92b07fd746ad42c96781f13295a074d58e77747848/aiosqlite-0.20.0.tar.gz", hash = "sha256:6d35c8c256637f4672f843c31021464090805bf925385ac39473fb16eaaca3d7", size = 21691 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/c4/c93eb22025a2de6b83263dfe3d7df2e19138e345bca6f18dba7394120930/aiosqlite-0.20.0-py3-none-any.whl", hash = "sha256:36a1deaca0cac40ebe32aac9977a6e2bbc7f5189f23f4a54d5908986729e5bd6", size = 15564 }, +] + [[package]] name = "amqp" version = "5.3.1" @@ -1466,6 +1478,7 @@ wsgi = [ [package.dev-dependencies] dev = [ { name = "aiohttp" }, + { name = "aiosqlite" }, { name = "anthropic" }, { name = "anyio" }, { name = "asyncpg" }, @@ -1581,6 +1594,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "aiohttp", specifier = ">=3.10.9" }, + { name = "aiosqlite", specifier = ">=0.20.0" }, { name = "anthropic", specifier = ">=0.27.0" }, { name = "anyio", specifier = "<4.4.0" }, { name = "asyncpg" },