diff --git a/CHANGELOG.md b/CHANGELOG.md index a28c5039c9..4198628292 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- `opentelemetry-instrumentation-starlette` Add type hints to the instrumentation + ([#3045](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3045)) - `opentelemetry-distro` default to OTLP log exporter. ([#3042](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3042)) - `opentelemetry-instrumentation-sqlalchemy` Update unit tests to run with SQLALchemy 2 diff --git a/instrumentation/opentelemetry-instrumentation-starlette/src/opentelemetry/instrumentation/py.typed b/instrumentation/opentelemetry-instrumentation-starlette/src/opentelemetry/instrumentation/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/instrumentation/opentelemetry-instrumentation-starlette/src/opentelemetry/instrumentation/starlette/__init__.py b/instrumentation/opentelemetry-instrumentation-starlette/src/opentelemetry/instrumentation/starlette/__init__.py index 820fa29411..5007bda50a 100644 --- a/instrumentation/opentelemetry-instrumentation-starlette/src/opentelemetry/instrumentation/starlette/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-starlette/src/opentelemetry/instrumentation/starlette/__init__.py @@ -170,7 +170,9 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A --- """ -from typing import Collection +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Collection, cast from starlette import applications from starlette.routing import Match @@ -184,18 +186,29 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.starlette.package import _instruments from opentelemetry.instrumentation.starlette.version import __version__ -from opentelemetry.metrics import get_meter +from opentelemetry.metrics import MeterProvider, get_meter from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace import get_tracer +from opentelemetry.trace import TracerProvider, get_tracer from opentelemetry.util.http import get_excluded_urls +if TYPE_CHECKING: + from typing import TypedDict, Unpack + + class InstrumentKwargs(TypedDict, total=False): + tracer_provider: TracerProvider + meter_provider: MeterProvider + server_request_hook: ServerRequestHook + client_request_hook: ClientRequestHook + client_response_hook: ClientResponseHook + + _excluded_urls = get_excluded_urls("STARLETTE") class StarletteInstrumentor(BaseInstrumentor): - """An instrumentor for starlette + """An instrumentor for Starlette. - See `BaseInstrumentor` + See `BaseInstrumentor`. """ _original_starlette = None @@ -206,8 +219,8 @@ def instrument_app( server_request_hook: ServerRequestHook = None, client_request_hook: ClientRequestHook = None, client_response_hook: ClientResponseHook = None, - meter_provider=None, - tracer_provider=None, + meter_provider: MeterProvider | None = None, + tracer_provider: TracerProvider | None = None, ): """Instrument an uninstrumented Starlette application.""" tracer = get_tracer( @@ -253,7 +266,7 @@ def uninstrument_app(app: applications.Starlette): def instrumentation_dependencies(self) -> Collection[str]: return _instruments - def _instrument(self, **kwargs): + def _instrument(self, **kwargs: Unpack[InstrumentKwargs]): self._original_starlette = applications.Starlette _InstrumentedStarlette._tracer_provider = kwargs.get("tracer_provider") _InstrumentedStarlette._server_request_hook = kwargs.get( @@ -269,7 +282,7 @@ def _instrument(self, **kwargs): applications.Starlette = _InstrumentedStarlette - def _uninstrument(self, **kwargs): + def _uninstrument(self, **kwargs: Any): """uninstrumenting all created apps by user""" for instance in _InstrumentedStarlette._instrumented_starlette_apps: self.uninstrument_app(instance) @@ -278,14 +291,14 @@ def _uninstrument(self, **kwargs): class _InstrumentedStarlette(applications.Starlette): - _tracer_provider = None - _meter_provider = None + _tracer_provider: TracerProvider | None = None + _meter_provider: MeterProvider | None = None _server_request_hook: ServerRequestHook = None _client_request_hook: ClientRequestHook = None _client_response_hook: ClientResponseHook = None - _instrumented_starlette_apps = set() + _instrumented_starlette_apps: set[applications.Starlette] = set() - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) tracer = get_tracer( __name__, @@ -318,21 +331,22 @@ def __del__(self): _InstrumentedStarlette._instrumented_starlette_apps.remove(self) -def _get_route_details(scope): +def _get_route_details(scope: dict[str, Any]) -> str | None: """ - Function to retrieve Starlette route from scope. + Function to retrieve Starlette route from ASGI scope. TODO: there is currently no way to retrieve http.route from a starlette application from scope. See: https://github.com/encode/starlette/pull/804 Args: - scope: A Starlette scope + scope: The ASGI scope that contains the Starlette application in the "app" key. + Returns: - A string containing the route or None + The path to the route if found, otherwise None. """ - app = scope["app"] - route = None + app = cast(applications.Starlette, scope["app"]) + route: str | None = None for starlette_route in app.routes: match, _ = starlette_route.matches(scope) @@ -344,18 +358,20 @@ def _get_route_details(scope): return route -def _get_default_span_details(scope): - """ - Callback to retrieve span name and attributes from scope. +def _get_default_span_details( + scope: dict[str, Any], +) -> tuple[str, dict[str, Any]]: + """Callback to retrieve span name and attributes from ASGI scope. Args: - scope: A Starlette scope + scope: The ASGI scope that contains the Starlette application in the "app" key. + Returns: - A tuple of span name and attributes + A tuple of span name and attributes. """ route = _get_route_details(scope) - method = scope.get("method", "") - attributes = {} + method: str = scope.get("method", "") + attributes: dict[str, Any] = {} if route: attributes[SpanAttributes.HTTP_ROUTE] = route if method and route: # http