Skip to content

Commit

Permalink
Add type hints to Psycopg
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Dec 4, 2024
1 parent f393546 commit 66a1747
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,26 @@
---
"""

from __future__ import annotations

import logging
import typing
from typing import Collection
from typing import Any, Callable, Collection, TypeVar

import psycopg # pylint: disable=import-self
from psycopg import (
AsyncCursor as pg_async_cursor, # pylint: disable=import-self,no-name-in-module
)
from psycopg import (
Cursor as pg_cursor, # pylint: disable=no-name-in-module,import-self
)
from psycopg.sql import Composed # pylint: disable=no-name-in-module

from opentelemetry.instrumentation import dbapi
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.psycopg.package import _instruments
from opentelemetry.instrumentation.psycopg.version import __version__
from opentelemetry.trace import TracerProvider

_logger = logging.getLogger(__name__)
_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"

Connection = TypeVar("Connection", psycopg.Connection, psycopg.AsyncConnection)
Cursor = TypeVar("Cursor", psycopg.Cursor, psycopg.AsyncCursor)


class PsycopgInstrumentor(BaseInstrumentor):
_CONNECTION_ATTRIBUTES = {
Expand All @@ -136,7 +135,7 @@ class PsycopgInstrumentor(BaseInstrumentor):
def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
def _instrument(self, **kwargs: Any):
"""Integrate with PostgreSQL Psycopg library.
Psycopg: http://initd.org/psycopg/
"""
Expand Down Expand Up @@ -181,7 +180,7 @@ def _instrument(self, **kwargs):
commenter_options=commenter_options,
)

def _uninstrument(self, **kwargs):
def _uninstrument(self, **kwargs: Any):
""" "Disable Psycopg instrumentation"""
dbapi.unwrap_connect(psycopg, "connect") # pylint: disable=no-member
dbapi.unwrap_connect(
Expand All @@ -195,7 +194,9 @@ def _uninstrument(self, **kwargs):

# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
@staticmethod
def instrument_connection(connection, tracer_provider=None):
def instrument_connection(
connection: Connection, tracer_provider: TracerProvider | None = None
) -> Connection:
if not hasattr(connection, "_is_instrumented_by_opentelemetry"):
connection._is_instrumented_by_opentelemetry = False

Expand All @@ -215,7 +216,7 @@ def instrument_connection(connection, tracer_provider=None):

# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
@staticmethod
def uninstrument_connection(connection):
def uninstrument_connection(connection: Connection) -> Connection:
connection.cursor_factory = getattr(
connection, _OTEL_CURSOR_FACTORY_KEY, None
)
Expand All @@ -227,9 +228,9 @@ def uninstrument_connection(connection):
class DatabaseApiIntegration(dbapi.DatabaseApiIntegration):
def wrapped_connection(
self,
connect_method: typing.Callable[..., typing.Any],
args: typing.Tuple[typing.Any, typing.Any],
kwargs: typing.Dict[typing.Any, typing.Any],
connect_method: Callable[..., Any],
args: tuple[Any, Any],
kwargs: dict[Any, Any],
):
"""Add object proxy to connection object."""
base_cursor_factory = kwargs.pop("cursor_factory", None)
Expand All @@ -245,9 +246,9 @@ def wrapped_connection(
class DatabaseApiAsyncIntegration(dbapi.DatabaseApiIntegration):
async def wrapped_connection(
self,
connect_method: typing.Callable[..., typing.Any],
args: typing.Tuple[typing.Any, typing.Any],
kwargs: typing.Dict[typing.Any, typing.Any],
connect_method: Callable[..., Any],
args: tuple[Any, Any],
kwargs: dict[Any, Any],
):
"""Add object proxy to connection object."""
base_cursor_factory = kwargs.pop("cursor_factory", None)
Expand All @@ -263,7 +264,7 @@ async def wrapped_connection(


class CursorTracer(dbapi.CursorTracer):
def get_operation_name(self, cursor, args):
def get_operation_name(self, cursor: Cursor, args: list[Any]) -> str:
if not args:
return ""

Expand All @@ -278,7 +279,7 @@ def get_operation_name(self, cursor, args):

return ""

def get_statement(self, cursor, args):
def get_statement(self, cursor: Cursor, args: list[Any]) -> str:
if not args:
return ""

Expand All @@ -288,7 +289,11 @@ def get_statement(self, cursor, args):
return statement


def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
def _new_cursor_factory(
db_api: DatabaseApiIntegration | None = None,
base_factory: type[psycopg.Cursor] | None = None,
tracer_provider: TracerProvider | None = None,
):
if not db_api:
db_api = DatabaseApiIntegration(
__name__,
Expand All @@ -298,21 +303,21 @@ def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
tracer_provider=tracer_provider,
)

base_factory = base_factory or pg_cursor
base_factory = base_factory or psycopg.Cursor
_cursor_tracer = CursorTracer(db_api)

class TracedCursorFactory(base_factory):
def execute(self, *args, **kwargs):
def execute(self, *args: Any, **kwargs: Any):
return _cursor_tracer.traced_execution(
self, super().execute, *args, **kwargs
)

def executemany(self, *args, **kwargs):
def executemany(self, *args: Any, **kwargs: Any):
return _cursor_tracer.traced_execution(
self, super().executemany, *args, **kwargs
)

def callproc(self, *args, **kwargs):
def callproc(self, *args: Any, **kwargs: Any):
return _cursor_tracer.traced_execution(
self, super().callproc, *args, **kwargs
)
Expand All @@ -321,7 +326,9 @@ def callproc(self, *args, **kwargs):


def _new_cursor_async_factory(
db_api=None, base_factory=None, tracer_provider=None
db_api: DatabaseApiAsyncIntegration | None = None,
base_factory: type[psycopg.AsyncCursor] | None = None,
tracer_provider: TracerProvider | None = None,
):
if not db_api:
db_api = DatabaseApiAsyncIntegration(
Expand All @@ -331,21 +338,21 @@ def _new_cursor_async_factory(
version=__version__,
tracer_provider=tracer_provider,
)
base_factory = base_factory or pg_async_cursor
base_factory = base_factory or psycopg.AsyncCursor
_cursor_tracer = CursorTracer(db_api)

class TracedCursorAsyncFactory(base_factory):
async def execute(self, *args, **kwargs):
async def execute(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
self, super().execute, *args, **kwargs
)

async def executemany(self, *args, **kwargs):
async def executemany(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
self, super().executemany, *args, **kwargs
)

async def callproc(self, *args, **kwargs):
async def callproc(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
self, super().callproc, *args, **kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations


_instruments = ("psycopg >= 3.1.0",)
_instruments: tuple[str, ...] = ("psycopg >= 3.1.0",)

0 comments on commit 66a1747

Please sign in to comment.