diff --git a/logfire/__init__.py b/logfire/__init__.py index 60b2c160e..e69105a68 100644 --- a/logfire/__init__.py +++ b/logfire/__init__.py @@ -14,6 +14,7 @@ ) from ._internal.constants import LevelName from ._internal.exporters.file import load_file as load_spans_from_file +from ._internal.exporters.tail_sampling import TailSamplingOptions from ._internal.main import Logfire, LogfireSpan from ._internal.scrubbing import ScrubMatch from ._internal.utils import suppress_instrumentation @@ -111,4 +112,5 @@ def loguru_handler() -> dict[str, Any]: 'suppress_instrumentation', 'StructlogProcessor', 'LogfireLoggingHandler', + 'TailSamplingOptions', ) diff --git a/logfire/_internal/config.py b/logfire/_internal/config.py index 8a7db579d..c8cb0f5e4 100644 --- a/logfire/_internal/config.py +++ b/logfire/_internal/config.py @@ -67,9 +67,10 @@ from .exporters.fallback import FallbackSpanExporter from .exporters.file import FileSpanExporter from .exporters.otlp import OTLPExporterHttpSession, RetryFewerSpansSpanExporter -from .exporters.processor_wrapper import SpanProcessorWrapper +from .exporters.processor_wrapper import MainSpanProcessorWrapper from .exporters.quiet_metrics import QuietMetricExporter from .exporters.remove_pending import RemovePendingSpansExporter +from .exporters.tail_sampling import TailSamplingOptions, TailSamplingProcessor from .integrations.executors import instrument_executors from .metrics import ProxyMeterProvider, configure_metrics from .scrubbing import Scrubber, ScrubCallback @@ -159,6 +160,7 @@ def configure( scrubbing_patterns: Sequence[str] | None = None, scrubbing_callback: ScrubCallback | None = None, inspect_arguments: bool | None = None, + tail_sampling: TailSamplingOptions | None = None, ) -> None: """Configure the logfire SDK. @@ -216,6 +218,7 @@ def configure( [f-string magic](https://docs.pydantic.dev/logfire/guides/onboarding_checklist/add_manual_tracing/#f-strings). If `None` uses the `LOGFIRE_INSPECT_ARGUMENTS` environment variable. Defaults to `True` if and only if the Python version is at least 3.11. + tail_sampling: Tail sampling options. Not ready for general use. """ if processors is not None: # pragma: no cover raise ValueError( @@ -253,6 +256,7 @@ def configure( scrubbing_patterns=scrubbing_patterns, scrubbing_callback=scrubbing_callback, inspect_arguments=inspect_arguments, + tail_sampling=tail_sampling, ) @@ -332,6 +336,12 @@ class _LogfireConfigData: scrubbing_callback: ScrubCallback | None """A function that is called for each match found by the scrubber.""" + inspect_arguments: bool + """Whether to enable f-string magic""" + + tail_sampling: TailSamplingOptions | None + """Tail sampling options""" + def _load_configuration( self, # note that there are no defaults here so that the only place @@ -360,6 +370,7 @@ def _load_configuration( scrubbing_patterns: Sequence[str] | None, scrubbing_callback: ScrubCallback | None, inspect_arguments: bool | None, + tail_sampling: TailSamplingOptions | None, ) -> None: """Merge the given parameters with the environment variables file configurations.""" param_manager = ParamManager.create(config_dir) @@ -414,6 +425,12 @@ def _load_configuration( if get_version(pydantic.__version__) < get_version('2.5.0'): # pragma: no cover raise RuntimeError('The Pydantic plugin requires Pydantic 2.5.0 or newer.') + + if isinstance(tail_sampling, dict): + # This is particularly for deserializing from a dict as in executors.py + tail_sampling = TailSamplingOptions(**tail_sampling) # type: ignore + self.tail_sampling = tail_sampling + self.fast_shutdown = fast_shutdown self.id_generator = id_generator or RandomIdGenerator() @@ -457,6 +474,7 @@ def __init__( scrubbing_patterns: Sequence[str] | None = None, scrubbing_callback: ScrubCallback | None = None, inspect_arguments: bool | None = None, + tail_sampling: TailSamplingOptions | None = None, ) -> None: """Create a new LogfireConfig. @@ -490,6 +508,7 @@ def __init__( scrubbing_patterns=scrubbing_patterns, scrubbing_callback=scrubbing_callback, inspect_arguments=inspect_arguments, + tail_sampling=tail_sampling, ) # initialize with no-ops so that we don't impact OTEL's global config just because logfire is installed # that is, we defer setting logfire as the otel global config until `configure` is called @@ -527,6 +546,7 @@ def configure( scrubbing_patterns: Sequence[str] | None, scrubbing_callback: ScrubCallback | None, inspect_arguments: bool | None, + tail_sampling: TailSamplingOptions | None, ) -> None: with self._lock: self._initialized = False @@ -554,6 +574,7 @@ def configure( scrubbing_patterns, scrubbing_callback, inspect_arguments, + tail_sampling, ) self.initialize() @@ -592,8 +613,15 @@ def _initialize(self) -> ProxyTracerProvider: # Both recommend generating a UUID. resource = Resource({ResourceAttributes.SERVICE_INSTANCE_ID: uuid4().hex}).merge(resource) + # Avoid using the usual sampler if we're using tail-based sampling. + # The TailSamplingProcessor will handle the random sampling part as well. + sampler = ( + ParentBasedTraceIdRatio(self.trace_sample_rate) + if self.trace_sample_rate < 1 and self.tail_sampling is None + else None + ) tracer_provider = SDKTracerProvider( - sampler=ParentBasedTraceIdRatio(self.trace_sample_rate), + sampler=sampler, resource=resource, id_generator=self.id_generator, ) @@ -607,7 +635,16 @@ def add_span_processor(span_processor: SpanProcessor) -> None: # Most span processors added to the tracer provider should also be recorded in the `processors` list # so that they can be used by the final pending span processor. # This means that `tracer_provider.add_span_processor` should only appear in two places. - span_processor = SpanProcessorWrapper(span_processor, self.scrubber) + if self.tail_sampling: + span_processor = TailSamplingProcessor( + span_processor, + self.tail_sampling, + # If self.trace_sample_rate < 1 then that ratio of spans should be included randomly by this. + # In that case the tracer provider doesn't need to do any sampling, see above. + # Otherwise we're not using any random sampling, so 0% of spans should be included 'randomly'. + self.trace_sample_rate if self.trace_sample_rate < 1 else 0, + ) + span_processor = MainSpanProcessorWrapper(span_processor, self.scrubber) tracer_provider.add_span_processor(span_processor) processors.append(span_processor) diff --git a/logfire/_internal/exporters/processor_wrapper.py b/logfire/_internal/exporters/processor_wrapper.py index 06e076045..e7c7ed05b 100644 --- a/logfire/_internal/exporters/processor_wrapper.py +++ b/logfire/_internal/exporters/processor_wrapper.py @@ -16,9 +16,10 @@ ) from ..scrubbing import Scrubber from ..utils import ReadableSpanDict, is_instrumentation_suppressed, span_to_dict, truncate_string +from .wrapper import WrapperSpanProcessor -class SpanProcessorWrapper(SpanProcessor): +class MainSpanProcessorWrapper(WrapperSpanProcessor): """Wrapper around other processors to intercept starting and ending spans with our own global logic. Suppresses starting/ending if the current context has a `suppress_instrumentation` value. @@ -26,7 +27,7 @@ class SpanProcessorWrapper(SpanProcessor): """ def __init__(self, processor: SpanProcessor, scrubber: Scrubber) -> None: - self.processor = processor + super().__init__(processor) self.scrubber = scrubber def on_start( @@ -37,7 +38,7 @@ def on_start( if is_instrumentation_suppressed(): return _set_log_level_on_asgi_send_receive_spans(span) - self.processor.on_start(span, parent_context) + super().on_start(span, parent_context) def on_end(self, span: ReadableSpan) -> None: if is_instrumentation_suppressed(): @@ -47,13 +48,7 @@ def on_end(self, span: ReadableSpan) -> None: _tweak_http_spans(span_dict) self.scrubber.scrub_span(span_dict) span = ReadableSpan(**span_dict) - self.processor.on_end(span) - - def shutdown(self) -> None: - self.processor.shutdown() - - def force_flush(self, timeout_millis: int = 30000) -> bool: - return self.processor.force_flush(timeout_millis) # pragma: no cover + super().on_end(span) def _set_log_level_on_asgi_send_receive_spans(span: Span) -> None: diff --git a/logfire/_internal/exporters/tail_sampling.py b/logfire/_internal/exporters/tail_sampling.py new file mode 100644 index 000000000..19c0ee928 --- /dev/null +++ b/logfire/_internal/exporters/tail_sampling.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import random +from dataclasses import dataclass +from functools import cached_property +from threading import Lock + +from opentelemetry import context +from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor + +from logfire._internal.constants import ( + ATTRIBUTES_LOG_LEVEL_NUM_KEY, + LEVEL_NUMBERS, + ONE_SECOND_IN_NANOSECONDS, + LevelName, +) +from logfire._internal.exporters.wrapper import WrapperSpanProcessor + + +@dataclass +class TailSamplingOptions: + level: LevelName | None = 'notice' + """ + Include all spans/logs with level greater than or equal to this level. + If None, spans are not included based on level. + """ + + duration: float | None = 1.0 + """ + Include all spans/logs with duration greater than this duration in seconds. + If None, spans are not included based on duration. + """ + + +@dataclass +class TraceBuffer: + """Arguments of `on_start` and `on_end` for spans in a single trace.""" + + started: list[tuple[Span, context.Context | None]] + ended: list[ReadableSpan] + + @cached_property + def first_span(self) -> Span: + return self.started[0][0] + + +class TailSamplingProcessor(WrapperSpanProcessor): + """Passes spans to the wrapped processor if any span in a trace meets the sampling criteria.""" + + def __init__(self, processor: SpanProcessor, options: TailSamplingOptions, random_rate: float) -> None: + super().__init__(processor) + self.duration: float = ( + float('inf') if options.duration is None else options.duration * ONE_SECOND_IN_NANOSECONDS + ) + self.level: float | int = float('inf') if options.level is None else LEVEL_NUMBERS[options.level] + self.random_rate = random_rate + + # A TraceBuffer is typically created for each new trace. + # If a span meets the sampling criteria, the buffer is dropped and all spans within are pushed + # to the wrapped processor. + # So when more spans arrive and there's no buffer, they get passed through immediately. + self.traces: dict[int, TraceBuffer] = {} + + # Code that touches self.traces and its contents should be protected by this lock. + self.lock = Lock() + + def on_start(self, span: Span, parent_context: context.Context | None = None) -> None: + dropped = False + buffer = None + + with self.lock: + # span.context could supposedly be None, not sure how. + if span.context: # pragma: no branch + trace_id = span.context.trace_id + # If span.parent is None, it's the root span of a trace. + # If random.random() <= self.random_rate, immediately include this trace, + # meaning no buffer for it. + if span.parent is None and random.random() > self.random_rate: + self.traces[trace_id] = TraceBuffer([], []) + + buffer = self.traces.get(trace_id) + if buffer is not None: + # This trace's spans haven't met the criteria yet, so add this span to the buffer. + buffer.started.append((span, parent_context)) + dropped = self.check_span(span, buffer) + # The opposite case is handled outside the lock since it may take some time. + + # This code may take longer since it calls the wrapped processor which might do anything. + # It shouldn't be inside the lock to avoid blocking other threads. + # Since it's not in the lock, it shouldn't touch self.traces or its contents. + if buffer is None: + super().on_start(span, parent_context) + elif dropped: + self.push_buffer(buffer) + + def on_end(self, span: ReadableSpan) -> None: + # This has a very similar structure and reasoning to on_start. + + dropped = False + buffer = None + + with self.lock: + if span.context: # pragma: no branch + trace_id = span.context.trace_id + buffer = self.traces.get(trace_id) + if buffer is not None: + buffer.ended.append(span) + dropped = self.check_span(span, buffer) + if span.parent is None: + # This is the root span, so the trace is hopefully complete. + # Delete the buffer to save memory. + self.traces.pop(trace_id, None) + + if buffer is None: + super().on_end(span) + elif dropped: + self.push_buffer(buffer) + + def check_span(self, span: ReadableSpan, buffer: TraceBuffer) -> bool: + """If the span meets the sampling criteria, drop the buffer and return True. Otherwise, return False.""" + # span.end_time and span.start_time are in nanoseconds and can be None. + if (span.end_time or span.start_time or 0) - (buffer.first_span.start_time or float('inf')) > self.duration: + self.drop_buffer(buffer) + return True + + attributes = span.attributes or {} + level = attributes.get(ATTRIBUTES_LOG_LEVEL_NUM_KEY) + if not isinstance(level, int): + level = LEVEL_NUMBERS['info'] + if level >= self.level: + self.drop_buffer(buffer) + return True + + return False + + def drop_buffer(self, buffer: TraceBuffer) -> None: + span_context = buffer.first_span.context + assert span_context is not None + del self.traces[span_context.trace_id] + + def push_buffer(self, buffer: TraceBuffer) -> None: + for started in buffer.started: + super().on_start(*started) + for span in buffer.ended: + super().on_end(span) diff --git a/logfire/_internal/exporters/wrapper.py b/logfire/_internal/exporters/wrapper.py index d85061c85..433b5a93a 100644 --- a/logfire/_internal/exporters/wrapper.py +++ b/logfire/_internal/exporters/wrapper.py @@ -1,8 +1,11 @@ -from typing import Any, Dict, Optional, Sequence +from __future__ import annotations +from typing import Any, Sequence + +from opentelemetry import context from opentelemetry.sdk.metrics.export import AggregationTemporality, MetricExporter, MetricExportResult, MetricsData from opentelemetry.sdk.metrics.view import Aggregation -from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult @@ -28,8 +31,8 @@ class WrapperMetricExporter(MetricExporter): def __init__( self, exporter: MetricExporter, - preferred_temporality: Optional[Dict[type, AggregationTemporality]] = None, - preferred_aggregation: Optional[Dict[type, Aggregation]] = None, + preferred_temporality: dict[type, AggregationTemporality] | None = None, + preferred_aggregation: dict[type, Aggregation] | None = None, ) -> None: super().__init__(preferred_temporality=preferred_temporality, preferred_aggregation=preferred_aggregation) # type: ignore self.wrapped_exporter = exporter @@ -42,3 +45,22 @@ def force_flush(self, timeout_millis: float = 10_000) -> bool: def shutdown(self, timeout_millis: float = 30_000, **kwargs: Any) -> None: self.wrapped_exporter.shutdown(timeout_millis, **kwargs) # type: ignore + + +class WrapperSpanProcessor(SpanProcessor): + """A base class for SpanProcessors that wrap another processor.""" + + def __init__(self, processor: SpanProcessor) -> None: + self.processor = processor + + def on_start(self, span: Span, parent_context: context.Context | None = None) -> None: + self.processor.on_start(span, parent_context) + + def on_end(self, span: ReadableSpan) -> None: + self.processor.on_end(span) + + def shutdown(self) -> None: + self.processor.shutdown() + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return self.processor.force_flush(timeout_millis) # pragma: no cover diff --git a/tests/test_configure.py b/tests/test_configure.py index 4ba05d17c..66ee60dfe 100644 --- a/tests/test_configure.py +++ b/tests/test_configure.py @@ -776,13 +776,14 @@ def test_config_serializable(): send_to_logfire=False, pydantic_plugin=logfire.PydanticPlugin(record='all'), console=logfire.ConsoleOptions(verbose=True), + tail_sampling=logfire.TailSamplingOptions(), ) for field in dataclasses.fields(GLOBAL_CONFIG): # Check that the full set of dataclass fields is known. # If a new field appears here, make sure it gets deserialized properly in configure, and tested here. assert dataclasses.is_dataclass(getattr(GLOBAL_CONFIG, field.name)) == ( - field.name in ['pydantic_plugin', 'console'] + field.name in ['pydantic_plugin', 'console', 'tail_sampling'] ) serialized = serialize_config() @@ -799,6 +800,7 @@ def normalize(s: dict[str, Any]) -> dict[str, Any]: assert isinstance(GLOBAL_CONFIG.pydantic_plugin, logfire.PydanticPlugin) assert isinstance(GLOBAL_CONFIG.console, logfire.ConsoleOptions) + assert isinstance(GLOBAL_CONFIG.tail_sampling, logfire.TailSamplingOptions) def test_config_serializable_console_false(): diff --git a/tests/test_tail_sampling.py b/tests/test_tail_sampling.py new file mode 100644 index 000000000..5e24d1eb7 --- /dev/null +++ b/tests/test_tail_sampling.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import itertools +from typing import Any +from unittest.mock import Mock + +import pytest +from inline_snapshot import snapshot + +import logfire +from logfire.testing import TestExporter + + +def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): + # Use the default TailSamplingOptions.level of 'notice'. + # Set duration to None to not include spans with a long duration. + logfire.configure(**config_kwargs, tail_sampling=logfire.TailSamplingOptions(duration=None)) + + with logfire.span('ignored span'): + logfire.debug('ignored debug') + logfire.notice('notice') + with logfire.span('ignored span'): + logfire.info('ignored info') + logfire.warn('warn') + + # Include this whole tree because of the inner error + with logfire.span('span'): + with logfire.span('span2'): + logfire.error('error') + + # Include this whole tree because of the outer fatal + with logfire.span('span3', _level='fatal'): + logfire.trace('trace') + + assert exporter.exported_spans_as_dict(_include_pending_spans=True) == snapshot( + [ + { + 'name': 'notice', + 'context': {'trace_id': 2, 'span_id': 4, 'is_remote': False}, + 'parent': None, + 'start_time': 4000000000, + 'end_time': 4000000000, + 'attributes': { + 'logfire.span_type': 'log', + 'logfire.level_num': 10, + 'logfire.msg_template': 'notice', + 'logfire.msg': 'notice', + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_level_sampling', + 'code.lineno': 123, + }, + }, + { + 'name': 'warn', + 'context': {'trace_id': 4, 'span_id': 8, 'is_remote': False}, + 'parent': None, + 'start_time': 8000000000, + 'end_time': 8000000000, + 'attributes': { + 'logfire.span_type': 'log', + 'logfire.level_num': 13, + 'logfire.msg_template': 'warn', + 'logfire.msg': 'warn', + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_level_sampling', + 'code.lineno': 123, + }, + }, + { + 'name': 'span (pending)', + 'context': {'trace_id': 5, 'span_id': 10, 'is_remote': False}, + 'parent': {'trace_id': 5, 'span_id': 9, 'is_remote': False}, + 'start_time': 9000000000, + 'end_time': 9000000000, + 'attributes': { + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_level_sampling', + 'code.lineno': 123, + 'logfire.msg_template': 'span', + 'logfire.msg': 'span', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000000', + }, + }, + { + 'name': 'span2 (pending)', + 'context': {'trace_id': 5, 'span_id': 12, 'is_remote': False}, + 'parent': {'trace_id': 5, 'span_id': 11, 'is_remote': False}, + 'start_time': 10000000000, + 'end_time': 10000000000, + 'attributes': { + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_level_sampling', + 'code.lineno': 123, + 'logfire.msg_template': 'span2', + 'logfire.msg': 'span2', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000009', + }, + }, + { + 'name': 'error', + 'context': {'trace_id': 5, 'span_id': 13, 'is_remote': False}, + 'parent': {'trace_id': 5, 'span_id': 11, 'is_remote': False}, + 'start_time': 11000000000, + 'end_time': 11000000000, + 'attributes': { + 'logfire.span_type': 'log', + 'logfire.level_num': 17, + 'logfire.msg_template': 'error', + 'logfire.msg': 'error', + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_level_sampling', + 'code.lineno': 123, + }, + }, + { + 'name': 'span2', + 'context': {'trace_id': 5, 'span_id': 11, 'is_remote': False}, + 'parent': {'trace_id': 5, 'span_id': 9, 'is_remote': False}, + 'start_time': 10000000000, + 'end_time': 12000000000, + 'attributes': { + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_level_sampling', + 'code.lineno': 123, + 'logfire.msg_template': 'span2', + 'logfire.msg': 'span2', + 'logfire.span_type': 'span', + }, + }, + { + 'name': 'span', + 'context': {'trace_id': 5, 'span_id': 9, 'is_remote': False}, + 'parent': None, + 'start_time': 9000000000, + 'end_time': 13000000000, + 'attributes': { + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_level_sampling', + 'code.lineno': 123, + 'logfire.msg_template': 'span', + 'logfire.msg': 'span', + 'logfire.span_type': 'span', + }, + }, + { + 'name': 'span3 (pending)', + 'context': {'trace_id': 6, 'span_id': 15, 'is_remote': False}, + 'parent': {'trace_id': 6, 'span_id': 14, 'is_remote': False}, + 'start_time': 14000000000, + 'end_time': 14000000000, + 'attributes': { + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_level_sampling', + 'code.lineno': 123, + 'logfire.msg_template': 'span3', + 'logfire.msg': 'span3', + 'logfire.level_num': 21, + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000000', + }, + }, + { + 'name': 'trace', + 'context': {'trace_id': 6, 'span_id': 16, 'is_remote': False}, + 'parent': {'trace_id': 6, 'span_id': 14, 'is_remote': False}, + 'start_time': 15000000000, + 'end_time': 15000000000, + 'attributes': { + 'logfire.span_type': 'log', + 'logfire.level_num': 1, + 'logfire.msg_template': 'trace', + 'logfire.msg': 'trace', + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_level_sampling', + 'code.lineno': 123, + }, + }, + { + 'name': 'span3', + 'context': {'trace_id': 6, 'span_id': 14, 'is_remote': False}, + 'parent': None, + 'start_time': 14000000000, + 'end_time': 16000000000, + 'attributes': { + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_level_sampling', + 'code.lineno': 123, + 'logfire.msg_template': 'span3', + 'logfire.msg': 'span3', + 'logfire.level_num': 21, + 'logfire.span_type': 'span', + }, + }, + ] + ) + + +def test_duration_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): + # Set level to None to not include spans merely based on a high level. + logfire.configure(**config_kwargs, tail_sampling=logfire.TailSamplingOptions(level=None, duration=3)) + + logfire.error('short1') + with logfire.span('span'): + logfire.error('short2') + + # This has a total fake duration of 3s, which doesn't get included because we use >, not >=. + with logfire.span('span2'): + with logfire.span('span3', _level='error'): + pass + + # This has a total fake duration of 4s, which does get included. + with logfire.span('span4'): + with logfire.span('span5'): + logfire.info('long1') + + assert exporter.exported_spans_as_dict(_include_pending_spans=True) == snapshot( + [ + { + 'name': 'span4 (pending)', + 'context': {'trace_id': 4, 'span_id': 10, 'is_remote': False}, + 'parent': {'trace_id': 4, 'span_id': 9, 'is_remote': False}, + 'start_time': 9000000000, + 'end_time': 9000000000, + 'attributes': { + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_duration_sampling', + 'code.lineno': 123, + 'logfire.msg_template': 'span4', + 'logfire.msg': 'span4', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000000', + }, + }, + { + 'name': 'span5 (pending)', + 'context': {'trace_id': 4, 'span_id': 12, 'is_remote': False}, + 'parent': {'trace_id': 4, 'span_id': 11, 'is_remote': False}, + 'start_time': 10000000000, + 'end_time': 10000000000, + 'attributes': { + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_duration_sampling', + 'code.lineno': 123, + 'logfire.msg_template': 'span5', + 'logfire.msg': 'span5', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000009', + }, + }, + { + 'name': 'long1', + 'context': {'trace_id': 4, 'span_id': 13, 'is_remote': False}, + 'parent': {'trace_id': 4, 'span_id': 11, 'is_remote': False}, + 'start_time': 11000000000, + 'end_time': 11000000000, + 'attributes': { + 'logfire.span_type': 'log', + 'logfire.level_num': 9, + 'logfire.msg_template': 'long1', + 'logfire.msg': 'long1', + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_duration_sampling', + 'code.lineno': 123, + }, + }, + { + 'name': 'span5', + 'context': {'trace_id': 4, 'span_id': 11, 'is_remote': False}, + 'parent': {'trace_id': 4, 'span_id': 9, 'is_remote': False}, + 'start_time': 10000000000, + 'end_time': 12000000000, + 'attributes': { + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_duration_sampling', + 'code.lineno': 123, + 'logfire.msg_template': 'span5', + 'logfire.msg': 'span5', + 'logfire.span_type': 'span', + }, + }, + { + 'name': 'span4', + 'context': {'trace_id': 4, 'span_id': 9, 'is_remote': False}, + 'parent': None, + 'start_time': 9000000000, + 'end_time': 13000000000, + 'attributes': { + 'code.filepath': 'test_tail_sampling.py', + 'code.function': 'test_duration_sampling', + 'code.lineno': 123, + 'logfire.msg_template': 'span4', + 'logfire.msg': 'span4', + 'logfire.span_type': 'span', + }, + }, + ] + ) + + +def test_random_sampling(config_kwargs: dict[str, Any], exporter: TestExporter, monkeypatch: pytest.MonkeyPatch): + logfire.configure( + **config_kwargs, + tail_sampling=logfire.TailSamplingOptions(), + trace_sample_rate=0.3, + ) + # <0.3 a third of the time. + monkeypatch.setattr('random.random', Mock(side_effect=itertools.cycle([0.1, 0.6, 0.9]))) + + # These spans should all be included because the level is above the default. + for _ in range(10): + logfire.error('error') + assert len(exporter.exported_spans) == 10 + + # A third of these spans (i.e. an extra 10) should be included because of the trace_sample_rate. + # None of them meet the tail sampling criteria. + for _ in range(30): + logfire.info('info') + assert len(exporter.exported_spans) == 20