From 63a0950c677f14f67fed4e31940dca513ed1df83 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Tue, 17 Sep 2024 17:39:06 +0200 Subject: [PATCH] Tail sampling (#407) Co-authored-by: Marcelo Trylesinski --- docs/api/sampling.md | 1 + docs/guides/advanced/sampling.md | 312 ++++++++++++++++-- logfire-api/logfire_api/__init__.py | 2 +- logfire-api/logfire_api/__init__.pyi | 4 +- logfire-api/logfire_api/_internal/config.pyi | 17 +- .../logfire_api/_internal/constants.pyi | 6 +- logfire-api/logfire_api/sampling/__init__.pyi | 3 + .../logfire_api/sampling/_tail_sampling.pyi | 89 +++++ logfire/__init__.py | 5 +- logfire/_internal/config.py | 87 +++-- logfire/_internal/config_params.py | 2 +- logfire/_internal/constants.py | 6 +- logfire/_internal/exporters/tail_sampling.py | 145 -------- logfire/sampling/__init__.py | 9 + logfire/sampling/_tail_sampling.py | 280 ++++++++++++++++ mkdocs.yml | 1 + tests/test_configure.py | 8 +- tests/test_logfire.py | 3 +- tests/test_logfire_api.py | 6 +- tests/test_pydantic_plugin.py | 10 +- tests/test_sampling.py | 24 +- tests/test_tail_sampling.py | 236 +++++++++++-- 22 files changed, 965 insertions(+), 291 deletions(-) create mode 100644 docs/api/sampling.md create mode 100644 logfire-api/logfire_api/sampling/__init__.pyi create mode 100644 logfire-api/logfire_api/sampling/_tail_sampling.pyi delete mode 100644 logfire/_internal/exporters/tail_sampling.py create mode 100644 logfire/sampling/__init__.py create mode 100644 logfire/sampling/_tail_sampling.py diff --git a/docs/api/sampling.md b/docs/api/sampling.md new file mode 100644 index 000000000..cc2ebc0a2 --- /dev/null +++ b/docs/api/sampling.md @@ -0,0 +1 @@ +::: logfire.sampling diff --git a/docs/guides/advanced/sampling.md b/docs/guides/advanced/sampling.md index 571507594..9973c2bf9 100644 --- a/docs/guides/advanced/sampling.md +++ b/docs/guides/advanced/sampling.md @@ -1,38 +1,312 @@ # Sampling -Sampling is the practice of discarding some traces or spans in order to reduce the amount of data that needs to be stored and analyzed. Sampling is a trade-off between cost and completeness of data. +Sampling is the practice of discarding some traces or spans in order to reduce the amount of data that needs to be +stored and analyzed. Sampling is a trade-off between cost and completeness of data. -To configure sampling for the SDK: +_Head sampling_ means the decision to sample is made at the beginning of a trace. This is simpler and more common. -- Set the [`trace_sample_rate`][logfire.configure(trace_sample_rate)] option of [`logfire.configure()`][logfire.configure] to a number between 0 and 1, or -- Set the `LOGFIRE_TRACE_SAMPLE_RATE` environment variable, or -- Set the `trace_sample_rate` config file option. +_Tail sampling_ means the decision to sample is delayed, possibly until the end of a trace. This means there is more +information available to make the decision, but this adds complexity. -See [Configuration](../../reference/configuration.md) for more information. +Sampling usually happens at the trace level, meaning entire traces are kept or discarded. This way the remaining traces +are generally complete. + +## Random head sampling + +Here's an example of randomly sampling 50% of traces: + +```python +import logfire + +logfire.configure(sampling=logfire.SamplingOptions(head=0.5)) + +for x in range(10): + with logfire.span(f'span {x}'): + logfire.info(f'log {x}') +``` + +This outputs something like: + +``` +11:09:29.041 span 0 +11:09:29.041 log 0 +11:09:29.041 span 1 +11:09:29.042 log 1 +11:09:29.042 span 4 +11:09:29.042 log 4 +11:09:29.042 span 5 +11:09:29.042 log 5 +11:09:29.042 span 7 +11:09:29.042 log 7 +``` + +Note that 5 out of 10 traces are kept, and that the child log is kept if and only if the parent span is kept. + +## Tail sampling by level and duration + +Random head sampling often works well, but you may not want to lose any traces which indicate problems. In this case, +you can use tail sampling. Here's a simple example: + +```python +import time + +import logfire + +logfire.configure(sampling=logfire.SamplingOptions.level_or_duration()) + +for x in range(3): + # None of these are logged + with logfire.span('excluded span'): + logfire.info(f'info {x}') + + # All of these are logged + with logfire.span('included span'): + logfire.error(f'error {x}') + +for t in range(1, 10, 2): + with logfire.span(f'span with duration {t}'): + time.sleep(t) +``` + +This outputs something like: + +``` +11:37:45.484 included span +11:37:45.484 error 0 +11:37:45.485 included span +11:37:45.485 error 1 +11:37:45.485 included span +11:37:45.485 error 2 +11:37:49.493 span with duration 5 +11:37:54.499 span with duration 7 +11:38:01.505 span with duration 9 +``` + +[`logfire.SamplingOptions.level_or_duration()`][logfire.sampling.SamplingOptions.level_or_duration] creates an instance +of [`logfire.SamplingOptions`][logfire.sampling.SamplingOptions] with simple tail sampling. With no arguments, +it means that a trace will be included if and only if it has at least one span/log that: + +1. has a log level greater than `info` (the default of any span), or +2. has a duration greater than 5 seconds. + +This way you won't lose information about warnings/errors or long-running operations. You can customize what to keep +with the `level_threshold` and `duration_threshold` arguments. + +## Combining head and tail sampling + +You can combine head and tail sampling. For example: + +```python +import logfire + +logfire.configure(sampling=logfire.SamplingOptions.level_or_duration(head=0.1)) +``` + +This will only keep 10% of traces, even if they have a high log level or duration. Traces that don't meet the tail +sampling criteria will be discarded every time. + +## Keeping a fraction of all traces + +To keep some traces even if they don't meet the tail sampling criteria, you can use the `background_rate` argument. For +example, this script: + +```python +import logfire + +logfire.configure(sampling=logfire.SamplingOptions.level_or_duration(background_rate=0.3)) + +for x in range(10): + logfire.info(f'info {x}') +for x in range(5): + logfire.error(f'error {x}') +``` + +will output something like: + +``` +12:24:40.293 info 2 +12:24:40.293 info 3 +12:24:40.293 info 7 +12:24:40.294 error 0 +12:24:40.294 error 1 +12:24:40.294 error 2 +12:24:40.294 error 3 +12:24:40.295 error 4 +``` + +i.e. about 30% of the info logs and 100% of the error logs are kept. + +(Technical note: the trace ID is compared against the head and background rates to determine inclusion, so the +probabilities don't depend on the number of spans in the trace, and the rates give the probabilities directly without +needing any further calculations. For example, with a head sample rate of `0.6` and a background rate of `0.3`, the +chance of a non-notable trace being included is `0.3`, not `0.6 * 0.3`.) + +## Caveats of tail sampling + +### Memory usage + +For tail sampling to work, all the spans in a trace must be kept in memory until either the trace is included by +sampling or the trace is completed and discarded. In the above example, the spans named `included span` don't have a +high enough level to be included, so they are kept in memory until the error logs cause the entire trace to be included. +This means that traces with a large number of spans can consume a lot of memory, whereas without tail sampling the spans +would be regularly exported and freed from memory without waiting for the rest of the trace. + +In practice this is usually OK, because such large traces will usually exceed the duration threshold, at which point the +trace will be included and the spans will be exported and freed. This works because the duration is measured as the time +between the start of the trace and the start/end of the most recent span, so the tail sampler can know that a span will +exceed the duration threshold even before it's complete. For example, running this script: ```python +import time + import logfire -logfire.configure(trace_sample_rate=0.5) +logfire.configure(sampling=logfire.SamplingOptions.level_or_duration()) + +with logfire.span('span'): + for x in range(1, 10): + time.sleep(1) + logfire.info(f'info {x}') +``` + +will do nothing for the first 5 seconds, before suddenly logging all this at once: -with logfire.span("my_span"): # This span will be sampled 50% of the time - pass ``` +12:29:43.063 span +12:29:44.065 info 1 +12:29:45.066 info 2 +12:29:46.072 info 3 +12:29:47.076 info 4 +12:29:48.082 info 5 +``` + +followed by additional logs once per second. This is despite the fact that at this stage the outer span hasn't completed +yet and the inner logs each have 0 duration. + +However, memory usage can still be a problem in any of the following cases: + +- The duration threshold is set to a high value +- Spans are produced extremely rapidly +- Spans contain large attributes - +1. Keep 50% of traces with duration >= 1 second +2. `span_info.level` is a [special object][logfire.sampling.SpanLevel] that can be compared to log level names +3. Keep 30% of traces with a warning or error and with duration < 1 second +4. Keep 10% of other traces +5. Discard 50% of traces at the beginning to reduce the overhead of generating spans. This is optional, but improves + performance, and we know that `get_tail_sample_rate` will always return at most 0.5 so the other 50% of traces will + be discarded anyway. The probabilities are not independent - this will not discard traces that would otherwise have + been kept by tail sampling. diff --git a/logfire-api/logfire_api/__init__.py b/logfire-api/logfire_api/__init__.py index 02a4c60d4..ba4d972ae 100644 --- a/logfire-api/logfire_api/__init__.py +++ b/logfire-api/logfire_api/__init__.py @@ -173,7 +173,7 @@ def suppress_instrumentation(): class ConsoleOptions: def __init__(self, *args, **kwargs) -> None: ... - class TailSamplingOptions: + class SamplingOptions: def __init__(self, *args, **kwargs) -> None: ... class ScrubbingOptions: diff --git a/logfire-api/logfire_api/__init__.pyi b/logfire-api/logfire_api/__init__.pyi index c7de505b8..084d2d7cc 100644 --- a/logfire-api/logfire_api/__init__.pyi +++ b/logfire-api/logfire_api/__init__.pyi @@ -3,15 +3,15 @@ from ._internal.auto_trace.rewrite_ast import no_auto_trace as no_auto_trace from ._internal.config import ConsoleOptions as ConsoleOptions, METRICS_PREFERRED_TEMPORALITY as METRICS_PREFERRED_TEMPORALITY, PydanticPlugin as PydanticPlugin, configure as configure from ._internal.constants import LevelName as LevelName from ._internal.exporters.file import load_file as load_spans_from_file -from ._internal.exporters.tail_sampling import TailSamplingOptions as TailSamplingOptions from ._internal.main import Logfire as Logfire, LogfireSpan as LogfireSpan from ._internal.scrubbing import ScrubMatch as ScrubMatch, ScrubbingOptions as ScrubbingOptions from ._internal.utils import suppress_instrumentation as suppress_instrumentation from .integrations.logging import LogfireLoggingHandler as LogfireLoggingHandler from .integrations.structlog import LogfireProcessor as StructlogProcessor from .version import VERSION as VERSION +from logfire.sampling import SamplingOptions as SamplingOptions -__all__ = ['Logfire', 'LogfireSpan', 'LevelName', 'ConsoleOptions', 'PydanticPlugin', 'configure', 'span', 'instrument', 'log', 'trace', 'debug', 'notice', 'info', 'warn', 'error', 'exception', 'fatal', 'force_flush', 'log_slow_async_callbacks', 'install_auto_tracing', 'instrument_fastapi', 'instrument_openai', 'instrument_anthropic', 'instrument_asyncpg', 'instrument_httpx', 'instrument_celery', 'instrument_requests', 'instrument_psycopg', 'instrument_django', 'instrument_flask', 'instrument_starlette', 'instrument_aiohttp_client', 'instrument_sqlalchemy', 'instrument_redis', 'instrument_pymongo', 'instrument_mysql', 'instrument_system_metrics', 'AutoTraceModule', 'with_tags', 'with_settings', 'shutdown', 'load_spans_from_file', 'no_auto_trace', 'METRICS_PREFERRED_TEMPORALITY', 'ScrubMatch', 'ScrubbingOptions', 'VERSION', 'suppress_instrumentation', 'StructlogProcessor', 'LogfireLoggingHandler', 'TailSamplingOptions'] +__all__ = ['Logfire', 'LogfireSpan', 'LevelName', 'ConsoleOptions', 'PydanticPlugin', 'configure', 'span', 'instrument', 'log', 'trace', 'debug', 'notice', 'info', 'warn', 'error', 'exception', 'fatal', 'force_flush', 'log_slow_async_callbacks', 'install_auto_tracing', 'instrument_fastapi', 'instrument_openai', 'instrument_anthropic', 'instrument_asyncpg', 'instrument_httpx', 'instrument_celery', 'instrument_requests', 'instrument_psycopg', 'instrument_django', 'instrument_flask', 'instrument_starlette', 'instrument_aiohttp_client', 'instrument_sqlalchemy', 'instrument_redis', 'instrument_pymongo', 'instrument_mysql', 'instrument_system_metrics', 'AutoTraceModule', 'with_tags', 'with_settings', 'shutdown', 'load_spans_from_file', 'no_auto_trace', 'METRICS_PREFERRED_TEMPORALITY', 'ScrubMatch', 'ScrubbingOptions', 'VERSION', 'suppress_instrumentation', 'StructlogProcessor', 'LogfireLoggingHandler', 'SamplingOptions'] DEFAULT_LOGFIRE_INSTANCE = Logfire() span = DEFAULT_LOGFIRE_INSTANCE.span diff --git a/logfire-api/logfire_api/_internal/config.pyi b/logfire-api/logfire_api/_internal/config.pyi index 5c261a7ce..825763979 100644 --- a/logfire-api/logfire_api/_internal/config.pyi +++ b/logfire-api/logfire_api/_internal/config.pyi @@ -10,7 +10,6 @@ from .exporters.otlp import OTLPExporterHttpSession as OTLPExporterHttpSession, from .exporters.processor_wrapper import MainSpanProcessorWrapper as MainSpanProcessorWrapper from .exporters.quiet_metrics import QuietMetricExporter as QuietMetricExporter from .exporters.remove_pending import RemovePendingSpansExporter as RemovePendingSpansExporter -from .exporters.tail_sampling import TailSamplingOptions as TailSamplingOptions, TailSamplingProcessor as TailSamplingProcessor from .exporters.test import TestExporter as TestExporter from .integrations.executors import instrument_executors as instrument_executors from .main import FastLogfireSpan as FastLogfireSpan, LogfireSpan as LogfireSpan @@ -23,6 +22,8 @@ from _typeshed import Incomplete from dataclasses import dataclass from functools import cached_property from logfire.exceptions import LogfireConfigError as LogfireConfigError +from logfire.sampling import SamplingOptions as SamplingOptions +from logfire.sampling._tail_sampling import TailSamplingProcessor as TailSamplingProcessor from logfire.version import VERSION as VERSION from opentelemetry import metrics from opentelemetry.sdk.metrics.export import MetricReader as MetricReader @@ -55,7 +56,7 @@ class PydanticPlugin: include: set[str] = ... exclude: set[str] = ... -def configure(*, send_to_logfire: bool | Literal['if-token-present'] | None = None, token: str | None = None, project_name: str | None = None, service_name: str | None = None, service_version: str | None = None, trace_sample_rate: float | None = None, console: ConsoleOptions | Literal[False] | None = None, show_summary: bool | None = None, config_dir: Path | str | None = None, data_dir: Path | str | None = None, base_url: str | None = None, collect_system_metrics: None = None, id_generator: IdGenerator | None = None, ns_timestamp_generator: Callable[[], int] | None = None, processors: None = None, additional_span_processors: Sequence[SpanProcessor] | None = None, metric_readers: None = None, additional_metric_readers: Sequence[MetricReader] | None = None, pydantic_plugin: PydanticPlugin | None = None, fast_shutdown: bool = False, scrubbing_patterns: Sequence[str] | None = None, scrubbing_callback: ScrubCallback | None = None, scrubbing: ScrubbingOptions | Literal[False] | None = None, inspect_arguments: bool | None = None, tail_sampling: TailSamplingOptions | None = None) -> None: +def configure(*, send_to_logfire: bool | Literal['if-token-present'] | None = None, token: str | None = None, project_name: str | None = None, service_name: str | None = None, service_version: str | None = None, trace_sample_rate: float | None = None, console: ConsoleOptions | Literal[False] | None = None, show_summary: bool | None = None, config_dir: Path | str | None = None, data_dir: Path | str | None = None, base_url: str | None = None, collect_system_metrics: None = None, id_generator: IdGenerator | None = None, ns_timestamp_generator: Callable[[], int] | None = None, processors: None = None, additional_span_processors: Sequence[SpanProcessor] | None = None, metric_readers: None = None, additional_metric_readers: Sequence[MetricReader] | None = None, pydantic_plugin: PydanticPlugin | None = None, fast_shutdown: bool = False, scrubbing_patterns: Sequence[str] | None = None, scrubbing_callback: ScrubCallback | None = None, scrubbing: ScrubbingOptions | Literal[False] | None = None, inspect_arguments: bool | None = None, sampling: SamplingOptions | None = None) -> None: """Configure the logfire SDK. Args: @@ -71,8 +72,7 @@ def configure(*, send_to_logfire: bool | Literal['if-token-present'] | None = No service_name: Name of this service. Defaults to the `LOGFIRE_SERVICE_NAME` environment variable. service_version: Version of this service. Defaults to the `LOGFIRE_SERVICE_VERSION` environment variable, or the current git commit hash if available. - trace_sample_rate: Sampling ratio for spans. Defaults to the `LOGFIRE_SAMPLING_RATIO` environment variable, or - the `OTEL_TRACES_SAMPLER_ARG` environment variable, or to `1.0`. + trace_sample_rate: Deprecated, use `sampling` instead. console: Whether to control terminal output. If `None` uses the `LOGFIRE_CONSOLE_*` environment variables, otherwise defaults to `ConsoleOption(colors='auto', indent_spans=True, include_timestamps=True, verbose=False)`. If `False` disables console output. It can also be disabled by setting `LOGFIRE_CONSOLE` environment variable to `false`. @@ -101,7 +101,7 @@ def configure(*, send_to_logfire: bool | Literal['if-token-present'] | None = No [f-string magic](https://docs.pydantic.dev/logfire/guides/onboarding_checklist/add_manual_tracing/#f-strings). If `None` uses the `LOGFIRE_INSPECT_ARGUMENTS` environment variable. Defaults to `True` if and only if the Python version is at least 3.11. - tail_sampling: Tail sampling options. Not ready for general use. + sampling: Sampling options. TODO document this. """ @dataclasses.dataclass @@ -121,7 +121,6 @@ class _LogfireConfigData: project_name: str | None service_name: str service_version: str | None - trace_sample_rate: float console: ConsoleOptions | Literal[False] | None show_summary: bool data_dir: Path @@ -132,17 +131,17 @@ class _LogfireConfigData: fast_shutdown: bool scrubbing: ScrubbingOptions | Literal[False] inspect_arguments: bool - tail_sampling: TailSamplingOptions | None + sampling: SamplingOptions class LogfireConfig(_LogfireConfigData): - def __init__(self, base_url: str | None = None, send_to_logfire: bool | None = None, token: str | None = None, project_name: str | None = None, service_name: str | None = None, service_version: str | None = None, trace_sample_rate: float | None = None, console: ConsoleOptions | Literal[False] | None = None, show_summary: bool | None = None, config_dir: Path | None = None, data_dir: Path | None = None, id_generator: IdGenerator | None = None, ns_timestamp_generator: Callable[[], int] | None = None, additional_span_processors: Sequence[SpanProcessor] | None = None, additional_metric_readers: Sequence[MetricReader] | None = None, pydantic_plugin: PydanticPlugin | None = None, fast_shutdown: bool = False, scrubbing: ScrubbingOptions | Literal[False] | None = None, inspect_arguments: bool | None = None, tail_sampling: TailSamplingOptions | None = None) -> None: + def __init__(self, base_url: str | None = None, send_to_logfire: bool | None = None, token: str | None = None, project_name: str | None = None, service_name: str | None = None, service_version: str | None = None, console: ConsoleOptions | Literal[False] | None = None, show_summary: bool | None = None, config_dir: Path | None = None, data_dir: Path | None = None, id_generator: IdGenerator | None = None, ns_timestamp_generator: Callable[[], int] | None = None, additional_span_processors: Sequence[SpanProcessor] | None = None, additional_metric_readers: Sequence[MetricReader] | None = None, pydantic_plugin: PydanticPlugin | None = None, fast_shutdown: bool = False, scrubbing: ScrubbingOptions | Literal[False] | None = None, inspect_arguments: bool | None = None, sampling: SamplingOptions | None = None) -> None: """Create a new LogfireConfig. Users should never need to call this directly, instead use `logfire.configure`. See `_LogfireConfigData` for parameter documentation. """ - def configure(self, base_url: str | None, send_to_logfire: bool | Literal['if-token-present'] | None, token: str | None, project_name: str | None, service_name: str | None, service_version: str | None, trace_sample_rate: float | None, console: ConsoleOptions | Literal[False] | None, show_summary: bool | None, config_dir: Path | None, data_dir: Path | None, id_generator: IdGenerator | None, ns_timestamp_generator: Callable[[], int] | None, additional_span_processors: Sequence[SpanProcessor] | None, additional_metric_readers: Sequence[MetricReader] | None, pydantic_plugin: PydanticPlugin | None, fast_shutdown: bool, scrubbing: ScrubbingOptions | Literal[False] | None, inspect_arguments: bool | None, tail_sampling: TailSamplingOptions | None) -> None: ... + def configure(self, base_url: str | None, send_to_logfire: bool | Literal['if-token-present'] | None, token: str | None, project_name: str | None, service_name: str | None, service_version: str | None, console: ConsoleOptions | Literal[False] | None, show_summary: bool | None, config_dir: Path | None, data_dir: Path | None, id_generator: IdGenerator | None, ns_timestamp_generator: Callable[[], int] | None, additional_span_processors: Sequence[SpanProcessor] | None, additional_metric_readers: Sequence[MetricReader] | None, pydantic_plugin: PydanticPlugin | None, fast_shutdown: bool, scrubbing: ScrubbingOptions | Literal[False] | None, inspect_arguments: bool | None, sampling: SamplingOptions | None) -> None: ... def initialize(self) -> ProxyTracerProvider: """Configure internals to start exporting traces and metrics.""" def force_flush(self, timeout_millis: int = 30000) -> bool: diff --git a/logfire-api/logfire_api/_internal/constants.pyi b/logfire-api/logfire_api/_internal/constants.pyi index ec4d84314..27d3b7b23 100644 --- a/logfire-api/logfire_api/_internal/constants.pyi +++ b/logfire-api/logfire_api/_internal/constants.pyi @@ -3,9 +3,9 @@ from opentelemetry.util import types as otel_types LOGFIRE_ATTRIBUTES_NAMESPACE: str LevelName: Incomplete -LEVEL_NUMBERS: Incomplete -NUMBER_TO_LEVEL: Incomplete -LOGGING_TO_OTEL_LEVEL_NUMBERS: Incomplete +LEVEL_NUMBERS: dict[LevelName, int] +NUMBER_TO_LEVEL: dict[int, LevelName] +LOGGING_TO_OTEL_LEVEL_NUMBERS: dict[int, int] ATTRIBUTES_LOG_LEVEL_NAME_KEY: Incomplete ATTRIBUTES_LOG_LEVEL_NUM_KEY: Incomplete diff --git a/logfire-api/logfire_api/sampling/__init__.pyi b/logfire-api/logfire_api/sampling/__init__.pyi new file mode 100644 index 000000000..59e1f9785 --- /dev/null +++ b/logfire-api/logfire_api/sampling/__init__.pyi @@ -0,0 +1,3 @@ +from ._tail_sampling import SamplingOptions as SamplingOptions, SpanLevel as SpanLevel, TailSamplingSpanInfo as TailSamplingSpanInfo + +__all__ = ['SamplingOptions', 'SpanLevel', 'TailSamplingSpanInfo'] diff --git a/logfire-api/logfire_api/sampling/_tail_sampling.pyi b/logfire-api/logfire_api/sampling/_tail_sampling.pyi new file mode 100644 index 000000000..51dc1c885 --- /dev/null +++ b/logfire-api/logfire_api/sampling/_tail_sampling.pyi @@ -0,0 +1,89 @@ +from _typeshed import Incomplete +from dataclasses import dataclass +from functools import cached_property +from logfire._internal.constants import ATTRIBUTES_LOG_LEVEL_NUM_KEY as ATTRIBUTES_LOG_LEVEL_NUM_KEY, LEVEL_NUMBERS as LEVEL_NUMBERS, LevelName as LevelName, NUMBER_TO_LEVEL as NUMBER_TO_LEVEL, ONE_SECOND_IN_NANOSECONDS as ONE_SECOND_IN_NANOSECONDS +from logfire._internal.exporters.wrapper import WrapperSpanProcessor as WrapperSpanProcessor +from opentelemetry import context +from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor +from opentelemetry.sdk.trace.sampling import Sampler +from typing import Callable, Literal +from typing_extensions import Self + +@dataclass +class SpanLevel: + """A convenience class for comparing span/log levels. + + Can be compared to log level names (strings) such as 'info' or 'error' using + `<`, `>`, `<=`, or `>=`, so e.g. `level >= 'error'` is valid. + + Will raise an exception if compared to a non-string or an invalid level name. + """ + number: int + @property + def name(self) -> LevelName | None: + """The human-readable name of the level, or `None` if the number is invalid.""" + def __eq__(self, other: object): ... + def __hash__(self): ... + def __lt__(self, other: LevelName): ... + def __gt__(self, other: LevelName): ... + def __ge__(self, other: LevelName): ... + def __le__(self, other: LevelName): ... + +@dataclass +class TraceBuffer: + """Arguments of `SpanProcessor.on_start` and `SpanProcessor.on_end` for spans in a single trace. + + These are stored until either the trace is included by tail sampling or it's completed and discarded. + """ + started: list[tuple[Span, context.Context | None]] + ended: list[ReadableSpan] + @cached_property + def first_span(self) -> Span: ... + @cached_property + def trace_id(self) -> int: ... + +@dataclass +class TailSamplingSpanInfo: + """Argument passed to the [`SamplingOptions.tail`][logfire.sampling.SamplingOptions.tail] callback.""" + span: ReadableSpan + context: context.Context | None + event: Literal['start', 'end'] + buffer: TraceBuffer + @property + def level(self) -> SpanLevel: + """The log level of the span.""" + @property + def duration(self) -> float: + """The time in seconds between the start of the trace and the start/end of this span.""" + +@dataclass +class SamplingOptions: + """Options for [`logfire.configure(sampling=...)`][logfire.configure(sampling)].""" + head: float | Sampler = ... + tail: Callable[[TailSamplingSpanInfo], float] | None = ... + @classmethod + def level_or_duration(cls, *, head: float | Sampler = 1.0, level_threshold: LevelName | None = 'notice', duration_threshold: float | None = 5.0, background_rate: float = 0.0) -> Self: + """Returns a `SamplingOptions` instance that tail samples traces based on their log level and duration. + + If a trace has at least one span/log that has a log level greater than or equal to `level_threshold`, + or if the duration of the whole trace is greater than `duration_threshold` seconds, + then the whole trace will be included. + Otherwise, the probability is `background_rate`. + + The `head` parameter is the same as in the `SamplingOptions` constructor. + """ + +def check_trace_id_ratio(trace_id: int, rate: float) -> bool: ... + +class TailSamplingProcessor(WrapperSpanProcessor): + """Passes spans to the wrapped processor if any span in a trace meets the sampling criteria.""" + get_tail_sample_rate: Incomplete + traces: Incomplete + lock: Incomplete + def __init__(self, processor: SpanProcessor, get_tail_sample_rate: Callable[[TailSamplingSpanInfo], float]) -> None: ... + def on_start(self, span: Span, parent_context: context.Context | None = None) -> None: ... + def on_end(self, span: ReadableSpan) -> None: ... + def check_span(self, span_info: TailSamplingSpanInfo) -> bool: + """If the span meets the sampling criteria, drop the buffer and return True. Otherwise, return False.""" + def drop_buffer(self, buffer: TraceBuffer) -> None: ... + def push_buffer(self, buffer: TraceBuffer) -> None: ... diff --git a/logfire/__init__.py b/logfire/__init__.py index 7e469b417..0b6e1b32d 100644 --- a/logfire/__init__.py +++ b/logfire/__init__.py @@ -4,12 +4,13 @@ from typing import Any +from logfire.sampling import SamplingOptions + from ._internal.auto_trace import AutoTraceModule from ._internal.auto_trace.rewrite_ast import no_auto_trace from ._internal.config import METRICS_PREFERRED_TEMPORALITY, ConsoleOptions, PydanticPlugin, configure from ._internal.constants import LevelName from ._internal.exporters.file import load_file as load_spans_from_file -from ._internal.exporters.tail_sampling import TailSamplingOptions from ._internal.main import Logfire, LogfireSpan from ._internal.scrubbing import ScrubbingOptions, ScrubMatch from ._internal.utils import suppress_instrumentation @@ -131,5 +132,5 @@ def loguru_handler() -> dict[str, Any]: 'suppress_instrumentation', 'StructlogProcessor', 'LogfireLoggingHandler', - 'TailSamplingOptions', + 'SamplingOptions', ) diff --git a/logfire/_internal/config.py b/logfire/_internal/config.py index e9c3f548b..063183ca3 100644 --- a/logfire/_internal/config.py +++ b/logfire/_internal/config.py @@ -46,13 +46,15 @@ from opentelemetry.sdk.trace import SpanProcessor, TracerProvider as SDKTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor from opentelemetry.sdk.trace.id_generator import IdGenerator, RandomIdGenerator -from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio +from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio, Sampler from opentelemetry.semconv.resource import ResourceAttributes from rich.console import Console from rich.prompt import Confirm, Prompt from typing_extensions import Self, Unpack from logfire.exceptions import LogfireConfigError +from logfire.sampling import SamplingOptions +from logfire.sampling._tail_sampling import TailSamplingProcessor from logfire.version import VERSION from .auth import DEFAULT_FILE, DefaultFile, is_logged_in @@ -74,7 +76,6 @@ from .exporters.processor_wrapper import MainSpanProcessorWrapper from .exporters.quiet_metrics import QuietMetricExporter from .exporters.remove_pending import RemovePendingSpansExporter -from .exporters.tail_sampling import TailSamplingOptions, TailSamplingProcessor from .exporters.test import TestExporter from .integrations.executors import instrument_executors from .metrics import ProxyMeterProvider @@ -155,7 +156,6 @@ def configure( # noqa: D417 token: str | None = None, service_name: str | None = None, service_version: str | None = None, - trace_sample_rate: float | None = None, console: ConsoleOptions | Literal[False] | None = None, show_summary: bool | None = None, config_dir: Path | str | None = None, @@ -169,7 +169,7 @@ def configure( # noqa: D417 fast_shutdown: bool = False, scrubbing: ScrubbingOptions | Literal[False] | None = None, inspect_arguments: bool | None = None, - tail_sampling: TailSamplingOptions | None = None, + sampling: SamplingOptions | None = None, **deprecated_kwargs: Unpack[DeprecatedKwargs], ) -> None: """Configure the logfire SDK. @@ -182,8 +182,6 @@ def configure( # noqa: D417 service_name: Name of this service. Defaults to the `LOGFIRE_SERVICE_NAME` environment variable. service_version: Version of this service. Defaults to the `LOGFIRE_SERVICE_VERSION` environment variable, or the current git commit hash if available. - trace_sample_rate: Sampling ratio for spans. Defaults to the `LOGFIRE_SAMPLING_RATIO` environment variable, or - the `OTEL_TRACES_SAMPLER_ARG` environment variable, or to `1.0`. console: Whether to control terminal output. If `None` uses the `LOGFIRE_CONSOLE_*` environment variables, otherwise defaults to `ConsoleOption(colors='auto', indent_spans=True, include_timestamps=True, verbose=False)`. If `False` disables console output. It can also be disabled by setting `LOGFIRE_CONSOLE` environment variable to `false`. @@ -207,7 +205,7 @@ def configure( # noqa: D417 [f-string magic](https://docs.pydantic.dev/logfire/guides/onboarding_checklist/add_manual_tracing/#f-strings). If `None` uses the `LOGFIRE_INSPECT_ARGUMENTS` environment variable. Defaults to `True` if and only if the Python version is at least 3.11. - tail_sampling: Tail sampling options. Not ready for general use. + sampling: Sampling options. See the [sampling guide](https://docs.pydantic.dev/logfire/guides/advanced/sampling/). """ processors = deprecated_kwargs.pop('processors', None) # type: ignore if processors is not None: # pragma: no cover @@ -258,6 +256,20 @@ def configure( # noqa: D417 DeprecationWarning, ) + trace_sample_rate: float | None = deprecated_kwargs.pop('trace_sample_rate', None) # type: ignore + if trace_sample_rate is not None: + if sampling: + raise ValueError( + 'Cannot specify both `trace_sample_rate` and `sampling`. ' + 'Use `sampling.head` instead of `trace_sample_rate`.' + ) + else: + sampling = SamplingOptions(head=trace_sample_rate) + warnings.warn( + 'The `trace_sample_rate` argument is deprecated. ' + 'Use `sampling=logfire.SamplingOptions(head=...)` instead.', + ) + if deprecated_kwargs: raise TypeError(f'configure() got unexpected keyword arguments: {", ".join(deprecated_kwargs)}') @@ -267,7 +279,6 @@ def configure( # noqa: D417 token=token, service_name=service_name, service_version=service_version, - trace_sample_rate=trace_sample_rate, console=console, show_summary=show_summary, config_dir=Path(config_dir) if config_dir else None, @@ -280,7 +291,7 @@ def configure( # noqa: D417 fast_shutdown=fast_shutdown, scrubbing=scrubbing, inspect_arguments=inspect_arguments, - tail_sampling=tail_sampling, + sampling=sampling, ) @@ -318,9 +329,6 @@ class _LogfireConfigData: service_version: str | None """The version of this service""" - trace_sample_rate: float - """The sampling ratio for spans""" - console: ConsoleOptions | Literal[False] | None """Options for controlling console output""" @@ -351,8 +359,8 @@ class _LogfireConfigData: inspect_arguments: bool """Whether to enable f-string magic""" - tail_sampling: TailSamplingOptions | None - """Tail sampling options""" + sampling: SamplingOptions + """Sampling options""" def _load_configuration( self, @@ -364,7 +372,6 @@ def _load_configuration( token: str | None, service_name: str | None, service_version: str | None, - trace_sample_rate: float | None, console: ConsoleOptions | Literal[False] | None, show_summary: bool | None, config_dir: Path | None, @@ -377,7 +384,7 @@ def _load_configuration( fast_shutdown: bool, scrubbing: ScrubbingOptions | Literal[False] | None, inspect_arguments: bool | None, - tail_sampling: TailSamplingOptions | None, + sampling: SamplingOptions | None, ) -> None: """Merge the given parameters with the environment variables file configurations.""" param_manager = ParamManager.create(config_dir) @@ -387,7 +394,6 @@ def _load_configuration( self.token = param_manager.load_param('token', token) self.service_name = param_manager.load_param('service_name', service_name) self.service_version = param_manager.load_param('service_version', service_version) - self.trace_sample_rate = param_manager.load_param('trace_sample_rate', trace_sample_rate) self.show_summary = param_manager.load_param('show_summary', show_summary) self.data_dir = param_manager.load_param('data_dir', data_dir) self.inspect_arguments = param_manager.load_param('inspect_arguments', inspect_arguments) @@ -434,10 +440,14 @@ def _load_configuration( if get_version(pydantic.__version__) < get_version('2.5.0'): # pragma: no cover raise RuntimeError('The Pydantic plugin requires Pydantic 2.5.0 or newer.') - if isinstance(tail_sampling, dict): + if isinstance(sampling, dict): # This is particularly for deserializing from a dict as in executors.py - tail_sampling = TailSamplingOptions(**tail_sampling) # type: ignore - self.tail_sampling = tail_sampling + sampling = SamplingOptions(**sampling) # type: ignore + elif sampling is None: + sampling = SamplingOptions( + head=param_manager.load_param('trace_sample_rate'), + ) + self.sampling = sampling self.fast_shutdown = fast_shutdown @@ -462,7 +472,6 @@ def __init__( token: str | None = None, service_name: str | None = None, service_version: str | None = None, - trace_sample_rate: float | None = None, console: ConsoleOptions | Literal[False] | None = None, show_summary: bool | None = None, config_dir: Path | None = None, @@ -475,7 +484,7 @@ def __init__( fast_shutdown: bool = False, scrubbing: ScrubbingOptions | Literal[False] | None = None, inspect_arguments: bool | None = None, - tail_sampling: TailSamplingOptions | None = None, + sampling: SamplingOptions | None = None, ) -> None: """Create a new LogfireConfig. @@ -491,7 +500,6 @@ def __init__( token=token, service_name=service_name, service_version=service_version, - trace_sample_rate=trace_sample_rate, console=console, show_summary=show_summary, config_dir=config_dir, @@ -504,7 +512,7 @@ def __init__( fast_shutdown=fast_shutdown, scrubbing=scrubbing, inspect_arguments=inspect_arguments, - tail_sampling=tail_sampling, + sampling=sampling, ) # initialize with no-ops so that we don't impact OTEL's global config just because logfire is installed # that is, we defer setting logfire as the otel global config until `configure` is called @@ -524,7 +532,6 @@ def configure( token: str | None, service_name: str | None, service_version: str | None, - trace_sample_rate: float | None, console: ConsoleOptions | Literal[False] | None, show_summary: bool | None, config_dir: Path | None, @@ -537,7 +544,7 @@ def configure( fast_shutdown: bool, scrubbing: ScrubbingOptions | Literal[False] | None, inspect_arguments: bool | None, - tail_sampling: TailSamplingOptions | None, + sampling: SamplingOptions | None, ) -> None: with self._lock: self._initialized = False @@ -547,7 +554,6 @@ def configure( token, service_name, service_version, - trace_sample_rate, console, show_summary, config_dir, @@ -560,7 +566,7 @@ def configure( fast_shutdown, scrubbing, inspect_arguments, - tail_sampling, + sampling, ) self.initialize() @@ -601,13 +607,13 @@ def _initialize(self) -> ProxyTracerProvider: # Both recommend generating a UUID. resource = Resource({ResourceAttributes.SERVICE_INSTANCE_ID: uuid4().hex}).merge(resource) - # Avoid using the usual sampler if we're using tail-based sampling. - # The TailSamplingProcessor will handle the random sampling part as well. - sampler = ( - ParentBasedTraceIdRatio(self.trace_sample_rate) - if self.trace_sample_rate < 1 and self.tail_sampling is None - else None - ) + head = self.sampling.head + sampler: Sampler | None = None + if isinstance(head, (int, float)): + if head < 1: + sampler = ParentBasedTraceIdRatio(head) + else: + sampler = head tracer_provider = SDKTracerProvider( sampler=sampler, resource=resource, @@ -628,15 +634,8 @@ def add_span_processor(span_processor: SpanProcessor) -> None: (TestExporter, RemovePendingSpansExporter, SimpleConsoleSpanExporter), ) - if self.tail_sampling: - span_processor = TailSamplingProcessor( - span_processor, - self.tail_sampling, - # If self.trace_sample_rate < 1 then that ratio of spans should be included randomly by this. - # In that case the tracer provider doesn't need to do any sampling, see above. - # Otherwise we're not using any random sampling, so 0% of spans should be included 'randomly'. - self.trace_sample_rate if self.trace_sample_rate < 1 else 0, - ) + if self.sampling.tail: + span_processor = TailSamplingProcessor(span_processor, self.sampling.tail) span_processor = MainSpanProcessorWrapper(span_processor, self.scrubber) tracer_provider.add_span_processor(span_processor) if has_pending: diff --git a/logfire/_internal/config_params.py b/logfire/_internal/config_params.py index d8929023f..4cbd05417 100644 --- a/logfire/_internal/config_params.py +++ b/logfire/_internal/config_params.py @@ -91,7 +91,7 @@ class _DefaultCallback: PYDANTIC_PLUGIN_EXCLUDE = ConfigParam(env_vars=['LOGFIRE_PYDANTIC_PLUGIN_EXCLUDE'], allow_file_config=True, default=set(), tp=Set[str]) """Set of items that should be excluded from Logfire Pydantic plugin instrumentation.""" TRACE_SAMPLE_RATE = ConfigParam(env_vars=['LOGFIRE_TRACE_SAMPLE_RATE', 'OTEL_TRACES_SAMPLER_ARG'], allow_file_config=True, default=1.0, tp=float) -"""Default sampling ratio for traces. Can be overridden by the `logfire.sample_rate` attribute of a span.""" +"""Head sampling rate for traces.""" INSPECT_ARGUMENTS = ConfigParam(env_vars=['LOGFIRE_INSPECT_ARGUMENTS'], allow_file_config=True, default=sys.version_info[:2] >= (3, 11), tp=bool) """Whether to enable the f-string magic feature. On by default for Python 3.11 and above.""" IGNORE_NO_CONFIG = ConfigParam(env_vars=['LOGFIRE_IGNORE_NO_CONFIG'], allow_file_config=True, default=False, tp=bool) diff --git a/logfire/_internal/constants.py b/logfire/_internal/constants.py index 6eeb74a34..36d689f5f 100644 --- a/logfire/_internal/constants.py +++ b/logfire/_internal/constants.py @@ -12,7 +12,7 @@ LevelName = Literal['trace', 'debug', 'info', 'notice', 'warn', 'warning', 'error', 'fatal'] """Level names for records.""" -LEVEL_NUMBERS = { +LEVEL_NUMBERS: dict[LevelName, int] = { 'trace': 1, 'debug': 5, 'info': 9, @@ -24,9 +24,9 @@ 'fatal': 21, } -NUMBER_TO_LEVEL = {v: k for k, v in LEVEL_NUMBERS.items()} +NUMBER_TO_LEVEL: dict[int, LevelName] = {v: k for k, v in LEVEL_NUMBERS.items()} -LOGGING_TO_OTEL_LEVEL_NUMBERS = { +LOGGING_TO_OTEL_LEVEL_NUMBERS: dict[int, int] = { 0: 9, # logging.NOTSET: default to info 1: 1, # OTEL trace 2: 1, diff --git a/logfire/_internal/exporters/tail_sampling.py b/logfire/_internal/exporters/tail_sampling.py deleted file mode 100644 index 19c0ee928..000000000 --- a/logfire/_internal/exporters/tail_sampling.py +++ /dev/null @@ -1,145 +0,0 @@ -from __future__ import annotations - -import random -from dataclasses import dataclass -from functools import cached_property -from threading import Lock - -from opentelemetry import context -from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor - -from logfire._internal.constants import ( - ATTRIBUTES_LOG_LEVEL_NUM_KEY, - LEVEL_NUMBERS, - ONE_SECOND_IN_NANOSECONDS, - LevelName, -) -from logfire._internal.exporters.wrapper import WrapperSpanProcessor - - -@dataclass -class TailSamplingOptions: - level: LevelName | None = 'notice' - """ - Include all spans/logs with level greater than or equal to this level. - If None, spans are not included based on level. - """ - - duration: float | None = 1.0 - """ - Include all spans/logs with duration greater than this duration in seconds. - If None, spans are not included based on duration. - """ - - -@dataclass -class TraceBuffer: - """Arguments of `on_start` and `on_end` for spans in a single trace.""" - - started: list[tuple[Span, context.Context | None]] - ended: list[ReadableSpan] - - @cached_property - def first_span(self) -> Span: - return self.started[0][0] - - -class TailSamplingProcessor(WrapperSpanProcessor): - """Passes spans to the wrapped processor if any span in a trace meets the sampling criteria.""" - - def __init__(self, processor: SpanProcessor, options: TailSamplingOptions, random_rate: float) -> None: - super().__init__(processor) - self.duration: float = ( - float('inf') if options.duration is None else options.duration * ONE_SECOND_IN_NANOSECONDS - ) - self.level: float | int = float('inf') if options.level is None else LEVEL_NUMBERS[options.level] - self.random_rate = random_rate - - # A TraceBuffer is typically created for each new trace. - # If a span meets the sampling criteria, the buffer is dropped and all spans within are pushed - # to the wrapped processor. - # So when more spans arrive and there's no buffer, they get passed through immediately. - self.traces: dict[int, TraceBuffer] = {} - - # Code that touches self.traces and its contents should be protected by this lock. - self.lock = Lock() - - def on_start(self, span: Span, parent_context: context.Context | None = None) -> None: - dropped = False - buffer = None - - with self.lock: - # span.context could supposedly be None, not sure how. - if span.context: # pragma: no branch - trace_id = span.context.trace_id - # If span.parent is None, it's the root span of a trace. - # If random.random() <= self.random_rate, immediately include this trace, - # meaning no buffer for it. - if span.parent is None and random.random() > self.random_rate: - self.traces[trace_id] = TraceBuffer([], []) - - buffer = self.traces.get(trace_id) - if buffer is not None: - # This trace's spans haven't met the criteria yet, so add this span to the buffer. - buffer.started.append((span, parent_context)) - dropped = self.check_span(span, buffer) - # The opposite case is handled outside the lock since it may take some time. - - # This code may take longer since it calls the wrapped processor which might do anything. - # It shouldn't be inside the lock to avoid blocking other threads. - # Since it's not in the lock, it shouldn't touch self.traces or its contents. - if buffer is None: - super().on_start(span, parent_context) - elif dropped: - self.push_buffer(buffer) - - def on_end(self, span: ReadableSpan) -> None: - # This has a very similar structure and reasoning to on_start. - - dropped = False - buffer = None - - with self.lock: - if span.context: # pragma: no branch - trace_id = span.context.trace_id - buffer = self.traces.get(trace_id) - if buffer is not None: - buffer.ended.append(span) - dropped = self.check_span(span, buffer) - if span.parent is None: - # This is the root span, so the trace is hopefully complete. - # Delete the buffer to save memory. - self.traces.pop(trace_id, None) - - if buffer is None: - super().on_end(span) - elif dropped: - self.push_buffer(buffer) - - def check_span(self, span: ReadableSpan, buffer: TraceBuffer) -> bool: - """If the span meets the sampling criteria, drop the buffer and return True. Otherwise, return False.""" - # span.end_time and span.start_time are in nanoseconds and can be None. - if (span.end_time or span.start_time or 0) - (buffer.first_span.start_time or float('inf')) > self.duration: - self.drop_buffer(buffer) - return True - - attributes = span.attributes or {} - level = attributes.get(ATTRIBUTES_LOG_LEVEL_NUM_KEY) - if not isinstance(level, int): - level = LEVEL_NUMBERS['info'] - if level >= self.level: - self.drop_buffer(buffer) - return True - - return False - - def drop_buffer(self, buffer: TraceBuffer) -> None: - span_context = buffer.first_span.context - assert span_context is not None - del self.traces[span_context.trace_id] - - def push_buffer(self, buffer: TraceBuffer) -> None: - for started in buffer.started: - super().on_start(*started) - for span in buffer.ended: - super().on_end(span) diff --git a/logfire/sampling/__init__.py b/logfire/sampling/__init__.py new file mode 100644 index 000000000..cd5d2a471 --- /dev/null +++ b/logfire/sampling/__init__.py @@ -0,0 +1,9 @@ +"""Types for configuring sampling. See the [sampling guide](https://docs.pydantic.dev/logfire/guides/advanced/sampling/).""" + +from ._tail_sampling import SamplingOptions, SpanLevel, TailSamplingSpanInfo + +__all__ = [ + 'SamplingOptions', + 'SpanLevel', + 'TailSamplingSpanInfo', +] diff --git a/logfire/sampling/_tail_sampling.py b/logfire/sampling/_tail_sampling.py new file mode 100644 index 000000000..2cd39cc67 --- /dev/null +++ b/logfire/sampling/_tail_sampling.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from dataclasses import dataclass +from functools import cached_property +from threading import Lock +from typing import Callable, Literal + +from opentelemetry import context +from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor +from opentelemetry.sdk.trace.sampling import Sampler, TraceIdRatioBased +from typing_extensions import Self + +from logfire._internal.constants import ( + ATTRIBUTES_LOG_LEVEL_NUM_KEY, + LEVEL_NUMBERS, + NUMBER_TO_LEVEL, + ONE_SECOND_IN_NANOSECONDS, + LevelName, +) +from logfire._internal.exporters.wrapper import WrapperSpanProcessor + + +@dataclass +class SpanLevel: + """A convenience class for comparing span/log levels. + + Can be compared to log level names (strings) such as 'info' or 'error' using + `<`, `>`, `<=`, or `>=`, so e.g. `level >= 'error'` is valid. + + Will raise an exception if compared to a non-string or an invalid level name. + """ + + number: int + """ + The raw numeric value of the level. Higher values are more severe. + """ + + @property + def name(self) -> LevelName | None: + """The human-readable name of the level, or `None` if the number is invalid.""" + return NUMBER_TO_LEVEL.get(self.number) + + def __eq__(self, other: object): + if isinstance(other, int): + return self.number == other + if isinstance(other, str): + return self.name == other + if isinstance(other, SpanLevel): + return self.number == other.number + return NotImplemented + + def __hash__(self): + return hash(self.number) + + def __lt__(self, other: LevelName): + return self.number < LEVEL_NUMBERS[other] + + def __gt__(self, other: LevelName): + return self.number > LEVEL_NUMBERS[other] + + def __ge__(self, other: LevelName): + return self.number >= LEVEL_NUMBERS[other] + + def __le__(self, other: LevelName): + return self.number <= LEVEL_NUMBERS[other] + + +@dataclass +class TraceBuffer: + """Arguments of `SpanProcessor.on_start` and `SpanProcessor.on_end` for spans in a single trace. + + These are stored until either the trace is included by tail sampling or it's completed and discarded. + """ + + started: list[tuple[Span, context.Context | None]] + ended: list[ReadableSpan] + + @cached_property + def first_span(self) -> Span: + return self.started[0][0] + + @cached_property + def trace_id(self) -> int: + span_context = self.first_span.context + assert span_context is not None + return span_context.trace_id + + +@dataclass +class TailSamplingSpanInfo: + """Argument passed to the [`SamplingOptions.tail`][logfire.sampling.SamplingOptions.tail] callback.""" + + span: ReadableSpan + """ + Raw span object being started or ended. + """ + + context: context.Context | None + """ + Second argument of [`SpanProcessor.on_start`][opentelemetry.sdk.trace.SpanProcessor.on_start] or `None` for `SpanProcessor.on_end`. + """ + + event: Literal['start', 'end'] + """ + `'start'` if the span is being started, `'end'` if it's being ended. + """ + + # Intentionally undocumented for now. + buffer: TraceBuffer + + @property + def level(self) -> SpanLevel: + """The log level of the span.""" + attributes = self.span.attributes or {} + level = attributes.get(ATTRIBUTES_LOG_LEVEL_NUM_KEY) + if not isinstance(level, int): + level = LEVEL_NUMBERS['info'] + return SpanLevel(level) + + @property + def duration(self) -> float: + """The time in seconds between the start of the trace and the start/end of this span.""" + # span.end_time and span.start_time are in nanoseconds and can be None. + return ( + (self.span.end_time or self.span.start_time or 0) - (self.buffer.first_span.start_time or float('inf')) + ) / ONE_SECOND_IN_NANOSECONDS + + +@dataclass +class SamplingOptions: + """Options for [`logfire.configure(sampling=...)`][logfire.configure(sampling)]. + + See the [sampling guide](https://docs.pydantic.dev/logfire/guides/advanced/sampling/). + """ + + head: float | Sampler = 1.0 + """ + Head sampling options. + + If it's a float, it should be a number between 0.0 and 1.0. + This is the probability that an entire trace will randomly included. + + Alternatively you can pass a custom + [OpenTelemetry `Sampler`](https://opentelemetry-python.readthedocs.io/en/latest/sdk/trace.sampling.html). + """ + + tail: Callable[[TailSamplingSpanInfo], float] | None = None + """ + An optional tail sampling callback which will be called for every span. + + It should return a number between 0.0 and 1.0, the probability that the entire trace will be included. + Use [`SamplingOptions.level_or_duration`][logfire.sampling.SamplingOptions.level_or_duration] + for a common use case. + + Every span in a trace will be stored in memory until either the trace is included by tail sampling + or it's completed and discarded, so large traces may consume a lot of memory. + """ + + @classmethod + def level_or_duration( + cls, + *, + head: float | Sampler = 1.0, + level_threshold: LevelName | None = 'notice', + duration_threshold: float | None = 5.0, + background_rate: float = 0.0, + ) -> Self: + """Returns a `SamplingOptions` instance that tail samples traces based on their log level and duration. + + If a trace has at least one span/log that has a log level greater than or equal to `level_threshold`, + or if the duration of the whole trace is greater than `duration_threshold` seconds, + then the whole trace will be included. + Otherwise, the probability is `background_rate`. + + The `head` parameter is the same as in the `SamplingOptions` constructor. + """ + head_sample_rate = head if isinstance(head, (float, int)) else 1.0 + + if not (0.0 <= background_rate <= head_sample_rate <= 1.0): + raise ValueError('Invalid sampling rates, must be 0.0 <= background_rate <= head <= 1.0') + + def get_tail_sample_rate(span_info: TailSamplingSpanInfo) -> float: + if duration_threshold is not None and span_info.duration > duration_threshold: + return 1.0 + + if level_threshold is not None and span_info.level >= level_threshold: + return 1.0 + + return background_rate + + return cls(head=head, tail=get_tail_sample_rate) + + +def check_trace_id_ratio(trace_id: int, rate: float) -> bool: + # Based on TraceIdRatioBased.should_sample. + return (trace_id & TraceIdRatioBased.TRACE_ID_LIMIT) < TraceIdRatioBased.get_bound_for_rate(rate) + + +class TailSamplingProcessor(WrapperSpanProcessor): + """Passes spans to the wrapped processor if any span in a trace meets the sampling criteria.""" + + def __init__(self, processor: SpanProcessor, get_tail_sample_rate: Callable[[TailSamplingSpanInfo], float]) -> None: + super().__init__(processor) + self.get_tail_sample_rate = get_tail_sample_rate + + # A TraceBuffer is typically created for each new trace. + # If a span meets the sampling criteria, the buffer is dropped and all spans within are pushed + # to the wrapped processor. + # So when more spans arrive and there's no buffer, they get passed through immediately. + self.traces: dict[int, TraceBuffer] = {} + + # Code that touches self.traces and its contents should be protected by this lock. + self.lock = Lock() + + def on_start(self, span: Span, parent_context: context.Context | None = None) -> None: + dropped = False + buffer = None + + with self.lock: + # span.context could supposedly be None, not sure how. + if span.context: # pragma: no branch + trace_id = span.context.trace_id + # If span.parent is None, it's the root span of a trace. + if span.parent is None: + self.traces[trace_id] = TraceBuffer([], []) + + buffer = self.traces.get(trace_id) + if buffer is not None: + # This trace's spans haven't met the criteria yet, so add this span to the buffer. + buffer.started.append((span, parent_context)) + dropped = self.check_span(TailSamplingSpanInfo(span, parent_context, 'start', buffer)) + # The opposite case is handled outside the lock since it may take some time. + + # This code may take longer since it calls the wrapped processor which might do anything. + # It shouldn't be inside the lock to avoid blocking other threads. + # Since it's not in the lock, it shouldn't touch self.traces or its contents. + if buffer is None: + super().on_start(span, parent_context) + elif dropped: + self.push_buffer(buffer) + + def on_end(self, span: ReadableSpan) -> None: + # This has a very similar structure and reasoning to on_start. + + dropped = False + buffer = None + + with self.lock: + if span.context: # pragma: no branch + trace_id = span.context.trace_id + buffer = self.traces.get(trace_id) + if buffer is not None: + buffer.ended.append(span) + dropped = self.check_span(TailSamplingSpanInfo(span, None, 'end', buffer)) + if span.parent is None: + # This is the root span, so the trace is hopefully complete. + # Delete the buffer to save memory. + self.traces.pop(trace_id, None) + + if buffer is None: + super().on_end(span) + elif dropped: + self.push_buffer(buffer) + + def check_span(self, span_info: TailSamplingSpanInfo) -> bool: + """If the span meets the sampling criteria, drop the buffer and return True. Otherwise, return False.""" + sample_rate = self.get_tail_sample_rate(span_info) + if sampled := check_trace_id_ratio(span_info.buffer.trace_id, sample_rate): + self.drop_buffer(span_info.buffer) + + return sampled + + def drop_buffer(self, buffer: TraceBuffer) -> None: + del self.traces[buffer.trace_id] + + def push_buffer(self, buffer: TraceBuffer) -> None: + for started in buffer.started: + super().on_start(*started) + for span in buffer.ended: + super().on_end(span) diff --git a/mkdocs.yml b/mkdocs.yml index c3ab17b23..f7af46edc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -136,6 +136,7 @@ nav: - SDK API: - Logfire: api/logfire.md - Testing: api/testing.md + - Sampling: api/sampling.md - Propagate: api/propagate.md - Exceptions: api/exceptions.md - Integrations: diff --git a/tests/test_configure.py b/tests/test_configure.py index fc61a9382..17f0acd79 100644 --- a/tests/test_configure.py +++ b/tests/test_configure.py @@ -466,6 +466,7 @@ def test_read_config_from_pyproject_toml(tmp_path: Path) -> None: pydantic_plugin_record = "metrics" pydantic_plugin_include = " test1, test2" pydantic_plugin_exclude = "test3 ,test4" + trace_sample_rate = "0.123" """ ) @@ -483,6 +484,7 @@ def test_read_config_from_pyproject_toml(tmp_path: Path) -> None: assert GLOBAL_CONFIG.pydantic_plugin.record == 'metrics' assert GLOBAL_CONFIG.pydantic_plugin.include == {'test1', 'test2'} assert GLOBAL_CONFIG.pydantic_plugin.exclude == {'test3', 'test4'} + assert GLOBAL_CONFIG.sampling.head == 0.123 def test_logfire_invalid_config_dir(tmp_path: Path): @@ -850,7 +852,7 @@ def test_config_serializable(): send_to_logfire=False, pydantic_plugin=logfire.PydanticPlugin(record='all'), console=logfire.ConsoleOptions(verbose=True), - tail_sampling=logfire.TailSamplingOptions(), + sampling=logfire.SamplingOptions(), scrubbing=logfire.ScrubbingOptions(), ) @@ -858,7 +860,7 @@ def test_config_serializable(): # Check that the full set of dataclass fields is known. # If a new field appears here, make sure it gets deserialized properly in configure, and tested here. assert dataclasses.is_dataclass(getattr(GLOBAL_CONFIG, field.name)) == ( - field.name in ['pydantic_plugin', 'console', 'tail_sampling', 'scrubbing'] + field.name in ['pydantic_plugin', 'console', 'sampling', 'scrubbing'] ) serialized = serialize_config() @@ -875,7 +877,7 @@ def normalize(s: dict[str, Any]) -> dict[str, Any]: assert isinstance(GLOBAL_CONFIG.pydantic_plugin, logfire.PydanticPlugin) assert isinstance(GLOBAL_CONFIG.console, logfire.ConsoleOptions) - assert isinstance(GLOBAL_CONFIG.tail_sampling, logfire.TailSamplingOptions) + assert isinstance(GLOBAL_CONFIG.sampling, logfire.SamplingOptions) assert isinstance(GLOBAL_CONFIG.scrubbing, logfire.ScrubbingOptions) diff --git a/tests/test_logfire.py b/tests/test_logfire.py index dd6443b39..a2b0d29ef 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -30,6 +30,7 @@ ATTRIBUTES_TAGS_KEY, LEVEL_NUMBERS, NULL_ARGS_KEY, + LevelName, ) from logfire._internal.formatter import FormattingFailedWarning, InspectArgumentsFailedWarning from logfire._internal.main import NoopSpan @@ -476,7 +477,7 @@ def test_span_end_on_exit_false(exporter: TestExporter) -> None: @pytest.mark.parametrize('level', ('fatal', 'debug', 'error', 'info', 'notice', 'warn', 'trace')) -def test_log(exporter: TestExporter, level: str): +def test_log(exporter: TestExporter, level: LevelName): getattr(logfire, level)('test {name} {number} {none}', name='foo', number=2, none=None) s = exporter.exported_spans[0] diff --git a/tests/test_logfire_api.py b/tests/test_logfire_api.py index 36a23d3ff..aa8137802 100644 --- a/tests/test_logfire_api.py +++ b/tests/test_logfire_api.py @@ -156,9 +156,9 @@ def func() -> None: ... logfire_api.StructlogProcessor() logfire__all__.remove('StructlogProcessor') - assert hasattr(logfire_api, 'TailSamplingOptions') - logfire_api.TailSamplingOptions() - logfire__all__.remove('TailSamplingOptions') + assert hasattr(logfire_api, 'SamplingOptions') + logfire_api.SamplingOptions() + logfire__all__.remove('SamplingOptions') assert hasattr(logfire_api, 'ScrubbingOptions') logfire_api.ScrubbingOptions() diff --git a/tests/test_pydantic_plugin.py b/tests/test_pydantic_plugin.py index 65436e974..3d090cc37 100644 --- a/tests/test_pydantic_plugin.py +++ b/tests/test_pydantic_plugin.py @@ -627,14 +627,12 @@ class MyDataclass: ) -def test_pydantic_plugin_sample_rate_config(exporter: TestExporter) -> None: - logfire.configure( - send_to_logfire=False, - trace_sample_rate=0.1, - additional_span_processors=[SimpleSpanProcessor(exporter)], +def test_pydantic_plugin_sample_rate_config(exporter: TestExporter, config_kwargs: dict[str, Any]) -> None: + config_kwargs.update( + sampling=logfire.SamplingOptions(head=0.1), id_generator=SeededRandomIdGenerator(), - additional_metric_readers=[InMemoryMetricReader()], ) + logfire.configure(**config_kwargs) class MyModel(BaseModel, plugin_settings={'logfire': {'record': 'all'}}): x: int diff --git a/tests/test_sampling.py b/tests/test_sampling.py index fa9ab4def..2fffc608d 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -56,25 +56,21 @@ def test_invalid_sample_rate(sample_rate: float) -> None: # pragma: no cover logfire.DEFAULT_LOGFIRE_INSTANCE.with_trace_sample_rate(sample_rate) -def test_sample_rate_config() -> None: - exporter = TestExporter() - - logfire.configure( - send_to_logfire=False, - trace_sample_rate=0.05, - additional_span_processors=[SimpleSpanProcessor(exporter)], +def test_sample_rate_config(exporter: TestExporter, config_kwargs: dict[str, Any]) -> None: + config_kwargs.update( + sampling=logfire.SamplingOptions(head=0.3), id_generator=SeededRandomIdGenerator(), - additional_metric_readers=[InMemoryMetricReader()], ) + logfire.configure(**config_kwargs) - for _ in range(100): + for _ in range(1000): with logfire.span('outer'): with logfire.span('inner'): pass - # 100 iterations of 2 spans -> 200 spans - # 5% sampling -> 10 spans (approximately) - assert len(exporter.exported_spans_as_dict()) == 6 + # 1000 iterations of 2 spans -> 2000 spans + # 30% sampling -> 600 spans (approximately) + assert len(exporter.exported_spans_as_dict()) == 634 @pytest.mark.skipif( @@ -85,7 +81,6 @@ def test_sample_rate_runtime() -> None: # pragma: no cover logfire.configure( send_to_logfire=False, - trace_sample_rate=1, additional_span_processors=[SimpleSpanProcessor(exporter)], id_generator=SeededRandomIdGenerator(), additional_metric_readers=[InMemoryMetricReader()], @@ -109,7 +104,6 @@ def test_outer_sampled_inner_not() -> None: # pragma: no cover logfire.configure( send_to_logfire=False, - trace_sample_rate=1, id_generator=SeededRandomIdGenerator(), ns_timestamp_generator=TimeGenerator(), additional_span_processors=[SimpleSpanProcessor(exporter)], @@ -138,7 +132,6 @@ def test_outer_and_inner_sampled() -> None: # pragma: no cover logfire.configure( send_to_logfire=False, - trace_sample_rate=1, id_generator=SeededRandomIdGenerator(), ns_timestamp_generator=TimeGenerator(), additional_span_processors=[SimpleSpanProcessor(exporter)], @@ -173,7 +166,6 @@ def test_sampling_rate_does_not_get_overwritten() -> None: # pragma: no cover logfire.configure( send_to_logfire=False, - trace_sample_rate=1, id_generator=SeededRandomIdGenerator(), ns_timestamp_generator=TimeGenerator(), additional_span_processors=[SimpleSpanProcessor(exporter)], diff --git a/tests/test_tail_sampling.py b/tests/test_tail_sampling.py index 5e24d1eb7..502ab4c69 100644 --- a/tests/test_tail_sampling.py +++ b/tests/test_tail_sampling.py @@ -1,20 +1,23 @@ from __future__ import annotations -import itertools from typing import Any -from unittest.mock import Mock +import inline_snapshot.extra import pytest from inline_snapshot import snapshot +from opentelemetry.context import Context +from opentelemetry.sdk.trace.sampling import ALWAYS_OFF, ALWAYS_ON, Sampler, SamplingResult import logfire -from logfire.testing import TestExporter +from logfire._internal.constants import LEVEL_NUMBERS +from logfire.sampling import SpanLevel, TailSamplingSpanInfo +from logfire.testing import SeededRandomIdGenerator, TestExporter -def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): +def test_level_threshold(config_kwargs: dict[str, Any], exporter: TestExporter): # Use the default TailSamplingOptions.level of 'notice'. # Set duration to None to not include spans with a long duration. - logfire.configure(**config_kwargs, tail_sampling=logfire.TailSamplingOptions(duration=None)) + logfire.configure(**config_kwargs, sampling=logfire.SamplingOptions.level_or_duration(duration_threshold=None)) with logfire.span('ignored span'): logfire.debug('ignored debug') @@ -46,7 +49,7 @@ def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): 'logfire.msg_template': 'notice', 'logfire.msg': 'notice', 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_level_sampling', + 'code.function': 'test_level_threshold', 'code.lineno': 123, }, }, @@ -62,7 +65,7 @@ def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): 'logfire.msg_template': 'warn', 'logfire.msg': 'warn', 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_level_sampling', + 'code.function': 'test_level_threshold', 'code.lineno': 123, }, }, @@ -74,7 +77,7 @@ def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): 'end_time': 9000000000, 'attributes': { 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_level_sampling', + 'code.function': 'test_level_threshold', 'code.lineno': 123, 'logfire.msg_template': 'span', 'logfire.msg': 'span', @@ -90,7 +93,7 @@ def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): 'end_time': 10000000000, 'attributes': { 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_level_sampling', + 'code.function': 'test_level_threshold', 'code.lineno': 123, 'logfire.msg_template': 'span2', 'logfire.msg': 'span2', @@ -110,7 +113,7 @@ def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): 'logfire.msg_template': 'error', 'logfire.msg': 'error', 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_level_sampling', + 'code.function': 'test_level_threshold', 'code.lineno': 123, }, }, @@ -122,7 +125,7 @@ def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): 'end_time': 12000000000, 'attributes': { 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_level_sampling', + 'code.function': 'test_level_threshold', 'code.lineno': 123, 'logfire.msg_template': 'span2', 'logfire.msg': 'span2', @@ -137,7 +140,7 @@ def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): 'end_time': 13000000000, 'attributes': { 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_level_sampling', + 'code.function': 'test_level_threshold', 'code.lineno': 123, 'logfire.msg_template': 'span', 'logfire.msg': 'span', @@ -152,7 +155,7 @@ def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): 'end_time': 14000000000, 'attributes': { 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_level_sampling', + 'code.function': 'test_level_threshold', 'code.lineno': 123, 'logfire.msg_template': 'span3', 'logfire.msg': 'span3', @@ -173,7 +176,7 @@ def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): 'logfire.msg_template': 'trace', 'logfire.msg': 'trace', 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_level_sampling', + 'code.function': 'test_level_threshold', 'code.lineno': 123, }, }, @@ -185,7 +188,7 @@ def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): 'end_time': 16000000000, 'attributes': { 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_level_sampling', + 'code.function': 'test_level_threshold', 'code.lineno': 123, 'logfire.msg_template': 'span3', 'logfire.msg': 'span3', @@ -197,9 +200,11 @@ def test_level_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): ) -def test_duration_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): +def test_duration_threshold(config_kwargs: dict[str, Any], exporter: TestExporter): # Set level to None to not include spans merely based on a high level. - logfire.configure(**config_kwargs, tail_sampling=logfire.TailSamplingOptions(level=None, duration=3)) + logfire.configure( + **config_kwargs, sampling=logfire.SamplingOptions.level_or_duration(level_threshold=None, duration_threshold=3) + ) logfire.error('short1') with logfire.span('span'): @@ -225,7 +230,7 @@ def test_duration_sampling(config_kwargs: dict[str, Any], exporter: TestExporter 'end_time': 9000000000, 'attributes': { 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_duration_sampling', + 'code.function': 'test_duration_threshold', 'code.lineno': 123, 'logfire.msg_template': 'span4', 'logfire.msg': 'span4', @@ -241,7 +246,7 @@ def test_duration_sampling(config_kwargs: dict[str, Any], exporter: TestExporter 'end_time': 10000000000, 'attributes': { 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_duration_sampling', + 'code.function': 'test_duration_threshold', 'code.lineno': 123, 'logfire.msg_template': 'span5', 'logfire.msg': 'span5', @@ -261,7 +266,7 @@ def test_duration_sampling(config_kwargs: dict[str, Any], exporter: TestExporter 'logfire.msg_template': 'long1', 'logfire.msg': 'long1', 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_duration_sampling', + 'code.function': 'test_duration_threshold', 'code.lineno': 123, }, }, @@ -273,7 +278,7 @@ def test_duration_sampling(config_kwargs: dict[str, Any], exporter: TestExporter 'end_time': 12000000000, 'attributes': { 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_duration_sampling', + 'code.function': 'test_duration_threshold', 'code.lineno': 123, 'logfire.msg_template': 'span5', 'logfire.msg': 'span5', @@ -288,7 +293,7 @@ def test_duration_sampling(config_kwargs: dict[str, Any], exporter: TestExporter 'end_time': 13000000000, 'attributes': { 'code.filepath': 'test_tail_sampling.py', - 'code.function': 'test_duration_sampling', + 'code.function': 'test_duration_threshold', 'code.lineno': 123, 'logfire.msg_template': 'span4', 'logfire.msg': 'span4', @@ -299,22 +304,187 @@ def test_duration_sampling(config_kwargs: dict[str, Any], exporter: TestExporter ) -def test_random_sampling(config_kwargs: dict[str, Any], exporter: TestExporter, monkeypatch: pytest.MonkeyPatch): +def test_background_rate(config_kwargs: dict[str, Any], exporter: TestExporter): + config_kwargs.update( + sampling=logfire.SamplingOptions.level_or_duration(background_rate=0.3), + id_generator=SeededRandomIdGenerator(seed=1), + ) + logfire.configure(**config_kwargs) + # These spans should all be included because the level is above the default. + for _ in range(100): + logfire.error('error') + assert len(exporter.exported_spans) == 100 + + # About 30% these spans (i.e. an extra ~300) should be included because of the background_rate. + # None of them meet the tail sampling criteria. + for _ in range(1000): + logfire.info('info') + assert len(exporter.exported_spans) == 100 + 321 + + +class TestSampler(Sampler): + def should_sample( + self, + parent_context: Context | None, + trace_id: int, + name: str, + *args: Any, + **kwargs: Any, + ) -> SamplingResult: + if name == 'exclude': + sampler = ALWAYS_OFF + else: + sampler = ALWAYS_ON + return sampler.should_sample(parent_context, trace_id, name, *args, **kwargs) + + def get_description(self) -> str: # pragma: no cover + return 'MySampler' + + +def test_raw_head_sampler_without_tail_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): logfire.configure( **config_kwargs, - tail_sampling=logfire.TailSamplingOptions(), - trace_sample_rate=0.3, + sampling=logfire.SamplingOptions(head=TestSampler()), ) - # <0.3 a third of the time. - monkeypatch.setattr('random.random', Mock(side_effect=itertools.cycle([0.1, 0.6, 0.9]))) - # These spans should all be included because the level is above the default. - for _ in range(10): + # These spans should all be excluded by the head sampler + for _ in range(100): + logfire.error('exclude') + assert len(exporter.exported_spans) == 0 + + for _ in range(100): logfire.error('error') - assert len(exporter.exported_spans) == 10 + assert len(exporter.exported_spans) == 100 + + +def test_raw_head_sampler_with_tail_sampling(config_kwargs: dict[str, Any], exporter: TestExporter): + config_kwargs.update( + sampling=logfire.SamplingOptions.level_or_duration(head=TestSampler(), background_rate=0.3), + id_generator=SeededRandomIdGenerator(seed=1), + ) + logfire.configure(**config_kwargs) - # A third of these spans (i.e. an extra 10) should be included because of the trace_sample_rate. + # These spans should all be excluded by the head sampler, + # so it doesn't matter that they have a high level. + for _ in range(100): + logfire.error('exclude') + assert len(exporter.exported_spans) == 0 + + # These spans should all be included because the level is above the default, + # and the head sampler doesn't exclude them. + for _ in range(100): + logfire.error('error') + assert len(exporter.exported_spans) == 100 + + # About 30% these spans (i.e. an extra ~300) should be included because of the background_rate. # None of them meet the tail sampling criteria. - for _ in range(30): + for _ in range(1000): logfire.info('info') - assert len(exporter.exported_spans) == 20 + assert len(exporter.exported_spans) == 100 + 315 + + +def test_custom_head_and_tail(config_kwargs: dict[str, Any], exporter: TestExporter): + span_counts = {'start': 0, 'end': 0} + + def get_tail_sample_rate(span_info: TailSamplingSpanInfo) -> float: + span_counts[span_info.event] += 1 + if span_info.duration >= 1: + return 0.5 + if span_info.level > 'warn': + return 0.3 + return 0.1 + + config_kwargs.update( + sampling=logfire.SamplingOptions( + head=0.7, + tail=get_tail_sample_rate, + ), + id_generator=SeededRandomIdGenerator(seed=3), + ) + + logfire.configure(**config_kwargs) + + for _ in range(1000): + logfire.warn('warn') + assert span_counts == snapshot({'start': 720, 'end': 617}) + assert len(exporter.exported_spans) == snapshot(103) + assert span_counts['end'] + len(exporter.exported_spans) == span_counts['start'] + + exporter.clear() + for _ in range(1000): + with logfire.span('span'): + pass + assert len(exporter.exported_spans_as_dict()) == snapshot(506) + + exporter.clear() + for _ in range(1000): + logfire.error('error') + assert len(exporter.exported_spans) == snapshot(291) + + +def test_span_levels(): + warn = SpanLevel(LEVEL_NUMBERS['warn']) + + assert 'warn' <= warn <= 'warn' + assert 'debug' <= warn <= 'error' + assert 'info' < warn < 'fatal' + assert not (warn < 'warn') + assert not (warn > 'warn') + assert not (warn > 'fatal') + assert not (warn >= 'fatal') + assert not (warn < 'debug') + assert not (warn <= 'debug') + + assert warn == 'warn' + assert warn != 'error' + assert not (warn != 'warn') + assert not (warn == 'error') + assert warn == warn.number + assert warn != 123 + assert warn == warn + assert not (warn != warn) + assert warn != SpanLevel(LEVEL_NUMBERS['error']) + assert not (warn == SpanLevel(LEVEL_NUMBERS['error'])) + assert warn == SpanLevel(LEVEL_NUMBERS['warn']) + assert not (warn != SpanLevel(LEVEL_NUMBERS['warn'])) + + assert warn.number == LEVEL_NUMBERS['warn'] == 13 + assert warn.name == 'warn' + + assert warn != 'foo' + assert warn != [1, 2, 3] + assert not (warn == [1, 2, 3]) + assert not ([1, 2, 3] == warn) + + assert {warn, SpanLevel(LEVEL_NUMBERS['warn'])} == {warn} + + +def test_invalid_rates(): + with inline_snapshot.extra.raises( + snapshot('ValueError: Invalid sampling rates, ' 'must be 0.0 <= background_rate <= head <= 1.0') + ): + logfire.SamplingOptions.level_or_duration(background_rate=-1) + with pytest.raises(ValueError): + logfire.SamplingOptions.level_or_duration(background_rate=0.5, head=0.3) + with pytest.raises(ValueError): + logfire.SamplingOptions.level_or_duration(head=2) + + +def test_trace_sample_rate(config_kwargs: dict[str, Any]): + with pytest.warns(UserWarning) as warnings: + logfire.configure(trace_sample_rate=0.123, **config_kwargs) # type: ignore + assert logfire.DEFAULT_LOGFIRE_INSTANCE.config.sampling.head == 0.123 + assert len(warnings) == 1 + assert str(warnings[0].message) == snapshot( + 'The `trace_sample_rate` argument is deprecated. ' 'Use `sampling=logfire.SamplingOptions(head=...)` instead.' + ) + + +def test_both_trace_and_head(): + with inline_snapshot.extra.raises( + snapshot( + 'ValueError: Cannot specify both `trace_sample_rate` and `sampling`. ' + 'Use `sampling.head` instead of `trace_sample_rate`.' + ) + ): + logfire.configure(trace_sample_rate=0.5, sampling=logfire.SamplingOptions()) # type: ignore