Skip to content

Commit

Permalink
Remove double record exception (#712)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu authored Jan 3, 2025
1 parent 80779f3 commit 1382f17
Showing 3 changed files with 65 additions and 38 deletions.
47 changes: 10 additions & 37 deletions logfire/_internal/main.py
Original file line number Diff line number Diff line change
@@ -2109,8 +2109,7 @@ def __enter__(self) -> FastLogfireSpan:
@handle_internal_errors()
def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None:
context_api.detach(self._token)
_exit_span(self._span, exc_value)
self._span.end()
self._span.__exit__(exc_type, exc_value, traceback)


# Changes to this class may need to be reflected in `FastLogfireSpan` and `NoopSpan` as well.
@@ -2146,6 +2145,7 @@ def __enter__(self) -> LogfireSpan:
attributes=self._otlp_attributes,
links=self._links,
)
self._span.__enter__()
if self._token is None: # pragma: no branch
self._token = context_api.attach(trace_api.set_span_in_context(self._span))

@@ -2155,14 +2155,17 @@ def __enter__(self) -> LogfireSpan:
def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None:
if self._token is None: # pragma: no cover
return
assert self._span is not None

context_api.detach(self._token)
self._token = None

assert self._span is not None
_exit_span(self._span, exc_value)

self.end()
if self._span.is_recording():
with handle_internal_errors():
if self._added_attributes:
self._span.set_attribute(
ATTRIBUTES_JSON_SCHEMA_KEY, attributes_json_schema(self._json_schema_properties)
)
self._span.__exit__(exc_type, exc_value, traceback)

@property
def message_template(self) -> str | None: # pragma: no cover
@@ -2188,26 +2191,6 @@ def message(self) -> str:
def message(self, message: str):
self._set_attribute(ATTRIBUTES_MESSAGE_KEY, message)

def end(self, end_time: int | None = None) -> None:
"""Sets the current time as the span's end time.
The span's end time is the wall time at which the operation finished.
Only the first call to this method is recorded, further calls are ignored so you
can call this within the span's context manager to end it before the context manager
exits.
"""
if self._span is None: # pragma: no cover
raise RuntimeError('Span has not been started')
if self._span.is_recording():
with handle_internal_errors():
if self._added_attributes:
self._span.set_attribute(
ATTRIBUTES_JSON_SCHEMA_KEY, attributes_json_schema(self._json_schema_properties)
)

self._span.end(end_time)

@handle_internal_errors()
def set_attribute(self, key: str, value: Any) -> None:
"""Sets an attribute on the span.
@@ -2339,16 +2322,6 @@ def is_recording(self) -> bool:
return False


def _exit_span(span: trace_api.Span, exception: BaseException | None) -> None:
if not span.is_recording():
return

# record exception if present
# isinstance is to ignore BaseException
if isinstance(exception, Exception):
record_exception(span, exception, escaped=True)


AttributesValueType = TypeVar('AttributesValueType', bound=Union[Any, otel_types.AttributeValue])


14 changes: 13 additions & 1 deletion logfire/_internal/tracer.py
Original file line number Diff line number Diff line change
@@ -116,7 +116,13 @@ def force_flush(self, timeout_millis: int = 30000) -> bool:

@dataclass(eq=False)
class _LogfireWrappedSpan(trace_api.Span, ReadableSpan):
"""Span that overrides end() to use a timestamp generator if one was provided."""
"""A span that wraps another span and overrides some behaviors in a logfire-specific way.
In particular:
* Stores a reference to itself in `OPEN_SPANS`, used to close open spans when the program exits
* Adds some logfire-specific tweaks to the exception recording behavior
* Overrides end() to use a timestamp generator if one was provided
"""

span: Span
ns_timestamp_generator: Callable[[], int]
@@ -171,6 +177,12 @@ def record_exception(
timestamp = timestamp or self.ns_timestamp_generator()
record_exception(self.span, exception, attributes=attributes, timestamp=timestamp, escaped=escaped)

def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None:
if self.is_recording():
if isinstance(exc_value, BaseException):
self.record_exception(exc_value, escaped=True)
self.end()

if not TYPE_CHECKING: # pragma: no branch
# for ReadableSpan
def __getattr__(self, name: str) -> Any:
42 changes: 42 additions & 0 deletions tests/test_logfire.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
from functools import partial
from logging import getLogger
from typing import Any, Callable
from unittest.mock import patch

import pytest
from dirty_equals import IsInt, IsJson, IsStr
@@ -36,6 +37,7 @@
)
from logfire._internal.formatter import FormattingFailedWarning, InspectArgumentsFailedWarning
from logfire._internal.main import NoopSpan
from logfire._internal.tracer import record_exception
from logfire._internal.utils import is_instrumentation_suppressed
from logfire.integrations.logging import LogfireLoggingHandler
from logfire.testing import TestExporter
@@ -3171,3 +3173,43 @@ def test_suppress_scopes(exporter: TestExporter, metrics_reader: InMemoryMetricR
}
]
)


def test_logfire_span_records_exceptions_once():
n_calls_to_record_exception = 0

def patched_record_exception(*args: Any, **kwargs: Any) -> Any:
nonlocal n_calls_to_record_exception
n_calls_to_record_exception += 1

return record_exception(*args, **kwargs)

with patch('logfire._internal.tracer.record_exception', patched_record_exception), patch(
'logfire._internal.main.record_exception', patched_record_exception
):
with pytest.raises(RuntimeError):
with logfire.span('foo'):
raise RuntimeError('error')

assert n_calls_to_record_exception == 1


def test_exit_ended_span(exporter: TestExporter):
tracer = get_tracer(__name__)

with tracer.start_span('test') as span:
# Ensure that doing this does not emit a warning about calling end() twice when the block exits.
span.end()

assert exporter.exported_spans_as_dict(_strip_function_qualname=False) == snapshot(
[
{
'attributes': {'logfire.msg': 'test', 'logfire.span_type': 'span'},
'context': {'is_remote': False, 'span_id': 1, 'trace_id': 1},
'end_time': 2000000000,
'name': 'test',
'parent': None,
'start_time': 1000000000,
}
]
)

0 comments on commit 1382f17

Please sign in to comment.