From 66a1747f0fbad7b948a318ffdc4d681f6b63cded Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 4 Dec 2024 18:25:12 +0100 Subject: [PATCH] Add type hints to Psycopg --- .../instrumentation/psycopg/__init__.py | 67 ++++++++++--------- .../instrumentation/psycopg/package.py | 4 +- 2 files changed, 39 insertions(+), 32 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py b/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py index e986ec0d46..7668fa806e 100644 --- a/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py @@ -101,27 +101,26 @@ --- """ +from __future__ import annotations + import logging -import typing -from typing import Collection +from typing import Any, Callable, Collection, TypeVar import psycopg # pylint: disable=import-self -from psycopg import ( - AsyncCursor as pg_async_cursor, # pylint: disable=import-self,no-name-in-module -) -from psycopg import ( - Cursor as pg_cursor, # pylint: disable=no-name-in-module,import-self -) from psycopg.sql import Composed # pylint: disable=no-name-in-module from opentelemetry.instrumentation import dbapi from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.psycopg.package import _instruments from opentelemetry.instrumentation.psycopg.version import __version__ +from opentelemetry.trace import TracerProvider _logger = logging.getLogger(__name__) _OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory" +Connection = TypeVar("Connection", psycopg.Connection, psycopg.AsyncConnection) +Cursor = TypeVar("Cursor", psycopg.Cursor, psycopg.AsyncCursor) + class PsycopgInstrumentor(BaseInstrumentor): _CONNECTION_ATTRIBUTES = { @@ -136,7 +135,7 @@ class PsycopgInstrumentor(BaseInstrumentor): def instrumentation_dependencies(self) -> Collection[str]: return _instruments - def _instrument(self, **kwargs): + def _instrument(self, **kwargs: Any): """Integrate with PostgreSQL Psycopg library. Psycopg: http://initd.org/psycopg/ """ @@ -181,7 +180,7 @@ def _instrument(self, **kwargs): commenter_options=commenter_options, ) - def _uninstrument(self, **kwargs): + def _uninstrument(self, **kwargs: Any): """ "Disable Psycopg instrumentation""" dbapi.unwrap_connect(psycopg, "connect") # pylint: disable=no-member dbapi.unwrap_connect( @@ -195,7 +194,9 @@ def _uninstrument(self, **kwargs): # TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql @staticmethod - def instrument_connection(connection, tracer_provider=None): + def instrument_connection( + connection: Connection, tracer_provider: TracerProvider | None = None + ) -> Connection: if not hasattr(connection, "_is_instrumented_by_opentelemetry"): connection._is_instrumented_by_opentelemetry = False @@ -215,7 +216,7 @@ def instrument_connection(connection, tracer_provider=None): # TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql @staticmethod - def uninstrument_connection(connection): + def uninstrument_connection(connection: Connection) -> Connection: connection.cursor_factory = getattr( connection, _OTEL_CURSOR_FACTORY_KEY, None ) @@ -227,9 +228,9 @@ def uninstrument_connection(connection): class DatabaseApiIntegration(dbapi.DatabaseApiIntegration): def wrapped_connection( self, - connect_method: typing.Callable[..., typing.Any], - args: typing.Tuple[typing.Any, typing.Any], - kwargs: typing.Dict[typing.Any, typing.Any], + connect_method: Callable[..., Any], + args: tuple[Any, Any], + kwargs: dict[Any, Any], ): """Add object proxy to connection object.""" base_cursor_factory = kwargs.pop("cursor_factory", None) @@ -245,9 +246,9 @@ def wrapped_connection( class DatabaseApiAsyncIntegration(dbapi.DatabaseApiIntegration): async def wrapped_connection( self, - connect_method: typing.Callable[..., typing.Any], - args: typing.Tuple[typing.Any, typing.Any], - kwargs: typing.Dict[typing.Any, typing.Any], + connect_method: Callable[..., Any], + args: tuple[Any, Any], + kwargs: dict[Any, Any], ): """Add object proxy to connection object.""" base_cursor_factory = kwargs.pop("cursor_factory", None) @@ -263,7 +264,7 @@ async def wrapped_connection( class CursorTracer(dbapi.CursorTracer): - def get_operation_name(self, cursor, args): + def get_operation_name(self, cursor: Cursor, args: list[Any]) -> str: if not args: return "" @@ -278,7 +279,7 @@ def get_operation_name(self, cursor, args): return "" - def get_statement(self, cursor, args): + def get_statement(self, cursor: Cursor, args: list[Any]) -> str: if not args: return "" @@ -288,7 +289,11 @@ def get_statement(self, cursor, args): return statement -def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None): +def _new_cursor_factory( + db_api: DatabaseApiIntegration | None = None, + base_factory: type[psycopg.Cursor] | None = None, + tracer_provider: TracerProvider | None = None, +): if not db_api: db_api = DatabaseApiIntegration( __name__, @@ -298,21 +303,21 @@ def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None): tracer_provider=tracer_provider, ) - base_factory = base_factory or pg_cursor + base_factory = base_factory or psycopg.Cursor _cursor_tracer = CursorTracer(db_api) class TracedCursorFactory(base_factory): - def execute(self, *args, **kwargs): + def execute(self, *args: Any, **kwargs: Any): return _cursor_tracer.traced_execution( self, super().execute, *args, **kwargs ) - def executemany(self, *args, **kwargs): + def executemany(self, *args: Any, **kwargs: Any): return _cursor_tracer.traced_execution( self, super().executemany, *args, **kwargs ) - def callproc(self, *args, **kwargs): + def callproc(self, *args: Any, **kwargs: Any): return _cursor_tracer.traced_execution( self, super().callproc, *args, **kwargs ) @@ -321,7 +326,9 @@ def callproc(self, *args, **kwargs): def _new_cursor_async_factory( - db_api=None, base_factory=None, tracer_provider=None + db_api: DatabaseApiAsyncIntegration | None = None, + base_factory: type[psycopg.AsyncCursor] | None = None, + tracer_provider: TracerProvider | None = None, ): if not db_api: db_api = DatabaseApiAsyncIntegration( @@ -331,21 +338,21 @@ def _new_cursor_async_factory( version=__version__, tracer_provider=tracer_provider, ) - base_factory = base_factory or pg_async_cursor + base_factory = base_factory or psycopg.AsyncCursor _cursor_tracer = CursorTracer(db_api) class TracedCursorAsyncFactory(base_factory): - async def execute(self, *args, **kwargs): + async def execute(self, *args: Any, **kwargs: Any): return await _cursor_tracer.traced_execution( self, super().execute, *args, **kwargs ) - async def executemany(self, *args, **kwargs): + async def executemany(self, *args: Any, **kwargs: Any): return await _cursor_tracer.traced_execution( self, super().executemany, *args, **kwargs ) - async def callproc(self, *args, **kwargs): + async def callproc(self, *args: Any, **kwargs: Any): return await _cursor_tracer.traced_execution( self, super().callproc, *args, **kwargs ) diff --git a/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/package.py b/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/package.py index 635edfb4db..a3ee72d1ae 100644 --- a/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/package.py +++ b/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/package.py @@ -11,6 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations - -_instruments = ("psycopg >= 3.1.0",) +_instruments: tuple[str, ...] = ("psycopg >= 3.1.0",)