diff --git a/CHANGELOG.md b/CHANGELOG.md index 17eec0a6b28..4c0a919e386 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.0.0...HEAD) +### Added - Added `py.typed` file to every package. This should resolve a bunch of mypy errors for users. ([#1720](https://github.com/open-telemetry/opentelemetry-python/pull/1720)) +### Changed +- Adjust `B3Format` propagator to be spec compliant by not modifying context + when propagation headers are not present/invalid/empty + ([#1728](https://github.com/open-telemetry/opentelemetry-python/pull/1728)) + ## [1.0.0](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.0.0) - 2021-03-26 ### Added diff --git a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py index be478b05ec0..df8803d6171 100644 --- a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py +++ b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py @@ -50,8 +50,8 @@ def extract( context: typing.Optional[Context] = None, getter: Getter = default_getter, ) -> Context: - trace_id = format_trace_id(trace.INVALID_TRACE_ID) - span_id = format_span_id(trace.INVALID_SPAN_ID) + trace_id = trace.INVALID_TRACE_ID + span_id = trace.INVALID_SPAN_ID sampled = "0" flags = None @@ -73,8 +73,6 @@ def extract( trace_id, span_id, sampled = fields elif len(fields) == 4: trace_id, span_id, sampled, _ = fields - else: - return trace.set_span_in_context(trace.INVALID_SPAN) else: trace_id = ( _extract_first_element(getter.get(carrier, self.TRACE_ID_KEY)) @@ -94,18 +92,15 @@ def extract( ) if ( - self._trace_id_regex.fullmatch(trace_id) is None + trace_id == trace.INVALID_TRACE_ID + or span_id == trace.INVALID_SPAN_ID + or self._trace_id_regex.fullmatch(trace_id) is None or self._span_id_regex.fullmatch(span_id) is None ): - id_generator = trace.get_tracer_provider().id_generator - trace_id = id_generator.generate_trace_id() - span_id = id_generator.generate_span_id() - sampled = "0" - - else: - trace_id = int(trace_id, 16) - span_id = int(span_id, 16) + return context + trace_id = int(trace_id, 16) + span_id = int(span_id, 16) options = 0 # The b3 spec provides no defined behavior for both sample and # flag values set. Since the setting of at least one implies diff --git a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py index d1d96a269f0..fd0a9a4029a 100644 --- a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py +++ b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py @@ -13,7 +13,7 @@ # limitations under the License. import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock import opentelemetry.propagators.b3 as b3_format # pylint: disable=no-name-in-module,import-error import opentelemetry.sdk.trace as trace @@ -231,89 +231,73 @@ def test_64bit_trace_id(self): new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit ) - def test_invalid_single_header(self): - """If an invalid single header is passed, return an - invalid SpanContext. - """ + def test_extract_invalid_single_header(self): + """Given unparsable header, do not modify context""" + old_ctx = {} + carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} - ctx = FORMAT.extract(carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() - self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) - self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) + new_ctx = FORMAT.extract(carrier, old_ctx) + + self.assertDictEqual(new_ctx, old_ctx) + + def test_extract_missing_trace_id(self): + """Given no trace ID, do not modify context""" + old_ctx = {} - def test_missing_trace_id(self): - """If a trace id is missing, populate an invalid trace id.""" carrier = { FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } + new_ctx = FORMAT.extract(carrier, old_ctx) - ctx = FORMAT.extract(carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() - self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) - - @patch( - "opentelemetry.sdk.trace.id_generator.RandomIdGenerator.generate_trace_id" - ) - @patch( - "opentelemetry.sdk.trace.id_generator.RandomIdGenerator.generate_span_id" - ) - def test_invalid_trace_id( - self, mock_generate_span_id, mock_generate_trace_id - ): - """If a trace id is invalid, generate a trace id.""" + self.assertDictEqual(new_ctx, old_ctx) - mock_generate_trace_id.configure_mock(return_value=1) - mock_generate_span_id.configure_mock(return_value=2) + def test_extract_invalid_trace_id(self): + """Given invalid trace ID, do not modify context""" + old_ctx = {} carrier = { FORMAT.TRACE_ID_KEY: "abc123", FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } + new_ctx = FORMAT.extract(carrier, old_ctx) - ctx = FORMAT.extract(carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() + self.assertDictEqual(new_ctx, old_ctx) - self.assertEqual(span_context.trace_id, 1) - self.assertEqual(span_context.span_id, 2) - - @patch( - "opentelemetry.sdk.trace.id_generator.RandomIdGenerator.generate_trace_id" - ) - @patch( - "opentelemetry.sdk.trace.id_generator.RandomIdGenerator.generate_span_id" - ) - def test_invalid_span_id( - self, mock_generate_span_id, mock_generate_trace_id - ): - """If a span id is invalid, generate a trace id.""" - - mock_generate_trace_id.configure_mock(return_value=1) - mock_generate_span_id.configure_mock(return_value=2) + def test_extract_invalid_span_id(self): + """Given invalid span ID, do not modify context""" + old_ctx = {} carrier = { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.SPAN_ID_KEY: "abc123", FORMAT.FLAGS_KEY: "1", } + new_ctx = FORMAT.extract(carrier, old_ctx) - ctx = FORMAT.extract(carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() + self.assertDictEqual(new_ctx, old_ctx) - self.assertEqual(span_context.trace_id, 1) - self.assertEqual(span_context.span_id, 2) + def test_extract_missing_span_id(self): + """Given no span ID, do not modify context""" + old_ctx = {} - def test_missing_span_id(self): - """If a trace id is missing, populate an invalid trace id.""" carrier = { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.FLAGS_KEY: "1", } + new_ctx = FORMAT.extract(carrier, old_ctx) + + self.assertDictEqual(new_ctx, old_ctx) + + def test_extract_empty_carrier(self): + """Given no headers at all, do not modify context""" + old_ctx = {} + + carrier = {} + new_ctx = FORMAT.extract(carrier, old_ctx) - ctx = FORMAT.extract(carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() - self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) + self.assertDictEqual(new_ctx, old_ctx) @staticmethod def test_inject_empty_context():