diff --git a/CHANGELOG.md b/CHANGELOG.md index b22da9c9b95..6dda7ba4153 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#1721](https://github.com/open-telemetry/opentelemetry-python/pull/1721)) - Update bootstrap cmd to use exact version when installing instrumentation packages. ([#1722](https://github.com/open-telemetry/opentelemetry-python/pull/1722)) +- Fix B3 propagator to never return None. + ([#1750](https://github.com/open-telemetry/opentelemetry-python/pull/1750)) ## [1.0.0](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.0.0) - 2021-03-26 diff --git a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py index df8803d6171..2d50ea88dc5 100644 --- a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py +++ b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py @@ -97,6 +97,8 @@ def extract( or self._trace_id_regex.fullmatch(trace_id) is None or self._span_id_regex.fullmatch(span_id) is None ): + if context is None: + return trace.set_span_in_context(trace.INVALID_SPAN, context) return context trace_id = int(trace_id, 16) @@ -119,7 +121,8 @@ def extract( trace_flags=trace.TraceFlags(options), trace_state=trace.TraceState(), ) - ) + ), + context, ) def inject( diff --git a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py index fd0a9a4029a..6ee0be2ce1c 100644 --- a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py +++ b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py @@ -51,6 +51,8 @@ def get_child_parent_new_carrier(old_carrier): class TestB3Format(unittest.TestCase): + # pylint: disable=too-many-public-methods + @classmethod def setUpClass(cls): generator = id_generator.RandomIdGenerator() @@ -215,6 +217,31 @@ def test_flags_and_sampling(self): self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + def test_derived_ctx_is_returned_for_success(self): + """Ensure returned context is derived from the given context.""" + old_ctx = {"k1": "v1"} + new_ctx = FORMAT.extract( + { + FORMAT.TRACE_ID_KEY: self.serialized_trace_id, + FORMAT.SPAN_ID_KEY: self.serialized_span_id, + FORMAT.FLAGS_KEY: "1", + }, + old_ctx, + ) + self.assertIn("current-span", new_ctx) + for key, value in old_ctx.items(): + self.assertIn(key, new_ctx) + self.assertEqual(new_ctx[key], value) + + def test_derived_ctx_is_returned_for_failure(self): + """Ensure returned context is derived from the given context.""" + old_ctx = {"k2": "v2"} + new_ctx = FORMAT.extract({}, old_ctx) + self.assertNotIn("current-span", new_ctx) + for key, value in old_ctx.items(): + self.assertIn(key, new_ctx) + self.assertEqual(new_ctx[key], value) + def test_64bit_trace_id(self): """64 bit trace ids should be padded to 128 bit trace ids.""" trace_id_64_bit = self.serialized_trace_id[:16] @@ -334,3 +361,12 @@ def test_fields(self): inject_fields.add(call[1][1]) self.assertEqual(FORMAT.fields, inject_fields) + + def test_extract_none_context(self): + """Given no trace ID, do not modify context""" + old_ctx = None + + carrier = {} + new_ctx = FORMAT.extract(carrier, old_ctx) + self.assertIsNotNone(new_ctx) + self.assertEqual(new_ctx["current-span"], trace_api.INVALID_SPAN)