Skip to content

Commit

Permalink
Tail sampling (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmojaki authored Jun 21, 2024
1 parent 0829bee commit e811c48
Show file tree
Hide file tree
Showing 7 changed files with 541 additions and 18 deletions.
2 changes: 2 additions & 0 deletions logfire/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from ._internal.constants import LevelName
from ._internal.exporters.file import load_file as load_spans_from_file
from ._internal.exporters.tail_sampling import TailSamplingOptions
from ._internal.main import Logfire, LogfireSpan
from ._internal.scrubbing import ScrubMatch
from ._internal.utils import suppress_instrumentation
Expand Down Expand Up @@ -111,4 +112,5 @@ def loguru_handler() -> dict[str, Any]:
'suppress_instrumentation',
'StructlogProcessor',
'LogfireLoggingHandler',
'TailSamplingOptions',
)
43 changes: 40 additions & 3 deletions logfire/_internal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@
from .exporters.fallback import FallbackSpanExporter
from .exporters.file import FileSpanExporter
from .exporters.otlp import OTLPExporterHttpSession, RetryFewerSpansSpanExporter
from .exporters.processor_wrapper import SpanProcessorWrapper
from .exporters.processor_wrapper import MainSpanProcessorWrapper
from .exporters.quiet_metrics import QuietMetricExporter
from .exporters.remove_pending import RemovePendingSpansExporter
from .exporters.tail_sampling import TailSamplingOptions, TailSamplingProcessor
from .integrations.executors import instrument_executors
from .metrics import ProxyMeterProvider, configure_metrics
from .scrubbing import Scrubber, ScrubCallback
Expand Down Expand Up @@ -159,6 +160,7 @@ def configure(
scrubbing_patterns: Sequence[str] | None = None,
scrubbing_callback: ScrubCallback | None = None,
inspect_arguments: bool | None = None,
tail_sampling: TailSamplingOptions | None = None,
) -> None:
"""Configure the logfire SDK.
Expand Down Expand Up @@ -216,6 +218,7 @@ def configure(
[f-string magic](https://docs.pydantic.dev/logfire/guides/onboarding_checklist/add_manual_tracing/#f-strings).
If `None` uses the `LOGFIRE_INSPECT_ARGUMENTS` environment variable.
Defaults to `True` if and only if the Python version is at least 3.11.
tail_sampling: Tail sampling options. Not ready for general use.
"""
if processors is not None: # pragma: no cover
raise ValueError(
Expand Down Expand Up @@ -253,6 +256,7 @@ def configure(
scrubbing_patterns=scrubbing_patterns,
scrubbing_callback=scrubbing_callback,
inspect_arguments=inspect_arguments,
tail_sampling=tail_sampling,
)


Expand Down Expand Up @@ -332,6 +336,12 @@ class _LogfireConfigData:
scrubbing_callback: ScrubCallback | None
"""A function that is called for each match found by the scrubber."""

inspect_arguments: bool
"""Whether to enable f-string magic"""

tail_sampling: TailSamplingOptions | None
"""Tail sampling options"""

def _load_configuration(
self,
# note that there are no defaults here so that the only place
Expand Down Expand Up @@ -360,6 +370,7 @@ def _load_configuration(
scrubbing_patterns: Sequence[str] | None,
scrubbing_callback: ScrubCallback | None,
inspect_arguments: bool | None,
tail_sampling: TailSamplingOptions | None,
) -> None:
"""Merge the given parameters with the environment variables file configurations."""
param_manager = ParamManager.create(config_dir)
Expand Down Expand Up @@ -414,6 +425,12 @@ def _load_configuration(

if get_version(pydantic.__version__) < get_version('2.5.0'): # pragma: no cover
raise RuntimeError('The Pydantic plugin requires Pydantic 2.5.0 or newer.')

if isinstance(tail_sampling, dict):
# This is particularly for deserializing from a dict as in executors.py
tail_sampling = TailSamplingOptions(**tail_sampling) # type: ignore
self.tail_sampling = tail_sampling

self.fast_shutdown = fast_shutdown

self.id_generator = id_generator or RandomIdGenerator()
Expand Down Expand Up @@ -457,6 +474,7 @@ def __init__(
scrubbing_patterns: Sequence[str] | None = None,
scrubbing_callback: ScrubCallback | None = None,
inspect_arguments: bool | None = None,
tail_sampling: TailSamplingOptions | None = None,
) -> None:
"""Create a new LogfireConfig.
Expand Down Expand Up @@ -490,6 +508,7 @@ def __init__(
scrubbing_patterns=scrubbing_patterns,
scrubbing_callback=scrubbing_callback,
inspect_arguments=inspect_arguments,
tail_sampling=tail_sampling,
)
# initialize with no-ops so that we don't impact OTEL's global config just because logfire is installed
# that is, we defer setting logfire as the otel global config until `configure` is called
Expand Down Expand Up @@ -527,6 +546,7 @@ def configure(
scrubbing_patterns: Sequence[str] | None,
scrubbing_callback: ScrubCallback | None,
inspect_arguments: bool | None,
tail_sampling: TailSamplingOptions | None,
) -> None:
with self._lock:
self._initialized = False
Expand Down Expand Up @@ -554,6 +574,7 @@ def configure(
scrubbing_patterns,
scrubbing_callback,
inspect_arguments,
tail_sampling,
)
self.initialize()

Expand Down Expand Up @@ -592,8 +613,15 @@ def _initialize(self) -> ProxyTracerProvider:
# Both recommend generating a UUID.
resource = Resource({ResourceAttributes.SERVICE_INSTANCE_ID: uuid4().hex}).merge(resource)

# Avoid using the usual sampler if we're using tail-based sampling.
# The TailSamplingProcessor will handle the random sampling part as well.
sampler = (
ParentBasedTraceIdRatio(self.trace_sample_rate)
if self.trace_sample_rate < 1 and self.tail_sampling is None
else None
)
tracer_provider = SDKTracerProvider(
sampler=ParentBasedTraceIdRatio(self.trace_sample_rate),
sampler=sampler,
resource=resource,
id_generator=self.id_generator,
)
Expand All @@ -607,7 +635,16 @@ def add_span_processor(span_processor: SpanProcessor) -> None:
# Most span processors added to the tracer provider should also be recorded in the `processors` list
# so that they can be used by the final pending span processor.
# This means that `tracer_provider.add_span_processor` should only appear in two places.
span_processor = SpanProcessorWrapper(span_processor, self.scrubber)
if self.tail_sampling:
span_processor = TailSamplingProcessor(
span_processor,
self.tail_sampling,
# If self.trace_sample_rate < 1 then that ratio of spans should be included randomly by this.
# In that case the tracer provider doesn't need to do any sampling, see above.
# Otherwise we're not using any random sampling, so 0% of spans should be included 'randomly'.
self.trace_sample_rate if self.trace_sample_rate < 1 else 0,
)
span_processor = MainSpanProcessorWrapper(span_processor, self.scrubber)
tracer_provider.add_span_processor(span_processor)
processors.append(span_processor)

Expand Down
15 changes: 5 additions & 10 deletions logfire/_internal/exporters/processor_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@
)
from ..scrubbing import Scrubber
from ..utils import ReadableSpanDict, is_instrumentation_suppressed, span_to_dict, truncate_string
from .wrapper import WrapperSpanProcessor


class SpanProcessorWrapper(SpanProcessor):
class MainSpanProcessorWrapper(WrapperSpanProcessor):
"""Wrapper around other processors to intercept starting and ending spans with our own global logic.
Suppresses starting/ending if the current context has a `suppress_instrumentation` value.
Tweaks the send/receive span names generated by the ASGI middleware.
"""

def __init__(self, processor: SpanProcessor, scrubber: Scrubber) -> None:
self.processor = processor
super().__init__(processor)
self.scrubber = scrubber

def on_start(
Expand All @@ -37,7 +38,7 @@ def on_start(
if is_instrumentation_suppressed():
return
_set_log_level_on_asgi_send_receive_spans(span)
self.processor.on_start(span, parent_context)
super().on_start(span, parent_context)

def on_end(self, span: ReadableSpan) -> None:
if is_instrumentation_suppressed():
Expand All @@ -47,13 +48,7 @@ def on_end(self, span: ReadableSpan) -> None:
_tweak_http_spans(span_dict)
self.scrubber.scrub_span(span_dict)
span = ReadableSpan(**span_dict)
self.processor.on_end(span)

def shutdown(self) -> None:
self.processor.shutdown()

def force_flush(self, timeout_millis: int = 30000) -> bool:
return self.processor.force_flush(timeout_millis) # pragma: no cover
super().on_end(span)


def _set_log_level_on_asgi_send_receive_spans(span: Span) -> None:
Expand Down
145 changes: 145 additions & 0 deletions logfire/_internal/exporters/tail_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

import random
from dataclasses import dataclass
from functools import cached_property
from threading import Lock

from opentelemetry import context
from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor

from logfire._internal.constants import (
ATTRIBUTES_LOG_LEVEL_NUM_KEY,
LEVEL_NUMBERS,
ONE_SECOND_IN_NANOSECONDS,
LevelName,
)
from logfire._internal.exporters.wrapper import WrapperSpanProcessor


@dataclass
class TailSamplingOptions:
level: LevelName | None = 'notice'
"""
Include all spans/logs with level greater than or equal to this level.
If None, spans are not included based on level.
"""

duration: float | None = 1.0
"""
Include all spans/logs with duration greater than this duration in seconds.
If None, spans are not included based on duration.
"""


@dataclass
class TraceBuffer:
"""Arguments of `on_start` and `on_end` for spans in a single trace."""

started: list[tuple[Span, context.Context | None]]
ended: list[ReadableSpan]

@cached_property
def first_span(self) -> Span:
return self.started[0][0]


class TailSamplingProcessor(WrapperSpanProcessor):
"""Passes spans to the wrapped processor if any span in a trace meets the sampling criteria."""

def __init__(self, processor: SpanProcessor, options: TailSamplingOptions, random_rate: float) -> None:
super().__init__(processor)
self.duration: float = (
float('inf') if options.duration is None else options.duration * ONE_SECOND_IN_NANOSECONDS
)
self.level: float | int = float('inf') if options.level is None else LEVEL_NUMBERS[options.level]
self.random_rate = random_rate

# A TraceBuffer is typically created for each new trace.
# If a span meets the sampling criteria, the buffer is dropped and all spans within are pushed
# to the wrapped processor.
# So when more spans arrive and there's no buffer, they get passed through immediately.
self.traces: dict[int, TraceBuffer] = {}

# Code that touches self.traces and its contents should be protected by this lock.
self.lock = Lock()

def on_start(self, span: Span, parent_context: context.Context | None = None) -> None:
dropped = False
buffer = None

with self.lock:
# span.context could supposedly be None, not sure how.
if span.context: # pragma: no branch
trace_id = span.context.trace_id
# If span.parent is None, it's the root span of a trace.
# If random.random() <= self.random_rate, immediately include this trace,
# meaning no buffer for it.
if span.parent is None and random.random() > self.random_rate:
self.traces[trace_id] = TraceBuffer([], [])

buffer = self.traces.get(trace_id)
if buffer is not None:
# This trace's spans haven't met the criteria yet, so add this span to the buffer.
buffer.started.append((span, parent_context))
dropped = self.check_span(span, buffer)
# The opposite case is handled outside the lock since it may take some time.

# This code may take longer since it calls the wrapped processor which might do anything.
# It shouldn't be inside the lock to avoid blocking other threads.
# Since it's not in the lock, it shouldn't touch self.traces or its contents.
if buffer is None:
super().on_start(span, parent_context)
elif dropped:
self.push_buffer(buffer)

def on_end(self, span: ReadableSpan) -> None:
# This has a very similar structure and reasoning to on_start.

dropped = False
buffer = None

with self.lock:
if span.context: # pragma: no branch
trace_id = span.context.trace_id
buffer = self.traces.get(trace_id)
if buffer is not None:
buffer.ended.append(span)
dropped = self.check_span(span, buffer)
if span.parent is None:
# This is the root span, so the trace is hopefully complete.
# Delete the buffer to save memory.
self.traces.pop(trace_id, None)

if buffer is None:
super().on_end(span)
elif dropped:
self.push_buffer(buffer)

def check_span(self, span: ReadableSpan, buffer: TraceBuffer) -> bool:
"""If the span meets the sampling criteria, drop the buffer and return True. Otherwise, return False."""
# span.end_time and span.start_time are in nanoseconds and can be None.
if (span.end_time or span.start_time or 0) - (buffer.first_span.start_time or float('inf')) > self.duration:
self.drop_buffer(buffer)
return True

attributes = span.attributes or {}
level = attributes.get(ATTRIBUTES_LOG_LEVEL_NUM_KEY)
if not isinstance(level, int):
level = LEVEL_NUMBERS['info']
if level >= self.level:
self.drop_buffer(buffer)
return True

return False

def drop_buffer(self, buffer: TraceBuffer) -> None:
span_context = buffer.first_span.context
assert span_context is not None
del self.traces[span_context.trace_id]

def push_buffer(self, buffer: TraceBuffer) -> None:
for started in buffer.started:
super().on_start(*started)
for span in buffer.ended:
super().on_end(span)
30 changes: 26 additions & 4 deletions logfire/_internal/exporters/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Any, Dict, Optional, Sequence
from __future__ import annotations

from typing import Any, Sequence

from opentelemetry import context
from opentelemetry.sdk.metrics.export import AggregationTemporality, MetricExporter, MetricExportResult, MetricsData
from opentelemetry.sdk.metrics.view import Aggregation
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult


Expand All @@ -28,8 +31,8 @@ class WrapperMetricExporter(MetricExporter):
def __init__(
self,
exporter: MetricExporter,
preferred_temporality: Optional[Dict[type, AggregationTemporality]] = None,
preferred_aggregation: Optional[Dict[type, Aggregation]] = None,
preferred_temporality: dict[type, AggregationTemporality] | None = None,
preferred_aggregation: dict[type, Aggregation] | None = None,
) -> None:
super().__init__(preferred_temporality=preferred_temporality, preferred_aggregation=preferred_aggregation) # type: ignore
self.wrapped_exporter = exporter
Expand All @@ -42,3 +45,22 @@ def force_flush(self, timeout_millis: float = 10_000) -> bool:

def shutdown(self, timeout_millis: float = 30_000, **kwargs: Any) -> None:
self.wrapped_exporter.shutdown(timeout_millis, **kwargs) # type: ignore


class WrapperSpanProcessor(SpanProcessor):
"""A base class for SpanProcessors that wrap another processor."""

def __init__(self, processor: SpanProcessor) -> None:
self.processor = processor

def on_start(self, span: Span, parent_context: context.Context | None = None) -> None:
self.processor.on_start(span, parent_context)

def on_end(self, span: ReadableSpan) -> None:
self.processor.on_end(span)

def shutdown(self) -> None:
self.processor.shutdown()

def force_flush(self, timeout_millis: int = 30000) -> bool:
return self.processor.force_flush(timeout_millis) # pragma: no cover
Loading

0 comments on commit e811c48

Please sign in to comment.