diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a4e061bb62..16839f2ccb9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - unit annotations (enclosed in curly braces like `{requests}`) are stripped away - units with slash are converted e.g. `m/s` -> `meters_per_second`. - The exporter's API is not changed +- Fix RandomIdGenerator can generate invalid Span/Trace Ids + ([#3921](https://github.com/open-telemetry/opentelemetry-python/issues/3921)) ## Version 1.24.0/0.45b0 (2024-03-28) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/id_generator.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/id_generator.py index 62b12a94921..63ff1fb2483 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/id_generator.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/id_generator.py @@ -15,6 +15,8 @@ import abc import random +from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID + class IdGenerator(abc.ABC): @abc.abstractmethod @@ -46,7 +48,13 @@ class RandomIdGenerator(IdGenerator): """ def generate_span_id(self) -> int: - return random.getrandbits(64) + span_id = random.getrandbits(64) + while span_id == INVALID_SPAN_ID: + span_id = random.getrandbits(64) + return span_id def generate_trace_id(self) -> int: - return random.getrandbits(128) + trace_id = random.getrandbits(128) + while trace_id == INVALID_TRACE_ID: + trace_id = random.getrandbits(128) + return trace_id diff --git a/opentelemetry-sdk/tests/trace/test_trace.py b/opentelemetry-sdk/tests/trace/test_trace.py index 30f4f0e2731..05f6ad56843 100644 --- a/opentelemetry-sdk/tests/trace/test_trace.py +++ b/opentelemetry-sdk/tests/trace/test_trace.py @@ -62,6 +62,7 @@ StatusCode, get_tracer, set_tracer_provider, + span, ) @@ -2061,3 +2062,28 @@ def test_tracer_provider_init_default(self, resource_patch, sample_patch): sample_patch.assert_called_once() self.assertIsNotNone(tracer_provider._span_limits) self.assertIsNotNone(tracer_provider._atexit_handler) + + +class TestRandomIdGenerator(unittest.TestCase): + _TRACE_ID_MAX_VALUE = 2 ** 128 - 1 + _SPAN_ID_MAX_VALUE = 2 ** 64 - 1 + + @patch('random.getrandbits', side_effect=[span.INVALID_SPAN_ID, 0x00000000DEADBEF0]) + def test_generate_span_id_avoids_invalid(self, mock_getrandbits): + generator = RandomIdGenerator() + span_id = generator.generate_span_id() + + self.assertNotEqual(span_id, span.INVALID_SPAN_ID) + self.assertGreater(span_id, span.INVALID_SPAN_ID) + self.assertLessEqual(span_id, self._SPAN_ID_MAX_VALUE) + self.assertEqual(mock_getrandbits.call_count, 2) # Ensure exactly two calls + + @patch('random.getrandbits', side_effect=[span.INVALID_TRACE_ID, 0x000000000000000000000000DEADBEEF]) + def test_generate_trace_id_avoids_invalid(self, mock_getrandbits): + generator = RandomIdGenerator() + trace_id = generator.generate_trace_id() + + self.assertNotEqual(trace_id, span.INVALID_TRACE_ID) + self.assertGreater(trace_id, span.INVALID_TRACE_ID) + self.assertLessEqual(trace_id, self._TRACE_ID_MAX_VALUE) + self.assertEqual(mock_getrandbits.call_count, 2) # Ensure exactly two calls