Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tail sampling #263

Merged
merged 15 commits into from
Jun 21, 2024
2 changes: 2 additions & 0 deletions logfire/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from ._internal.constants import LevelName
from ._internal.exporters.file import load_file as load_spans_from_file
from ._internal.exporters.tail_sampling import TailSamplingOptions
from ._internal.main import Logfire, LogfireSpan
from ._internal.scrubbing import ScrubMatch
from ._internal.utils import suppress_instrumentation
Expand Down Expand Up @@ -111,4 +112,5 @@ def loguru_handler() -> dict[str, Any]:
'suppress_instrumentation',
'StructlogProcessor',
'LogfireLoggingHandler',
'TailSamplingOptions',
)
43 changes: 40 additions & 3 deletions logfire/_internal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@
from .exporters.fallback import FallbackSpanExporter
from .exporters.file import FileSpanExporter
from .exporters.otlp import OTLPExporterHttpSession, RetryFewerSpansSpanExporter
from .exporters.processor_wrapper import SpanProcessorWrapper
from .exporters.processor_wrapper import MainSpanProcessorWrapper
from .exporters.quiet_metrics import QuietMetricExporter
from .exporters.remove_pending import RemovePendingSpansExporter
from .exporters.tail_sampling import TailSamplingOptions, TailSamplingProcessor
from .integrations.executors import instrument_executors
from .metrics import ProxyMeterProvider, configure_metrics
from .scrubbing import Scrubber, ScrubCallback
Expand Down Expand Up @@ -159,6 +160,7 @@ def configure(
scrubbing_patterns: Sequence[str] | None = None,
scrubbing_callback: ScrubCallback | None = None,
inspect_arguments: bool | None = None,
tail_sampling: TailSamplingOptions | None = None,
) -> None:
"""Configure the logfire SDK.

Expand Down Expand Up @@ -216,6 +218,7 @@ def configure(
[f-string magic](https://docs.pydantic.dev/logfire/guides/onboarding_checklist/add_manual_tracing/#f-strings).
If `None` uses the `LOGFIRE_INSPECT_ARGUMENTS` environment variable.
Defaults to `True` if and only if the Python version is at least 3.11.
tail_sampling: Tail sampling options. Not ready for general use.
"""
if processors is not None: # pragma: no cover
raise ValueError(
Expand Down Expand Up @@ -253,6 +256,7 @@ def configure(
scrubbing_patterns=scrubbing_patterns,
scrubbing_callback=scrubbing_callback,
inspect_arguments=inspect_arguments,
tail_sampling=tail_sampling,
)


Expand Down Expand Up @@ -332,6 +336,12 @@ class _LogfireConfigData:
scrubbing_callback: ScrubCallback | None
"""A function that is called for each match found by the scrubber."""

inspect_arguments: bool
"""Whether to enable f-string magic"""

tail_sampling: TailSamplingOptions | None
"""Tail sampling options"""
Comment on lines +342 to +343
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could just be a typeddict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that, I've also wanted ConsoleOptions and PydanticPlugin to be typed dicts, but they seemed worse. They can't define defaults in the class, and passing in a plain dict is actually not that user friendly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense


def _load_configuration(
self,
# note that there are no defaults here so that the only place
Expand Down Expand Up @@ -360,6 +370,7 @@ def _load_configuration(
scrubbing_patterns: Sequence[str] | None,
scrubbing_callback: ScrubCallback | None,
inspect_arguments: bool | None,
tail_sampling: TailSamplingOptions | None,
) -> None:
"""Merge the given parameters with the environment variables file configurations."""
param_manager = ParamManager.create(config_dir)
Expand Down Expand Up @@ -414,6 +425,12 @@ def _load_configuration(

if get_version(pydantic.__version__) < get_version('2.5.0'): # pragma: no cover
raise RuntimeError('The Pydantic plugin requires Pydantic 2.5.0 or newer.')

if isinstance(tail_sampling, dict):
# This is particularly for deserializing from a dict as in executors.py
tail_sampling = TailSamplingOptions(**tail_sampling) # type: ignore
self.tail_sampling = tail_sampling

self.fast_shutdown = fast_shutdown

self.id_generator = id_generator or RandomIdGenerator()
Expand Down Expand Up @@ -457,6 +474,7 @@ def __init__(
scrubbing_patterns: Sequence[str] | None = None,
scrubbing_callback: ScrubCallback | None = None,
inspect_arguments: bool | None = None,
tail_sampling: TailSamplingOptions | None = None,
) -> None:
"""Create a new LogfireConfig.

Expand Down Expand Up @@ -490,6 +508,7 @@ def __init__(
scrubbing_patterns=scrubbing_patterns,
scrubbing_callback=scrubbing_callback,
inspect_arguments=inspect_arguments,
tail_sampling=tail_sampling,
)
# initialize with no-ops so that we don't impact OTEL's global config just because logfire is installed
# that is, we defer setting logfire as the otel global config until `configure` is called
Expand Down Expand Up @@ -527,6 +546,7 @@ def configure(
scrubbing_patterns: Sequence[str] | None,
scrubbing_callback: ScrubCallback | None,
inspect_arguments: bool | None,
tail_sampling: TailSamplingOptions | None,
) -> None:
with self._lock:
self._initialized = False
Expand Down Expand Up @@ -554,6 +574,7 @@ def configure(
scrubbing_patterns,
scrubbing_callback,
inspect_arguments,
tail_sampling,
)
self.initialize()

Expand Down Expand Up @@ -592,8 +613,15 @@ def _initialize(self) -> ProxyTracerProvider:
# Both recommend generating a UUID.
resource = Resource({ResourceAttributes.SERVICE_INSTANCE_ID: uuid4().hex}).merge(resource)

# Avoid using the usual sampler if we're using tail-based sampling.
# The TailSamplingProcessor will handle the random sampling part as well.
sampler = (
ParentBasedTraceIdRatio(self.trace_sample_rate)
if self.trace_sample_rate < 1 and self.tail_sampling is None
else None
)
Comment on lines +616 to +622
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we document / tell users they can't combine these? Is there a world where I want to head sample down to 10% (to reduce overhead in the SDK) and then tail sample down to 1%? There's still advantages to head sampling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They can combine them, see the PR body or the new test_random_sampling.

Is there a world where I want to head sample down to 10% (to reduce overhead in the SDK) and then tail sample down to 1%?

I don't know what this means. Tail sample down to a percentage?

It sounds like you want to be able to discard most spans up front randomly regardless of whether tail-sampling would include them, so that a span only gets through if it's 'notable' AND 'lucky'.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly

tracer_provider = SDKTracerProvider(
sampler=ParentBasedTraceIdRatio(self.trace_sample_rate),
sampler=sampler,
resource=resource,
id_generator=self.id_generator,
)
Expand All @@ -607,7 +635,16 @@ def add_span_processor(span_processor: SpanProcessor) -> None:
# Most span processors added to the tracer provider should also be recorded in the `processors` list
# so that they can be used by the final pending span processor.
# This means that `tracer_provider.add_span_processor` should only appear in two places.
span_processor = SpanProcessorWrapper(span_processor, self.scrubber)
if self.tail_sampling:
span_processor = TailSamplingProcessor(
span_processor,
self.tail_sampling,
# If self.trace_sample_rate < 1 then that ratio of spans should be included randomly by this.
# In that case the tracer provider doesn't need to do any sampling, see above.
# Otherwise we're not using any random sampling, so 0% of spans should be included 'randomly'.
self.trace_sample_rate if self.trace_sample_rate < 1 else 0,
)
span_processor = MainSpanProcessorWrapper(span_processor, self.scrubber)
tracer_provider.add_span_processor(span_processor)
processors.append(span_processor)

Expand Down
15 changes: 5 additions & 10 deletions logfire/_internal/exporters/processor_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@
)
from ..scrubbing import Scrubber
from ..utils import ReadableSpanDict, is_instrumentation_suppressed, span_to_dict, truncate_string
from .wrapper import WrapperSpanProcessor


class SpanProcessorWrapper(SpanProcessor):
class MainSpanProcessorWrapper(WrapperSpanProcessor):
"""Wrapper around other processors to intercept starting and ending spans with our own global logic.

Suppresses starting/ending if the current context has a `suppress_instrumentation` value.
Tweaks the send/receive span names generated by the ASGI middleware.
"""

def __init__(self, processor: SpanProcessor, scrubber: Scrubber) -> None:
self.processor = processor
super().__init__(processor)
self.scrubber = scrubber

def on_start(
Expand All @@ -37,7 +38,7 @@ def on_start(
if is_instrumentation_suppressed():
return
_set_log_level_on_asgi_send_receive_spans(span)
self.processor.on_start(span, parent_context)
super().on_start(span, parent_context)

def on_end(self, span: ReadableSpan) -> None:
if is_instrumentation_suppressed():
Expand All @@ -47,13 +48,7 @@ def on_end(self, span: ReadableSpan) -> None:
_tweak_http_spans(span_dict)
self.scrubber.scrub_span(span_dict)
span = ReadableSpan(**span_dict)
self.processor.on_end(span)

def shutdown(self) -> None:
self.processor.shutdown()

def force_flush(self, timeout_millis: int = 30000) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the method we were using to get logfire to work with AWS lambda? If so, I guess we need it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's moved to the base class.

return self.processor.force_flush(timeout_millis) # pragma: no cover
super().on_end(span)


def _set_log_level_on_asgi_send_receive_spans(span: Span) -> None:
Expand Down
145 changes: 145 additions & 0 deletions logfire/_internal/exporters/tail_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

import random
from dataclasses import dataclass
from functools import cached_property
from threading import Lock

from opentelemetry import context
from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor

from logfire._internal.constants import (
ATTRIBUTES_LOG_LEVEL_NUM_KEY,
LEVEL_NUMBERS,
ONE_SECOND_IN_NANOSECONDS,
LevelName,
)
from logfire._internal.exporters.wrapper import WrapperSpanProcessor


@dataclass
class TailSamplingOptions:
level: LevelName | None = 'notice'
"""
Include all spans/logs with level greater than or equal to this level.
If None, spans are not included based on level.
"""

duration: float | None = 1.0
"""
Include all spans/logs with duration greater than this duration in seconds.
If None, spans are not included based on duration.
"""


@dataclass
class TraceBuffer:
"""Arguments of `on_start` and `on_end` for spans in a single trace."""

started: list[tuple[Span, context.Context | None]]
ended: list[ReadableSpan]

@cached_property
def first_span(self) -> Span:
return self.started[0][0]


class TailSamplingProcessor(WrapperSpanProcessor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really starting to feel like we should stop adding layers / wrapping processors like this. I would prefer to have a single processor that handles everything (sampling, batching, retries, etc.). I feel like it could be optimized more and would be easier to understand. It can still have pluggable bits just more explicit and less abstract.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This layer wraps all processors, including the console and user-defined processors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that. But I feel like we should just make LogfireSpanProcessor which does all of those things in one place. In particular we avoid double buffering.

"""Passes spans to the wrapped processor if any span in a trace meets the sampling criteria."""

def __init__(self, processor: SpanProcessor, options: TailSamplingOptions, random_rate: float) -> None:
super().__init__(processor)
self.duration: float = (
float('inf') if options.duration is None else options.duration * ONE_SECOND_IN_NANOSECONDS
)
Comment on lines +52 to +54
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like these durations to have the unit in their name. So self.duration -> self.duration_ns and TailsamplingOptions.duration -> TailSamplingOptions.duration_sec or something like that.

self.level: float | int = float('inf') if options.level is None else LEVEL_NUMBERS[options.level]
self.random_rate = random_rate

# A TraceBuffer is typically created for each new trace.
# If a span meets the sampling criteria, the buffer is dropped and all spans within are pushed
# to the wrapped processor.
# So when more spans arrive and there's no buffer, they get passed through immediately.
self.traces: dict[int, TraceBuffer] = {}

# Code that touches self.traces and its contents should be protected by this lock.
self.lock = Lock()

def on_start(self, span: Span, parent_context: context.Context | None = None) -> None:
dropped = False
buffer = None

with self.lock:
# span.context could supposedly be None, not sure how.
if span.context: # pragma: no branch
trace_id = span.context.trace_id
# If span.parent is None, it's the root span of a trace.
# If random.random() <= self.random_rate, immediately include this trace,
# meaning no buffer for it.
if span.parent is None and random.random() > self.random_rate:
self.traces[trace_id] = TraceBuffer([], [])

buffer = self.traces.get(trace_id)
if buffer is not None:
# This trace's spans haven't met the criteria yet, so add this span to the buffer.
buffer.started.append((span, parent_context))
dropped = self.check_span(span, buffer)
# The opposite case is handled outside the lock since it may take some time.

# This code may take longer since it calls the wrapped processor which might do anything.
# It shouldn't be inside the lock to avoid blocking other threads.
# Since it's not in the lock, it shouldn't touch self.traces or its contents.
if buffer is None:
super().on_start(span, parent_context)
elif dropped:
self.push_buffer(buffer)

def on_end(self, span: ReadableSpan) -> None:
# This has a very similar structure and reasoning to on_start.

dropped = False
buffer = None

with self.lock:
if span.context: # pragma: no branch
trace_id = span.context.trace_id
buffer = self.traces.get(trace_id)
if buffer is not None:
buffer.ended.append(span)
dropped = self.check_span(span, buffer)
if span.parent is None:
# This is the root span, so the trace is hopefully complete.
# Delete the buffer to save memory.
self.traces.pop(trace_id, None)

if buffer is None:
super().on_end(span)
elif dropped:
self.push_buffer(buffer)

def check_span(self, span: ReadableSpan, buffer: TraceBuffer) -> bool:
"""If the span meets the sampling criteria, drop the buffer and return True. Otherwise, return False."""
# span.end_time and span.start_time are in nanoseconds and can be None.
if (span.end_time or span.start_time or 0) - (buffer.first_span.start_time or float('inf')) > self.duration:
self.drop_buffer(buffer)
return True

attributes = span.attributes or {}
level = attributes.get(ATTRIBUTES_LOG_LEVEL_NUM_KEY)
if not isinstance(level, int):
level = LEVEL_NUMBERS['info']
if level >= self.level:
self.drop_buffer(buffer)
return True

return False

def drop_buffer(self, buffer: TraceBuffer) -> None:
span_context = buffer.first_span.context
assert span_context is not None
del self.traces[span_context.trace_id]

def push_buffer(self, buffer: TraceBuffer) -> None:
for started in buffer.started:
super().on_start(*started)
for span in buffer.ended:
super().on_end(span)
30 changes: 26 additions & 4 deletions logfire/_internal/exporters/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Any, Dict, Optional, Sequence
from __future__ import annotations

from typing import Any, Sequence

from opentelemetry import context
from opentelemetry.sdk.metrics.export import AggregationTemporality, MetricExporter, MetricExportResult, MetricsData
from opentelemetry.sdk.metrics.view import Aggregation
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult


Expand All @@ -28,8 +31,8 @@ class WrapperMetricExporter(MetricExporter):
def __init__(
self,
exporter: MetricExporter,
preferred_temporality: Optional[Dict[type, AggregationTemporality]] = None,
preferred_aggregation: Optional[Dict[type, Aggregation]] = None,
preferred_temporality: dict[type, AggregationTemporality] | None = None,
preferred_aggregation: dict[type, Aggregation] | None = None,
) -> None:
super().__init__(preferred_temporality=preferred_temporality, preferred_aggregation=preferred_aggregation) # type: ignore
self.wrapped_exporter = exporter
Expand All @@ -42,3 +45,22 @@ def force_flush(self, timeout_millis: float = 10_000) -> bool:

def shutdown(self, timeout_millis: float = 30_000, **kwargs: Any) -> None:
self.wrapped_exporter.shutdown(timeout_millis, **kwargs) # type: ignore


class WrapperSpanProcessor(SpanProcessor):
"""A base class for SpanProcessors that wrap another processor."""

def __init__(self, processor: SpanProcessor) -> None:
self.processor = processor

def on_start(self, span: Span, parent_context: context.Context | None = None) -> None:
self.processor.on_start(span, parent_context)

def on_end(self, span: ReadableSpan) -> None:
self.processor.on_end(span)

def shutdown(self) -> None:
self.processor.shutdown()

def force_flush(self, timeout_millis: int = 30000) -> bool:
return self.processor.force_flush(timeout_millis) # pragma: no cover
Loading