Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
zhihali committed Jun 4, 2024
1 parent d73593d commit c05ebc1
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- unit annotations (enclosed in curly braces like `{requests}`) are stripped away
- units with slash are converted e.g. `m/s` -> `meters_per_second`.
- The exporter's API is not changed
- Fix RandomIdGenerator can generate invalid Span/Trace Ids
([#3921](https://github.com/open-telemetry/opentelemetry-python/issues/3921))

## Version 1.24.0/0.45b0 (2024-03-28)

Expand Down
12 changes: 10 additions & 2 deletions opentelemetry-sdk/src/opentelemetry/sdk/trace/id_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import abc
import random

from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID


class IdGenerator(abc.ABC):
@abc.abstractmethod
Expand Down Expand Up @@ -46,7 +48,13 @@ class RandomIdGenerator(IdGenerator):
"""

def generate_span_id(self) -> int:
return random.getrandbits(64)
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id

def generate_trace_id(self) -> int:
return random.getrandbits(128)
trace_id = random.getrandbits(128)
while trace_id == INVALID_TRACE_ID:
trace_id = random.getrandbits(128)
return trace_id
26 changes: 26 additions & 0 deletions opentelemetry-sdk/tests/trace/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
StatusCode,
get_tracer,
set_tracer_provider,
span,
)


Expand Down Expand Up @@ -2061,3 +2062,28 @@ def test_tracer_provider_init_default(self, resource_patch, sample_patch):
sample_patch.assert_called_once()
self.assertIsNotNone(tracer_provider._span_limits)
self.assertIsNotNone(tracer_provider._atexit_handler)


class TestRandomIdGenerator(unittest.TestCase):
_TRACE_ID_MAX_VALUE = 2 ** 128 - 1
_SPAN_ID_MAX_VALUE = 2 ** 64 - 1

@patch('random.getrandbits', side_effect=[span.INVALID_SPAN_ID, 0x00000000DEADBEF0])
def test_generate_span_id_avoids_invalid(self, mock_getrandbits):
generator = RandomIdGenerator()
span_id = generator.generate_span_id()

self.assertNotEqual(span_id, span.INVALID_SPAN_ID)
self.assertGreater(span_id, span.INVALID_SPAN_ID)
self.assertLessEqual(span_id, self._SPAN_ID_MAX_VALUE)
self.assertEqual(mock_getrandbits.call_count, 2) # Ensure exactly two calls

@patch('random.getrandbits', side_effect=[span.INVALID_TRACE_ID, 0x000000000000000000000000DEADBEEF])
def test_generate_trace_id_avoids_invalid(self, mock_getrandbits):
generator = RandomIdGenerator()
trace_id = generator.generate_trace_id()

self.assertNotEqual(trace_id, span.INVALID_TRACE_ID)
self.assertGreater(trace_id, span.INVALID_TRACE_ID)
self.assertLessEqual(trace_id, self._TRACE_ID_MAX_VALUE)
self.assertEqual(mock_getrandbits.call_count, 2) # Ensure exactly two calls

0 comments on commit c05ebc1

Please sign in to comment.