Skip to content

Commit

Permalink
Support AsyncEngine to SQLAlchemy (#717)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Dec 23, 2024
1 parent cb9ad81 commit 0d21837
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 18 deletions.
12 changes: 9 additions & 3 deletions logfire/_internal/integrations/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING

try:
Expand All @@ -13,6 +14,7 @@

if TYPE_CHECKING:
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from typing_extensions import TypedDict, Unpack

class CommenterOptions(TypedDict, total=False):
Expand All @@ -21,15 +23,19 @@ class CommenterOptions(TypedDict, total=False):
opentelemetry_values: bool

class SQLAlchemyInstrumentKwargs(TypedDict, total=False):
engine: Engine | None
enable_commenter: bool | None
commenter_options: CommenterOptions | None
skip_dep_check: bool


def instrument_sqlalchemy(**kwargs: Unpack[SQLAlchemyInstrumentKwargs]) -> None:
def instrument_sqlalchemy(engine: AsyncEngine | Engine | None, **kwargs: Unpack[SQLAlchemyInstrumentKwargs]) -> None:
"""Instrument the `sqlalchemy` module so that spans are automatically created for each query.
See the `Logfire.instrument_sqlalchemy` method for details.
"""
SQLAlchemyInstrumentor().instrument(**kwargs)
with contextlib.suppress(ImportError):
from sqlalchemy.ext.asyncio import AsyncEngine

if isinstance(engine, AsyncEngine):
engine = engine.sync_engine
return SQLAlchemyInstrumentor().instrument(engine=engine, **kwargs)
13 changes: 12 additions & 1 deletion logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
from fastapi import FastAPI
from flask.app import Flask
from opentelemetry.metrics import _Gauge as Gauge
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.websockets import WebSocket
Expand Down Expand Up @@ -1498,17 +1500,26 @@ def instrument_aiohttp_client(self, **kwargs: Any) -> None:
self._warn_if_not_initialized_for_instrumentation()
return instrument_aiohttp_client(self, **kwargs)

def instrument_sqlalchemy(self, **kwargs: Unpack[SQLAlchemyInstrumentKwargs]) -> None:
def instrument_sqlalchemy(
self,
engine: AsyncEngine | Engine | None = None,
**kwargs: Unpack[SQLAlchemyInstrumentKwargs],
) -> None:
"""Instrument the `sqlalchemy` module so that spans are automatically created for each query.
Uses the
[OpenTelemetry SQLAlchemy Instrumentation](https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/sqlalchemy/sqlalchemy.html)
library, specifically `SQLAlchemyInstrumentor().instrument()`, to which it passes `**kwargs`.
Args:
engine: The `sqlalchemy` engine to instrument, or `None` to instrument all engines.
**kwargs: Additional keyword arguments to pass to the OpenTelemetry `instrument` methods.
"""
from .integrations.sqlalchemy import instrument_sqlalchemy

self._warn_if_not_initialized_for_instrumentation()
return instrument_sqlalchemy(
engine=engine,
**{ # type: ignore
'tracer_provider': self._config.get_tracer_provider(),
'meter_provider': self._config.get_meter_provider(),
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ dev = [
"logfire-api",
"requests",
"setuptools>=75.3.0",
"aiosqlite>=0.20.0",
]
docs = [
"mkdocs>=1.5.0",
Expand Down
205 changes: 191 additions & 14 deletions tests/otel_integrations/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import pytest
from inline_snapshot import snapshot
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from sqlalchemy.engine import Engine, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
from sqlalchemy.sql import text
from sqlalchemy.types import Integer, String
Expand All @@ -18,6 +20,19 @@
from logfire.testing import TestExporter


class Base(DeclarativeBase):
pass


# `auth` is in default scrubbing patterns, but `db.statement` attribute is in scrubbing SAFE_KEYS.
# So, logfire shouldn't redact `auth` in the `db.statement` attribute.
class AuthRecord(Base):
__tablename__ = 'auth_records'
id: Mapped[int] = mapped_column(primary_key=True)
number: Mapped[int] = mapped_column(Integer, nullable=False)
content: Mapped[str] = mapped_column(String, nullable=False)


@contextmanager
def sqlite_engine(path: Path) -> Iterator[Engine]:
path.unlink(missing_ok=True)
Expand All @@ -30,22 +45,8 @@ def sqlite_engine(path: Path) -> Iterator[Engine]:

def test_sqlalchemy_instrumentation(exporter: TestExporter):
with sqlite_engine(Path('example.db')) as engine:
# Need to ensure this import happens _after_ importing sqlalchemy
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor

logfire.instrument_sqlalchemy(engine=engine)

class Base(DeclarativeBase):
pass

# `auth` is in default scrubbing patterns, but `db.statement` attribute is in scrubbing SAFE_KEYS.
# So, logfire shouldn't redact `auth` in the `db.statement` attribute.
class AuthRecord(Base):
__tablename__ = 'auth_records'
id: Mapped[int] = mapped_column(primary_key=True)
number: Mapped[int] = mapped_column(Integer, nullable=False)
content: Mapped[str] = mapped_column(String, nullable=False)

Base.metadata.create_all(engine)

with Session(engine) as session:
Expand Down Expand Up @@ -207,6 +208,182 @@ class AuthRecord(Base):
SQLAlchemyInstrumentor().uninstrument()


@contextmanager
def sqlite_async_engine(path: Path) -> Iterator[AsyncEngine]:
path.unlink(missing_ok=True)
engine = create_async_engine(f'sqlite+aiosqlite:///{path}')
try:
yield engine
finally:
path.unlink()


@pytest.mark.anyio
async def test_sqlalchemy_async_instrumentation(exporter: TestExporter):
with sqlite_async_engine(Path('example.db')) as engine:
logfire.instrument_sqlalchemy(engine=engine)

async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

async with AsyncSession(engine) as session:
record = AuthRecord(id=1, number=2, content='abc')
await session.execute(text('select * from auth_records'))
session.add(record)
await session.commit()
await session.delete(record)
await session.commit()

assert exporter.exported_spans_as_dict() == snapshot(
[
{
'name': 'connect',
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
'parent': None,
'start_time': 1000000000,
'end_time': 2000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'connect',
'db.name': 'example.db',
'db.system': 'sqlite',
'logfire.level_num': 5,
},
},
{
'name': 'PRAGMA example.db',
'context': {'trace_id': 2, 'span_id': 3, 'is_remote': False},
'parent': None,
'start_time': 3000000000,
'end_time': 4000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'PRAGMA main.table_info("auth_records")',
'db.statement': 'PRAGMA main.table_info("auth_records")',
'db.system': 'sqlite',
'db.name': 'example.db',
},
},
{
'name': 'PRAGMA example.db',
'context': {'trace_id': 3, 'span_id': 5, 'is_remote': False},
'parent': None,
'start_time': 5000000000,
'end_time': 6000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'PRAGMA temp.table_info("auth_records")',
'db.statement': 'PRAGMA temp.table_info("auth_records")',
'db.system': 'sqlite',
'db.name': 'example.db',
},
},
{
'name': 'CREATE example.db',
'context': {'trace_id': 4, 'span_id': 7, 'is_remote': False},
'parent': None,
'start_time': 7000000000,
'end_time': 8000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'CREATE TABLE auth_records ( id INTEGER … t VARCHAR NOT NULL, PRIMARY KEY (id)\n)',
'db.statement': '\nCREATE TABLE auth_records (\n\tid INTEGER NOT NULL, \n\tnumber INTEGER NOT NULL, \n\tcontent VARCHAR NOT NULL, \n\tPRIMARY KEY (id)\n)\n\n',
'db.system': 'sqlite',
'db.name': 'example.db',
},
},
{
'name': 'connect',
'context': {'trace_id': 5, 'span_id': 9, 'is_remote': False},
'parent': None,
'start_time': 9000000000,
'end_time': 10000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'connect',
'db.name': 'example.db',
'db.system': 'sqlite',
'logfire.level_num': 5,
},
},
{
'name': 'select example.db',
'context': {'trace_id': 6, 'span_id': 11, 'is_remote': False},
'parent': None,
'start_time': 11000000000,
'end_time': 12000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'select * from auth_records',
'db.statement': 'select * from auth_records',
'db.system': 'sqlite',
'db.name': 'example.db',
},
},
{
'name': 'INSERT example.db',
'context': {'trace_id': 7, 'span_id': 13, 'is_remote': False},
'parent': None,
'start_time': 13000000000,
'end_time': 14000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'INSERT INTO auth_records (id, number, content) VALUES (?, ?, ?)',
'db.statement': 'INSERT INTO auth_records (id, number, content) VALUES (?, ?, ?)',
'db.system': 'sqlite',
'db.name': 'example.db',
},
},
{
'name': 'connect',
'context': {'trace_id': 8, 'span_id': 15, 'is_remote': False},
'parent': None,
'start_time': 15000000000,
'end_time': 16000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'connect',
'db.name': 'example.db',
'db.system': 'sqlite',
'logfire.level_num': 5,
},
},
{
'name': 'SELECT example.db',
'context': {'trace_id': 9, 'span_id': 17, 'is_remote': False},
'parent': None,
'start_time': 17000000000,
'end_time': 18000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'SELECT auth_recor…ds_content FROM auth_records WHERE …',
'db.statement': """\
SELECT auth_records.id AS auth_records_id, auth_records.number AS auth_records_number, auth_records.content AS auth_records_content \nFROM auth_records \nWHERE auth_records.id = ?\
""",
'db.system': 'sqlite',
'db.name': 'example.db',
},
},
{
'name': 'DELETE example.db',
'context': {'trace_id': 10, 'span_id': 19, 'is_remote': False},
'parent': None,
'start_time': 19000000000,
'end_time': 20000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'DELETE FROM auth_records WHERE auth_records.id = ?',
'db.statement': 'DELETE FROM auth_records WHERE auth_records.id = ?',
'db.system': 'sqlite',
'db.name': 'example.db',
},
},
]
)

SQLAlchemyInstrumentor().uninstrument()


def test_missing_opentelemetry_dependency() -> None:
with mock.patch.dict('sys.modules', {'opentelemetry.instrumentation.sqlalchemy': None}):
with pytest.raises(RuntimeError) as exc_info:
Expand Down
14 changes: 14 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0d21837

Please sign in to comment.