Skip to content

Commit

Permalink
Harden logic for retrieving meter + tracer providers
Browse files Browse the repository at this point in the history
  • Loading branch information
jbfenton committed Jun 7, 2024
1 parent ffbc78c commit 869f0f8
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,17 @@ def custom_event_context_extractor(lambda_event):
import os
import time
from importlib import import_module
from typing import Any, Callable, Collection
from typing import Any, Callable, Collection, Optional, cast
from urllib.parse import urlencode

from wrapt import wrap_function_wrapper

from opentelemetry.context.context import Context
from opentelemetry.instrumentation.aws_lambda.package import _instruments
from opentelemetry.instrumentation.aws_lambda.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.metrics import MeterProvider, get_meter_provider
from opentelemetry.propagate import get_global_textmap
from opentelemetry.propagators.aws.aws_xray_propagator import (
TRACE_HEADER_KEY,
AwsXRayPropagator,
TRACE_HEADER_KEY,
)
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.semconv.trace import SpanAttributes
Expand All @@ -98,6 +94,10 @@ def custom_event_context_extractor(lambda_event):
)
from opentelemetry.trace.propagation import get_current_span
from opentelemetry.trace.status import Status, StatusCode
from wrapt import wrap_function_wrapper

from opentelemetry.instrumentation.aws_lambda.package import _instruments
from opentelemetry.instrumentation.aws_lambda.version import __version__

logger = logging.getLogger(__name__)

Expand All @@ -112,6 +112,114 @@ def custom_event_context_extractor(lambda_event):
)


class FlushableMeterProvider(MeterProvider):
"""
Interface for a MeterProvider that can be flushed.
"""

def force_flush(self, timeout: int) -> None:
"""
Force flush all meter data.
Args:
timeout: The maximum amount of time to wait for the flush to complete, in milliseconds.
Returns: None
"""

...


class FlushableTracerProvider(TracerProvider):
"""
Interface for a TracerProvider that can be flushed.
"""

def force_flush(self, timeout: int) -> None:
"""
Force flush all tracer data.
Args:
timeout: The maximum amount of time to wait for the flush to complete, in milliseconds.
Returns: None
"""

...


def provider_warning(provider_type: str):
"""
Log a warning that the provider is missing the `force_flush` method.
Args:
provider_type: Provider type to include in the warning message.
Returns: A function that mimics the `force_flush` method, that will log a warning.
"""

return lambda timeout: logger.warning(
f"{provider_type} was missing `force_flush` method. "
"This is necessary in case of a Lambda freeze and would exist in the OTel SDK implementation."
)


def determine_tracer_provider(
tracer_provider_override: Optional[Any],
) -> FlushableTracerProvider:
"""
Determine the TracerProvider to use for the instrumentation.
The TracerProvider must have a `force_flush` method to be used in the instrumentation.
If the TracerProvider does not have a `force_flush` method, a dummy `force_flush` method is added that logs a warning.
Args:
tracer_provider_override: Optional TracerProvider to use for the instrumentation.
Returns: TracerProvider that is flushable.
"""

if not isinstance(tracer_provider_override, TracerProvider):
tracer_provider = get_tracer_provider()
else:
tracer_provider = tracer_provider_override

if not hasattr(tracer_provider, "force_flush"):
tracer_provider.force_flush = provider_warning(
provider_type="TracerProvider"
)

return cast(FlushableTracerProvider, tracer_provider)


def determine_meter_provider(
meter_provider_override: Optional[Any],
) -> FlushableMeterProvider:
"""
Determine the MeterProvider to use for the instrumentation.
The MeterProvider must have a `force_flush` method to be used in the instrumentation.
If the MeterProvider does not have a `force_flush` method, a dummy `force_flush` method is added that logs a warning.
Args:
meter_provider_override: Optional MeterProvider to use for the instrumentation.
Returns: MeterProvider that is flushable.
"""

if not isinstance(meter_provider_override, MeterProvider):
meter_provider = get_meter_provider()
else:
meter_provider = meter_provider_override

if not hasattr(meter_provider, "force_flush"):
meter_provider.force_flush = provider_warning(
provider_type="MeterProvider"
)

return cast(FlushableMeterProvider, meter_provider)


def _default_event_context_extractor(lambda_event: Any) -> Context:
"""Default way of extracting the context from the Lambda Event.
Expand Down Expand Up @@ -279,15 +387,47 @@ def _set_api_gateway_v2_proxy_attributes(
return span


def flush(
tracer_provider: FlushableTracerProvider,
meter_provider: FlushableMeterProvider,
) -> None:
"""
Flushes the tracer and meter providers.
Args:
tracer_provider: Tracer provider to flush.
meter_provider: Meter provider to flush.
Returns: None
"""

flush_timeout = determine_flush_timeout()

now = time.time()

try:
# NOTE: `force_flush` before function quit in case of Lambda freeze.
tracer_provider.force_flush(flush_timeout)
except Exception: # pylint: disable=broad-except
logger.exception("TracerProvider failed to flush traces")

rem = int(flush_timeout - (time.time() - now) * 1000)
if rem > 0:
try:
# NOTE: `force_flush` before function quit in case of Lambda freeze.
meter_provider.force_flush(rem)
except Exception: # pylint: disable=broad-except
logger.exception("MeterProvider failed to flush metrics")


# pylint: disable=too-many-statements
def _instrument(
wrapped_module_name,
wrapped_function_name,
flush_timeout,
event_context_extractor: Callable[[Any], Context],
disable_aws_context_propagation: bool,
tracer_provider: TracerProvider = None,
meter_provider: MeterProvider = None,
tracer_provider: FlushableTracerProvider,
meter_provider: FlushableMeterProvider,
):
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
Expand Down Expand Up @@ -392,32 +532,10 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
result.get("statusCode"),
)

now = time.time()
_tracer_provider = tracer_provider or get_tracer_provider()
if hasattr(_tracer_provider, "force_flush"):
try:
# NOTE: `force_flush` before function quit in case of Lambda freeze.
_tracer_provider.force_flush(flush_timeout)
except Exception: # pylint: disable=broad-except
logger.exception("TracerProvider failed to flush traces")
else:
logger.warning(
"TracerProvider was missing `force_flush` method. This is necessary in case of a Lambda freeze and would exist in the OTel SDK implementation."
)

_meter_provider = meter_provider or get_meter_provider()
if hasattr(_meter_provider, "force_flush"):
rem = flush_timeout - (time.time() - now) * 1000
if rem > 0:
try:
# NOTE: `force_flush` before function quit in case of Lambda freeze.
_meter_provider.force_flush(rem)
except Exception: # pylint: disable=broad-except
logger.exception("MeterProvider failed to flush metrics")
else:
logger.warning(
"MeterProvider was missing `force_flush` method. This is necessary in case of a Lambda freeze and would exist in the OTel SDK implementation."
)
flush(
meter_provider=meter_provider,
tracer_provider=tracer_provider,
)

if exception is not None:
raise exception.with_traceback(exception.__traceback__)
Expand Down Expand Up @@ -526,17 +644,20 @@ def _instrument(self, **kwargs):
_instrument(
self._wrapped_module_name,
self._wrapped_function_name,
flush_timeout=determine_flush_timeout(),
event_context_extractor=kwargs.get(
"event_context_extractor", _default_event_context_extractor
),
tracer_provider=kwargs.get("tracer_provider"),
disable_aws_context_propagation=is_aws_context_propagation_disabled(
disable_aws_context_propagation_override=kwargs.get(
"disable_aws_context_propagation", False
)
),
meter_provider=kwargs.get("meter_provider"),
meter_provider=determine_meter_provider(
meter_provider_override=kwargs.get("meter_provider")
),
tracer_provider=determine_tracer_provider(
tracer_provider_override=kwargs.get("tracer_provider")
),
)

def _uninstrument(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

from opentelemetry.instrumentation.aws_lambda import (
determine_flush_timeout,
flush,
is_aws_context_propagation_disabled,
)
from opentelemetry.instrumentation.aws_lambda import (
OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT,
OTEL_LAMBDA_DISABLE_AWS_CONTEXT_PROPAGATION,
)
from unittest.mock import patch
from unittest.mock import patch, MagicMock
from unittest import TestCase


Expand Down Expand Up @@ -84,3 +85,96 @@ def test_env_var_t(self):
def test_env_var_not_set(self):
result = is_aws_context_propagation_disabled(False)
self.assertFalse(result)


class TestFlush(TestCase):

@patch("time.time", return_value=0)
@patch(
"opentelemetry.instrumentation.aws_lambda.determine_flush_timeout",
return_value=1000,
)
def test_successful_flush(self, mock_time, mock_determine_flush_timeout):
tracer_provider = MagicMock()
meter_provider = MagicMock()
tracer_provider.force_flush = MagicMock()
meter_provider.force_flush = MagicMock()

flush(tracer_provider, meter_provider)

tracer_provider.force_flush.assert_called_once_with(1000)
meter_provider.force_flush.assert_called_once_with(1000)

@patch("time.time", return_value=0)
@patch(
"opentelemetry.instrumentation.aws_lambda.determine_flush_timeout",
return_value=1000,
)
def test_missing_force_flush_in_tracer_provider(
self, mock_time, mock_determine_flush_timeout
):
tracer_provider = MagicMock()
del tracer_provider.force_flush # Remove force_flush method
meter_provider = MagicMock()
meter_provider.force_flush = MagicMock()

flush(tracer_provider, meter_provider)

assert "force_flush" not in dir(tracer_provider)
meter_provider.force_flush.assert_called_once_with(1000)

@patch("time.time", return_value=0)
@patch(
"opentelemetry.instrumentation.aws_lambda.determine_flush_timeout",
return_value=1000,
)
def test_missing_force_flush_in_meter_provider(
self, mock_time, mock_determine_flush_timeout
):
tracer_provider = MagicMock()
tracer_provider.force_flush = MagicMock()
meter_provider = MagicMock()
del meter_provider.force_flush # Remove force_flush method

flush(tracer_provider, meter_provider)

assert "force_flush" not in dir(meter_provider)
tracer_provider.force_flush.assert_called_once_with(1000)

@patch("time.time", return_value=0)
@patch(
"opentelemetry.instrumentation.aws_lambda.determine_flush_timeout",
return_value=1000,
)
def test_exception_in_force_flush_tracer_provider(
self, mock_time, mock_determine_flush_timeout
):
tracer_provider = MagicMock()
tracer_provider.force_flush = MagicMock(
side_effect=Exception("Tracer flush exception")
)
meter_provider = MagicMock()
meter_provider.force_flush = MagicMock()

flush(tracer_provider, meter_provider)

meter_provider.force_flush.assert_called_once_with(1000)

@patch("time.time", return_value=0)
@patch(
"opentelemetry.instrumentation.aws_lambda.determine_flush_timeout",
return_value=1000,
)
def test_exception_in_force_flush_meter_provider(
self, mock_time, mock_determine_flush_timeout
):
tracer_provider = MagicMock()
tracer_provider.force_flush = MagicMock()
meter_provider = MagicMock()
meter_provider.force_flush = MagicMock(
side_effect=Exception("Meter flush exception")
)

flush(tracer_provider, meter_provider)

tracer_provider.force_flush.assert_called_once_with(1000)

0 comments on commit 869f0f8

Please sign in to comment.