From 869f0f8678b60d05b9da07bf22b66a8740e33ba5 Mon Sep 17 00:00:00 2001 From: Joshua Fenton Date: Mon, 20 May 2024 16:52:56 +1000 Subject: [PATCH] Harden logic for retrieving meter + tracer providers --- .../instrumentation/aws_lambda/__init__.py | 197 ++++++++++++++---- .../tests/test_util_functions.py | 96 ++++++++- 2 files changed, 254 insertions(+), 39 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py b/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py index b7b74edec1..a9b7bdbb7b 100644 --- a/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py @@ -71,21 +71,17 @@ def custom_event_context_extractor(lambda_event): import os import time from importlib import import_module -from typing import Any, Callable, Collection +from typing import Any, Callable, Collection, Optional, cast from urllib.parse import urlencode -from wrapt import wrap_function_wrapper - from opentelemetry.context.context import Context -from opentelemetry.instrumentation.aws_lambda.package import _instruments -from opentelemetry.instrumentation.aws_lambda.version import __version__ from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap from opentelemetry.metrics import MeterProvider, get_meter_provider from opentelemetry.propagate import get_global_textmap from opentelemetry.propagators.aws.aws_xray_propagator import ( - TRACE_HEADER_KEY, AwsXRayPropagator, + TRACE_HEADER_KEY, ) from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.semconv.trace import SpanAttributes @@ -98,6 +94,10 @@ def custom_event_context_extractor(lambda_event): ) from opentelemetry.trace.propagation import get_current_span from opentelemetry.trace.status import Status, StatusCode +from wrapt import wrap_function_wrapper + +from opentelemetry.instrumentation.aws_lambda.package import _instruments +from opentelemetry.instrumentation.aws_lambda.version import __version__ logger = logging.getLogger(__name__) @@ -112,6 +112,114 @@ def custom_event_context_extractor(lambda_event): ) +class FlushableMeterProvider(MeterProvider): + """ + Interface for a MeterProvider that can be flushed. + """ + + def force_flush(self, timeout: int) -> None: + """ + Force flush all meter data. + + Args: + timeout: The maximum amount of time to wait for the flush to complete, in milliseconds. + + Returns: None + """ + + ... + + +class FlushableTracerProvider(TracerProvider): + """ + Interface for a TracerProvider that can be flushed. + """ + + def force_flush(self, timeout: int) -> None: + """ + Force flush all tracer data. + + Args: + timeout: The maximum amount of time to wait for the flush to complete, in milliseconds. + + Returns: None + """ + + ... + + +def provider_warning(provider_type: str): + """ + Log a warning that the provider is missing the `force_flush` method. + + Args: + provider_type: Provider type to include in the warning message. + + Returns: A function that mimics the `force_flush` method, that will log a warning. + """ + + return lambda timeout: logger.warning( + f"{provider_type} was missing `force_flush` method. " + "This is necessary in case of a Lambda freeze and would exist in the OTel SDK implementation." + ) + + +def determine_tracer_provider( + tracer_provider_override: Optional[Any], +) -> FlushableTracerProvider: + """ + Determine the TracerProvider to use for the instrumentation. + + The TracerProvider must have a `force_flush` method to be used in the instrumentation. + If the TracerProvider does not have a `force_flush` method, a dummy `force_flush` method is added that logs a warning. + + Args: + tracer_provider_override: Optional TracerProvider to use for the instrumentation. + + Returns: TracerProvider that is flushable. + """ + + if not isinstance(tracer_provider_override, TracerProvider): + tracer_provider = get_tracer_provider() + else: + tracer_provider = tracer_provider_override + + if not hasattr(tracer_provider, "force_flush"): + tracer_provider.force_flush = provider_warning( + provider_type="TracerProvider" + ) + + return cast(FlushableTracerProvider, tracer_provider) + + +def determine_meter_provider( + meter_provider_override: Optional[Any], +) -> FlushableMeterProvider: + """ + Determine the MeterProvider to use for the instrumentation. + + The MeterProvider must have a `force_flush` method to be used in the instrumentation. + If the MeterProvider does not have a `force_flush` method, a dummy `force_flush` method is added that logs a warning. + + Args: + meter_provider_override: Optional MeterProvider to use for the instrumentation. + + Returns: MeterProvider that is flushable. + """ + + if not isinstance(meter_provider_override, MeterProvider): + meter_provider = get_meter_provider() + else: + meter_provider = meter_provider_override + + if not hasattr(meter_provider, "force_flush"): + meter_provider.force_flush = provider_warning( + provider_type="MeterProvider" + ) + + return cast(FlushableMeterProvider, meter_provider) + + def _default_event_context_extractor(lambda_event: Any) -> Context: """Default way of extracting the context from the Lambda Event. @@ -279,15 +387,47 @@ def _set_api_gateway_v2_proxy_attributes( return span +def flush( + tracer_provider: FlushableTracerProvider, + meter_provider: FlushableMeterProvider, +) -> None: + """ + Flushes the tracer and meter providers. + + Args: + tracer_provider: Tracer provider to flush. + meter_provider: Meter provider to flush. + + Returns: None + """ + + flush_timeout = determine_flush_timeout() + + now = time.time() + + try: + # NOTE: `force_flush` before function quit in case of Lambda freeze. + tracer_provider.force_flush(flush_timeout) + except Exception: # pylint: disable=broad-except + logger.exception("TracerProvider failed to flush traces") + + rem = int(flush_timeout - (time.time() - now) * 1000) + if rem > 0: + try: + # NOTE: `force_flush` before function quit in case of Lambda freeze. + meter_provider.force_flush(rem) + except Exception: # pylint: disable=broad-except + logger.exception("MeterProvider failed to flush metrics") + + # pylint: disable=too-many-statements def _instrument( wrapped_module_name, wrapped_function_name, - flush_timeout, event_context_extractor: Callable[[Any], Context], disable_aws_context_propagation: bool, - tracer_provider: TracerProvider = None, - meter_provider: MeterProvider = None, + tracer_provider: FlushableTracerProvider, + meter_provider: FlushableMeterProvider, ): # pylint: disable=too-many-locals # pylint: disable=too-many-statements @@ -392,32 +532,10 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches result.get("statusCode"), ) - now = time.time() - _tracer_provider = tracer_provider or get_tracer_provider() - if hasattr(_tracer_provider, "force_flush"): - try: - # NOTE: `force_flush` before function quit in case of Lambda freeze. - _tracer_provider.force_flush(flush_timeout) - except Exception: # pylint: disable=broad-except - logger.exception("TracerProvider failed to flush traces") - else: - logger.warning( - "TracerProvider was missing `force_flush` method. This is necessary in case of a Lambda freeze and would exist in the OTel SDK implementation." - ) - - _meter_provider = meter_provider or get_meter_provider() - if hasattr(_meter_provider, "force_flush"): - rem = flush_timeout - (time.time() - now) * 1000 - if rem > 0: - try: - # NOTE: `force_flush` before function quit in case of Lambda freeze. - _meter_provider.force_flush(rem) - except Exception: # pylint: disable=broad-except - logger.exception("MeterProvider failed to flush metrics") - else: - logger.warning( - "MeterProvider was missing `force_flush` method. This is necessary in case of a Lambda freeze and would exist in the OTel SDK implementation." - ) + flush( + meter_provider=meter_provider, + tracer_provider=tracer_provider, + ) if exception is not None: raise exception.with_traceback(exception.__traceback__) @@ -526,17 +644,20 @@ def _instrument(self, **kwargs): _instrument( self._wrapped_module_name, self._wrapped_function_name, - flush_timeout=determine_flush_timeout(), event_context_extractor=kwargs.get( "event_context_extractor", _default_event_context_extractor ), - tracer_provider=kwargs.get("tracer_provider"), disable_aws_context_propagation=is_aws_context_propagation_disabled( disable_aws_context_propagation_override=kwargs.get( "disable_aws_context_propagation", False ) ), - meter_provider=kwargs.get("meter_provider"), + meter_provider=determine_meter_provider( + meter_provider_override=kwargs.get("meter_provider") + ), + tracer_provider=determine_tracer_provider( + tracer_provider_override=kwargs.get("tracer_provider") + ), ) def _uninstrument(self, **kwargs): diff --git a/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_util_functions.py b/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_util_functions.py index 4a0d14b4ce..fc5a2b909a 100644 --- a/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_util_functions.py +++ b/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_util_functions.py @@ -14,13 +14,14 @@ from opentelemetry.instrumentation.aws_lambda import ( determine_flush_timeout, + flush, is_aws_context_propagation_disabled, ) from opentelemetry.instrumentation.aws_lambda import ( OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT, OTEL_LAMBDA_DISABLE_AWS_CONTEXT_PROPAGATION, ) -from unittest.mock import patch +from unittest.mock import patch, MagicMock from unittest import TestCase @@ -84,3 +85,96 @@ def test_env_var_t(self): def test_env_var_not_set(self): result = is_aws_context_propagation_disabled(False) self.assertFalse(result) + + +class TestFlush(TestCase): + + @patch("time.time", return_value=0) + @patch( + "opentelemetry.instrumentation.aws_lambda.determine_flush_timeout", + return_value=1000, + ) + def test_successful_flush(self, mock_time, mock_determine_flush_timeout): + tracer_provider = MagicMock() + meter_provider = MagicMock() + tracer_provider.force_flush = MagicMock() + meter_provider.force_flush = MagicMock() + + flush(tracer_provider, meter_provider) + + tracer_provider.force_flush.assert_called_once_with(1000) + meter_provider.force_flush.assert_called_once_with(1000) + + @patch("time.time", return_value=0) + @patch( + "opentelemetry.instrumentation.aws_lambda.determine_flush_timeout", + return_value=1000, + ) + def test_missing_force_flush_in_tracer_provider( + self, mock_time, mock_determine_flush_timeout + ): + tracer_provider = MagicMock() + del tracer_provider.force_flush # Remove force_flush method + meter_provider = MagicMock() + meter_provider.force_flush = MagicMock() + + flush(tracer_provider, meter_provider) + + assert "force_flush" not in dir(tracer_provider) + meter_provider.force_flush.assert_called_once_with(1000) + + @patch("time.time", return_value=0) + @patch( + "opentelemetry.instrumentation.aws_lambda.determine_flush_timeout", + return_value=1000, + ) + def test_missing_force_flush_in_meter_provider( + self, mock_time, mock_determine_flush_timeout + ): + tracer_provider = MagicMock() + tracer_provider.force_flush = MagicMock() + meter_provider = MagicMock() + del meter_provider.force_flush # Remove force_flush method + + flush(tracer_provider, meter_provider) + + assert "force_flush" not in dir(meter_provider) + tracer_provider.force_flush.assert_called_once_with(1000) + + @patch("time.time", return_value=0) + @patch( + "opentelemetry.instrumentation.aws_lambda.determine_flush_timeout", + return_value=1000, + ) + def test_exception_in_force_flush_tracer_provider( + self, mock_time, mock_determine_flush_timeout + ): + tracer_provider = MagicMock() + tracer_provider.force_flush = MagicMock( + side_effect=Exception("Tracer flush exception") + ) + meter_provider = MagicMock() + meter_provider.force_flush = MagicMock() + + flush(tracer_provider, meter_provider) + + meter_provider.force_flush.assert_called_once_with(1000) + + @patch("time.time", return_value=0) + @patch( + "opentelemetry.instrumentation.aws_lambda.determine_flush_timeout", + return_value=1000, + ) + def test_exception_in_force_flush_meter_provider( + self, mock_time, mock_determine_flush_timeout + ): + tracer_provider = MagicMock() + tracer_provider.force_flush = MagicMock() + meter_provider = MagicMock() + meter_provider.force_flush = MagicMock( + side_effect=Exception("Meter flush exception") + ) + + flush(tracer_provider, meter_provider) + + tracer_provider.force_flush.assert_called_once_with(1000)