From 25edfefd70f2d2fd8f6c6a90293d598a358fbb96 Mon Sep 17 00:00:00 2001 From: Diego Hurtado Date: Fri, 26 Mar 2021 11:20:20 -0600 Subject: [PATCH] Make setters and getters optional (#1690) Co-authored-by: alrex --- .github/workflows/test.yml | 2 +- CHANGELOG.md | 2 + docs/conf.py | 2 +- docs/examples/auto-instrumentation/README.rst | 3 +- docs/examples/auto-instrumentation/client.py | 2 +- .../server_instrumented.py | 3 +- docs/examples/datadog_exporter/client.py | 2 +- docs/examples/django/client.py | 2 +- .../baggage/propagation/__init__.py | 14 ++-- .../src/opentelemetry/propagate/__init__.py | 20 ++--- .../opentelemetry/propagators/composite.py | 12 +-- .../src/opentelemetry/propagators/textmap.py | 82 ++++++++++++++----- .../trace/propagation/tracecontext.py | 14 ++-- .../tests/baggage/test_baggage_propagation.py | 17 ++-- .../tests/propagators/test_composite.py | 30 +++---- .../propagators/test_global_httptextformat.py | 7 +- .../tests/trace/propagation/test_textmap.py | 12 +-- .../test_tracecontexthttptextformat.py | 24 ++---- .../opentelemetry/propagators/b3/__init__.py | 24 +++--- .../propagation/test_benchmark_b3_format.py | 3 - .../tests/test_b3_format.py | 36 ++++---- .../propagators/jaeger/__init__.py | 22 ++--- .../tests/test_jaeger_propagator.py | 28 +++---- .../shim/opentracing_shim/__init__.py | 6 +- .../src/opentelemetry/test/mock_textmap.py | 24 +++--- 25 files changed, 201 insertions(+), 192 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 908cf71caa3..34d58b2cdb7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ env: # Otherwise, set variable to the commit of your branch on # opentelemetry-python-contrib which is compatible with these Core repo # changes. - CONTRIB_REPO_SHA: 5bc0fa1611502be47a1f4eb550fe255e4b707ba1 + CONTRIB_REPO_SHA: 0d12fa39523212e268ef435825af2039a876fd75 jobs: build: diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a8e24366d1..4e7cbc03257 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v0.18b0...HEAD) +- Make setters and getters optional + ([#1690](https://github.com/open-telemetry/opentelemetry-python/pull/1690)) ### Added - Document how to work with fork process web server models(Gunicorn, uWSGI etc...) diff --git a/docs/conf.py b/docs/conf.py index d23cebfe96c..61dc1a82540 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -104,7 +104,7 @@ ("py:class", "opentelemetry.trace._LinkBase",), # TODO: Understand why sphinx is not able to find this local class ("py:class", "opentelemetry.propagators.textmap.TextMapPropagator",), - ("py:class", "opentelemetry.propagators.textmap.DictGetter",), + ("py:class", "opentelemetry.propagators.textmap.DefaultGetter",), ("any", "opentelemetry.propagators.textmap.TextMapPropagator.extract",), ("any", "opentelemetry.propagators.textmap.TextMapPropagator.inject",), ] diff --git a/docs/examples/auto-instrumentation/README.rst b/docs/examples/auto-instrumentation/README.rst index 607aa1b44b7..48c59951b4b 100644 --- a/docs/examples/auto-instrumentation/README.rst +++ b/docs/examples/auto-instrumentation/README.rst @@ -37,8 +37,7 @@ Manually instrumented server def server_request(): with tracer.start_as_current_span( "server_request", - context=propagators.extract(DictGetter(), request.headers - ), + context=propagators.extract(request.headers), ): print(request.args.get("param")) return "served" diff --git a/docs/examples/auto-instrumentation/client.py b/docs/examples/auto-instrumentation/client.py index fefc1f67b98..cc948cc54b8 100644 --- a/docs/examples/auto-instrumentation/client.py +++ b/docs/examples/auto-instrumentation/client.py @@ -37,7 +37,7 @@ with tracer.start_as_current_span("client-server"): headers = {} - propagators.inject(dict.__setitem__, headers) + propagators.inject(headers) requested = get( "http://localhost:8082/server_request", params={"param": argv[1]}, diff --git a/docs/examples/auto-instrumentation/server_instrumented.py b/docs/examples/auto-instrumentation/server_instrumented.py index 1ac1bd6b71b..652358e3a2e 100644 --- a/docs/examples/auto-instrumentation/server_instrumented.py +++ b/docs/examples/auto-instrumentation/server_instrumented.py @@ -17,7 +17,6 @@ from opentelemetry import trace from opentelemetry.instrumentation.wsgi import collect_request_attributes from opentelemetry.propagate import extract -from opentelemetry.propagators.textmap import DictGetter from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import ( ConsoleSpanExporter, @@ -38,7 +37,7 @@ def server_request(): with tracer.start_as_current_span( "server_request", - context=extract(DictGetter(), request.headers), + context=extract(request.headers), kind=trace.SpanKind.SERVER, attributes=collect_request_attributes(request.environ), ): diff --git a/docs/examples/datadog_exporter/client.py b/docs/examples/datadog_exporter/client.py index 6b4b5d00ec1..7c6196ad4ab 100644 --- a/docs/examples/datadog_exporter/client.py +++ b/docs/examples/datadog_exporter/client.py @@ -47,7 +47,7 @@ with tracer.start_as_current_span("client-server"): headers = {} - inject(dict.__setitem__, headers) + inject(headers) requested = get( "http://localhost:8082/server_request", params={"param": argv[1]}, diff --git a/docs/examples/django/client.py b/docs/examples/django/client.py index bc3606cbe76..3ae0cb6e1cf 100644 --- a/docs/examples/django/client.py +++ b/docs/examples/django/client.py @@ -36,7 +36,7 @@ with tracer.start_as_current_span("client-server"): headers = {} - inject(dict.__setitem__, headers) + inject(headers) requested = get( "http://localhost:8000", params={"param": argv[1]}, diff --git a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py index e6d1c4207bc..04d896baa36 100644 --- a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py +++ b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py @@ -31,9 +31,9 @@ class W3CBaggagePropagator(textmap.TextMapPropagator): def extract( self, - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, + carrier: textmap.CarrierT, context: typing.Optional[Context] = None, + getter: textmap.Getter = textmap.default_getter, ) -> Context: """Extract Baggage from the carrier. @@ -73,9 +73,9 @@ def extract( def inject( self, - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, + carrier: textmap.CarrierT, context: typing.Optional[Context] = None, + setter: textmap.Setter = textmap.default_setter, ) -> None: """Injects Baggage into the carrier. @@ -87,7 +87,7 @@ def inject( return baggage_string = _format_baggage(baggage_entries) - set_in_carrier(carrier, self._BAGGAGE_HEADER_NAME, baggage_string) + setter.set(carrier, self._BAGGAGE_HEADER_NAME, baggage_string) @property def fields(self) -> typing.Set[str]: @@ -103,8 +103,8 @@ def _format_baggage(baggage_entries: typing.Mapping[str, object]) -> str: def _extract_first_element( - items: typing.Optional[typing.Iterable[textmap.TextMapPropagatorT]], -) -> typing.Optional[textmap.TextMapPropagatorT]: + items: typing.Optional[typing.Iterable[textmap.CarrierT]], +) -> typing.Optional[textmap.CarrierT]: if items is None: return None return next(iter(items), None) diff --git a/opentelemetry-api/src/opentelemetry/propagate/__init__.py b/opentelemetry-api/src/opentelemetry/propagate/__init__.py index 44f9897a532..091b5b8d44e 100644 --- a/opentelemetry-api/src/opentelemetry/propagate/__init__.py +++ b/opentelemetry-api/src/opentelemetry/propagate/__init__.py @@ -82,9 +82,9 @@ def example_route(): def extract( - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, + carrier: textmap.CarrierT, context: typing.Optional[Context] = None, + getter: textmap.Getter = textmap.default_getter, ) -> Context: """Uses the configured propagator to extract a Context from the carrier. @@ -99,26 +99,26 @@ def extract( context: an optional Context to use. Defaults to current context if not set. """ - return get_global_textmap().extract(getter, carrier, context) + return get_global_textmap().extract(carrier, context, getter=getter) def inject( - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, + carrier: textmap.CarrierT, context: typing.Optional[Context] = None, + setter: textmap.Setter = textmap.default_setter, ) -> None: """Uses the configured propagator to inject a Context into the carrier. Args: - set_in_carrier: A setter function that can set values - on the carrier. carrier: An object that contains a representation of HTTP - headers. Should be paired with set_in_carrier, which + headers. Should be paired with setter, which should know how to set header values on the carrier. - context: an optional Context to use. Defaults to current + context: An optional Context to use. Defaults to current context if not set. + setter: An optional `Setter` object that can set values + on the carrier. """ - get_global_textmap().inject(set_in_carrier, carrier, context) + get_global_textmap().inject(carrier, context=context, setter=setter) try: diff --git a/opentelemetry-api/src/opentelemetry/propagators/composite.py b/opentelemetry-api/src/opentelemetry/propagators/composite.py index 92dc6b8a380..c027f638dcd 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/composite.py +++ b/opentelemetry-api/src/opentelemetry/propagators/composite.py @@ -35,9 +35,9 @@ def __init__( def extract( self, - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, + carrier: textmap.CarrierT, context: typing.Optional[Context] = None, + getter: textmap.Getter = textmap.default_getter, ) -> Context: """Run each of the configured propagators with the given context and carrier. Propagators are run in the order they are configured, if multiple @@ -47,14 +47,14 @@ def extract( See `opentelemetry.propagators.textmap.TextMapPropagator.extract` """ for propagator in self._propagators: - context = propagator.extract(getter, carrier, context) + context = propagator.extract(carrier, context, getter=getter) return context # type: ignore def inject( self, - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, + carrier: textmap.CarrierT, context: typing.Optional[Context] = None, + setter: textmap.Setter = textmap.default_setter, ) -> None: """Run each of the configured propagators with the given context and carrier. Propagators are run in the order they are configured, if multiple @@ -64,7 +64,7 @@ def inject( See `opentelemetry.propagators.textmap.TextMapPropagator.inject` """ for propagator in self._propagators: - propagator.inject(set_in_carrier, carrier, context) + propagator.inject(carrier, context, setter=setter) @property def fields(self) -> typing.Set[str]: diff --git a/opentelemetry-api/src/opentelemetry/propagators/textmap.py b/opentelemetry-api/src/opentelemetry/propagators/textmap.py index cf93d1d6319..45c2308f661 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/textmap.py +++ b/opentelemetry-api/src/opentelemetry/propagators/textmap.py @@ -17,19 +17,18 @@ from opentelemetry.context.context import Context -TextMapPropagatorT = typing.TypeVar("TextMapPropagatorT") +CarrierT = typing.TypeVar("CarrierT") CarrierValT = typing.Union[typing.List[str], str] -Setter = typing.Callable[[TextMapPropagatorT, str, str], None] - -class Getter(typing.Generic[TextMapPropagatorT]): +class Getter(abc.ABC): """This class implements a Getter that enables extracting propagated fields from a carrier. """ + @abc.abstractmethod def get( - self, carrier: TextMapPropagatorT, key: str + self, carrier: CarrierT, key: str ) -> typing.Optional[typing.List[str]]: """Function that can retrieve zero or more values from the carrier. In the case that @@ -42,9 +41,9 @@ def get( Returns: first value of the propagation key or None if the key doesn't exist. """ - raise NotImplementedError() - def keys(self, carrier: TextMapPropagatorT) -> typing.List[str]: + @abc.abstractmethod + def keys(self, carrier: CarrierT) -> typing.List[str]: """Function that can retrieve all the keys in a carrier object. Args: @@ -53,17 +52,33 @@ def keys(self, carrier: TextMapPropagatorT) -> typing.List[str]: Returns: list of keys from the carrier. """ - raise NotImplementedError() -class DictGetter(Getter[typing.Dict[str, CarrierValT]]): - def get( - self, carrier: typing.Dict[str, CarrierValT], key: str +class Setter(abc.ABC): + """This class implements a Setter that enables injecting propagated + fields into a carrier. + """ + + @abc.abstractmethod + def set(self, carrier: CarrierT, key: str, value: str) -> None: + """Function that can set a value into a carrier"" + + Args: + carrier: An object which contains values that are used to + construct a Context. + key: key of a field in carrier. + value: value for a field in carrier. + """ + + +class DefaultGetter(Getter): + def get( # type: ignore + self, carrier: typing.Mapping[str, CarrierValT], key: str ) -> typing.Optional[typing.List[str]]: """Getter implementation to retrieve a value from a dictionary. Args: - carrier: dictionary in which header + carrier: dictionary in which to get value key: the key used to get the value Returns: A list with a single string with the value if it exists, else None. @@ -75,11 +90,36 @@ def get( return list(val) return [val] - def keys(self, carrier: typing.Dict[str, CarrierValT]) -> typing.List[str]: + def keys( # type: ignore + self, carrier: typing.Dict[str, CarrierValT] + ) -> typing.List[str]: """Keys implementation that returns all keys from a dictionary.""" return list(carrier.keys()) +default_getter = DefaultGetter() + + +class DefaultSetter(Setter): + def set( # type: ignore + self, + carrier: typing.MutableMapping[str, CarrierValT], + key: str, + value: CarrierValT, + ) -> None: + """Setter implementation to set a value into a dictionary. + + Args: + carrier: dictionary in which to set value + key: the key used to set the value + value: the value to set + """ + carrier[key] = value + + +default_setter = DefaultSetter() + + class TextMapPropagator(abc.ABC): """This class provides an interface that enables extracting and injecting context into headers of HTTP requests. HTTP frameworks and clients @@ -92,9 +132,9 @@ class TextMapPropagator(abc.ABC): @abc.abstractmethod def extract( self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, + carrier: CarrierT, context: typing.Optional[Context] = None, + getter: Getter = default_getter, ) -> Context: """Create a Context from values in the carrier. @@ -120,25 +160,25 @@ def extract( @abc.abstractmethod def inject( self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, + carrier: CarrierT, context: typing.Optional[Context] = None, + setter: Setter = default_setter, ) -> None: """Inject values from a Context into a carrier. inject enables the propagation of values into HTTP clients or other objects which perform an HTTP request. Implementations - should use the set_in_carrier method to set values on the + should use the `Setter` 's set method to set values on the carrier. Args: - set_in_carrier: A setter function that can set values - on the carrier. carrier: An object that a place to define HTTP headers. - Should be paired with set_in_carrier, which should + Should be paired with setter, which should know how to set header values on the carrier. context: an optional Context to use. Defaults to current context if not set. + setter: An optional `Setter` object that can set values + on the carrier. """ diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py index 480e716bf78..9fc5cfed242 100644 --- a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py @@ -35,9 +35,9 @@ class TraceContextTextMapPropagator(textmap.TextMapPropagator): def extract( self, - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, + carrier: textmap.CarrierT, context: typing.Optional[Context] = None, + getter: textmap.Getter = textmap.default_getter, ) -> Context: """Extracts SpanContext from the carrier. @@ -85,9 +85,9 @@ def extract( def inject( self, - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, + carrier: textmap.CarrierT, context: typing.Optional[Context] = None, + setter: textmap.Setter = textmap.default_setter, ) -> None: """Injects SpanContext into the carrier. @@ -102,12 +102,10 @@ def inject( trace_id=format_trace_id(span_context.trace_id), span_id=format_span_id(span_context.span_id), ) - set_in_carrier( - carrier, self._TRACEPARENT_HEADER_NAME, traceparent_string - ) + setter.set(carrier, self._TRACEPARENT_HEADER_NAME, traceparent_string) if span_context.trace_state: tracestate_string = span_context.trace_state.to_header() - set_in_carrier( + setter.set( carrier, self._TRACESTATE_HEADER_NAME, tracestate_string ) diff --git a/opentelemetry-api/tests/baggage/test_baggage_propagation.py b/opentelemetry-api/tests/baggage/test_baggage_propagation.py index 3047ddbbe46..9084bb778e0 100644 --- a/opentelemetry-api/tests/baggage/test_baggage_propagation.py +++ b/opentelemetry-api/tests/baggage/test_baggage_propagation.py @@ -20,9 +20,6 @@ from opentelemetry import baggage from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.context import get_current -from opentelemetry.propagators.textmap import DictGetter - -carrier_getter = DictGetter() class TestBaggagePropagation(unittest.TestCase): @@ -32,7 +29,7 @@ def setUp(self): def _extract(self, header_value): """Test helper""" header = {"baggage": [header_value]} - return baggage.get_all(self.propagator.extract(carrier_getter, header)) + return baggage.get_all(self.propagator.extract(header)) def _inject(self, values): """Test helper""" @@ -40,13 +37,11 @@ def _inject(self, values): for k, v in values.items(): ctx = baggage.set_baggage(k, v, context=ctx) output = {} - self.propagator.inject(dict.__setitem__, output, context=ctx) + self.propagator.inject(output, context=ctx) return output.get("baggage") def test_no_context_header(self): - baggage_entries = baggage.get_all( - self.propagator.extract(carrier_getter, {}) - ) + baggage_entries = baggage.get_all(self.propagator.extract({})) self.assertEqual(baggage_entries, {}) def test_empty_context_header(self): @@ -149,13 +144,13 @@ def test_inject_non_string_values(self): @patch("opentelemetry.baggage.propagation._format_baggage") def test_fields(self, mock_format_baggage, mock_baggage): - mock_set_in_carrier = Mock() + mock_setter = Mock() - self.propagator.inject(mock_set_in_carrier, {}) + self.propagator.inject({}, setter=mock_setter) inject_fields = set() - for mock_call in mock_set_in_carrier.mock_calls: + for mock_call in mock_setter.mock_calls: inject_fields.add(mock_call[1][1]) self.assertEqual(inject_fields, self.propagator.fields) diff --git a/opentelemetry-api/tests/propagators/test_composite.py b/opentelemetry-api/tests/propagators/test_composite.py index e33649bbdd8..ef9fae2a1ac 100644 --- a/opentelemetry-api/tests/propagators/test_composite.py +++ b/opentelemetry-api/tests/propagators/test_composite.py @@ -26,16 +26,16 @@ def get_as_list(dict_object, key): def mock_inject(name, value="data"): - def wrapped(setter, carrier=None, context=None): + def wrapped(carrier=None, context=None, setter=None): carrier[name] = value - setter({}, "inject_field_{}_0".format(name), None) - setter({}, "inject_field_{}_1".format(name), None) + setter.set({}, "inject_field_{}_0".format(name), None) + setter.set({}, "inject_field_{}_1".format(name), None) return wrapped def mock_extract(name, value="context"): - def wrapped(getter, carrier=None, context=None): + def wrapped(carrier=None, context=None, getter=None): new_context = context.copy() new_context[name] = value return new_context @@ -69,11 +69,11 @@ def setUpClass(cls): def test_no_propagators(self): propagator = CompositeHTTPPropagator([]) new_carrier = {} - propagator.inject(dict.__setitem__, carrier=new_carrier) + propagator.inject(new_carrier) self.assertEqual(new_carrier, {}) context = propagator.extract( - get_as_list, carrier=new_carrier, context={} + carrier=new_carrier, context={}, getter=get_as_list ) self.assertEqual(context, {}) @@ -81,11 +81,11 @@ def test_single_propagator(self): propagator = CompositeHTTPPropagator([self.mock_propagator_0]) new_carrier = {} - propagator.inject(dict.__setitem__, carrier=new_carrier) + propagator.inject(new_carrier) self.assertEqual(new_carrier, {"mock-0": "data"}) context = propagator.extract( - get_as_list, carrier=new_carrier, context={} + carrier=new_carrier, context={}, getter=get_as_list ) self.assertEqual(context, {"mock-0": "context"}) @@ -95,11 +95,11 @@ def test_multiple_propagators(self): ) new_carrier = {} - propagator.inject(dict.__setitem__, carrier=new_carrier) + propagator.inject(new_carrier) self.assertEqual(new_carrier, {"mock-0": "data", "mock-1": "data"}) context = propagator.extract( - get_as_list, carrier=new_carrier, context={} + carrier=new_carrier, context={}, getter=get_as_list ) self.assertEqual(context, {"mock-0": "context", "mock-1": "context"}) @@ -111,11 +111,11 @@ def test_multiple_propagators_same_key(self): ) new_carrier = {} - propagator.inject(dict.__setitem__, carrier=new_carrier) + propagator.inject(new_carrier) self.assertEqual(new_carrier, {"mock-0": "data2"}) context = propagator.extract( - get_as_list, carrier=new_carrier, context={} + carrier=new_carrier, context={}, getter=get_as_list ) self.assertEqual(context, {"mock-0": "context2"}) @@ -128,13 +128,13 @@ def test_fields(self): ] ) - mock_set_in_carrier = Mock() + mock_setter = Mock() - propagator.inject(mock_set_in_carrier, {}) + propagator.inject({}, setter=mock_setter) inject_fields = set() - for mock_call in mock_set_in_carrier.mock_calls: + for mock_call in mock_setter.mock_calls: inject_fields.add(mock_call[1][1]) self.assertEqual(inject_fields, propagator.fields) diff --git a/opentelemetry-api/tests/propagators/test_global_httptextformat.py b/opentelemetry-api/tests/propagators/test_global_httptextformat.py index 6ba32e46183..466ce6895f8 100644 --- a/opentelemetry-api/tests/propagators/test_global_httptextformat.py +++ b/opentelemetry-api/tests/propagators/test_global_httptextformat.py @@ -18,12 +18,9 @@ from opentelemetry import baggage, trace from opentelemetry.propagate import extract, inject -from opentelemetry.propagators.textmap import DictGetter from opentelemetry.trace import get_current_span, set_span_in_context from opentelemetry.trace.span import format_span_id, format_trace_id -carrier_getter = DictGetter() - class TestDefaultGlobalPropagator(unittest.TestCase): """Test ensures the default global composite propagator works as intended""" @@ -42,7 +39,7 @@ def test_propagation(self): "traceparent": [traceparent_value], "tracestate": [tracestate_value], } - ctx = extract(carrier_getter, headers) + ctx = extract(headers) baggage_entries = baggage.get_all(context=ctx) expected = {"key1": "val1", "key2": "val2"} self.assertEqual(baggage_entries, expected) @@ -56,7 +53,7 @@ def test_propagation(self): ctx = baggage.set_baggage("key4", "val4", context=ctx) ctx = set_span_in_context(span, context=ctx) output = {} - inject(dict.__setitem__, output, context=ctx) + inject(output, context=ctx) self.assertEqual(traceparent_value, output["traceparent"]) self.assertIn("key3=val3", output["baggage"]) self.assertIn("key4=val4", output["baggage"]) diff --git a/opentelemetry-api/tests/trace/propagation/test_textmap.py b/opentelemetry-api/tests/trace/propagation/test_textmap.py index e47a0d22cb4..6b22d46f88e 100644 --- a/opentelemetry-api/tests/trace/propagation/test_textmap.py +++ b/opentelemetry-api/tests/trace/propagation/test_textmap.py @@ -16,29 +16,29 @@ import unittest -from opentelemetry.propagators.textmap import DictGetter +from opentelemetry.propagators.textmap import DefaultGetter -class TestDictGetter(unittest.TestCase): +class TestDefaultGetter(unittest.TestCase): def test_get_none(self): - getter = DictGetter() + getter = DefaultGetter() carrier = {} val = getter.get(carrier, "test") self.assertIsNone(val) def test_get_str(self): - getter = DictGetter() + getter = DefaultGetter() carrier = {"test": "val"} val = getter.get(carrier, "test") self.assertEqual(val, ["val"]) def test_get_iter(self): - getter = DictGetter() + getter = DefaultGetter() carrier = {"test": ["val"]} val = getter.get(carrier, "test") self.assertEqual(val, ["val"]) def test_keys(self): - getter = DictGetter() + getter = DefaultGetter() keys = getter.keys({"test": "val"}) self.assertEqual(keys, ["test"]) diff --git a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py index 79683d43d94..98ca50610b9 100644 --- a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py +++ b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py @@ -19,14 +19,11 @@ from unittest.mock import Mock, patch from opentelemetry import trace -from opentelemetry.propagators.textmap import DictGetter from opentelemetry.trace.propagation import tracecontext from opentelemetry.trace.span import TraceState FORMAT = tracecontext.TraceContextTextMapPropagator() -carrier_getter = DictGetter() - class TestTraceContextFormat(unittest.TestCase): TRACE_ID = int("12345678901234567890123456789012", 16) # type:int @@ -42,7 +39,7 @@ def test_no_traceparent_header(self): trace-id and parent-id that represents the current request. """ output = {} # type:typing.Dict[str, typing.List[str]] - span = trace.get_current_span(FORMAT.extract(carrier_getter, output)) + span = trace.get_current_span(FORMAT.extract(output)) self.assertIsInstance(span.get_span_context(), trace.SpanContext) def test_headers_with_tracestate(self): @@ -56,7 +53,6 @@ def test_headers_with_tracestate(self): tracestate_value = "foo=1,bar=2,baz=3" span_context = trace.get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [traceparent_value], "tracestate": [tracestate_value], @@ -73,7 +69,7 @@ def test_headers_with_tracestate(self): span = trace.NonRecordingSpan(span_context) ctx = trace.set_span_in_context(span) - FORMAT.inject(dict.__setitem__, output, ctx) + FORMAT.inject(output, context=ctx) self.assertEqual(output["traceparent"], traceparent_value) for pair in ["foo=1", "bar=2", "baz=3"]: self.assertIn(pair, output["tracestate"]) @@ -100,7 +96,6 @@ def test_invalid_trace_id(self): """ span = trace.get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ "00-00000000000000000000000000000000-1234567890123456-00" @@ -131,7 +126,6 @@ def test_invalid_parent_id(self): """ span = trace.get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ "00-00000000000000000000000000000000-0000000000000000-00" @@ -155,7 +149,7 @@ def test_no_send_empty_tracestate(self): trace.SpanContext(self.TRACE_ID, self.SPAN_ID, is_remote=False) ) ctx = trace.set_span_in_context(span) - FORMAT.inject(dict.__setitem__, output, ctx) + FORMAT.inject(output, context=ctx) self.assertTrue("traceparent" in output) self.assertFalse("tracestate" in output) @@ -169,7 +163,6 @@ def test_format_not_supported(self): """ span = trace.get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ "00-12345678901234567890123456789012-" @@ -185,14 +178,13 @@ def test_propagate_invalid_context(self): """Do not propagate invalid trace context.""" output = {} # type:typing.Dict[str, str] ctx = trace.set_span_in_context(trace.INVALID_SPAN) - FORMAT.inject(dict.__setitem__, output, context=ctx) + FORMAT.inject(output, context=ctx) self.assertFalse("traceparent" in output) def test_tracestate_empty_header(self): """Test tracestate with an additional empty header (should be ignored)""" span = trace.get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ "00-12345678901234567890123456789012-1234567890123456-00" @@ -207,7 +199,6 @@ def test_tracestate_header_with_trailing_comma(self): """Do not propagate invalid trace context.""" span = trace.get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ "00-12345678901234567890123456789012-1234567890123456-00" @@ -230,7 +221,6 @@ def test_tracestate_keys(self): ) span = trace.get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ "00-12345678901234567890123456789012-" @@ -270,13 +260,13 @@ def test_fields(self, mock_get_current_span, mock_invalid_span_context): ) ) - mock_set_in_carrier = Mock() + mock_setter = Mock() - FORMAT.inject(mock_set_in_carrier, {}) + FORMAT.inject({}, setter=mock_setter) inject_fields = set() - for mock_call in mock_set_in_carrier.mock_calls: + for mock_call in mock_setter.mock_calls: inject_fields.add(mock_call[1][1]) self.assertEqual(inject_fields, FORMAT.fields) diff --git a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py index 01abcc7c879..be478b05ec0 100644 --- a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py +++ b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py @@ -18,10 +18,12 @@ import opentelemetry.trace as trace from opentelemetry.context import Context from opentelemetry.propagators.textmap import ( + CarrierT, Getter, Setter, TextMapPropagator, - TextMapPropagatorT, + default_getter, + default_setter, ) from opentelemetry.trace import format_span_id, format_trace_id @@ -44,9 +46,9 @@ class B3Format(TextMapPropagator): def extract( self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, + carrier: CarrierT, context: typing.Optional[Context] = None, + getter: Getter = default_getter, ) -> Context: trace_id = format_trace_id(trace.INVALID_TRACE_ID) span_id = format_span_id(trace.INVALID_SPAN_ID) @@ -127,9 +129,9 @@ def extract( def inject( self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, + carrier: CarrierT, context: typing.Optional[Context] = None, + setter: Setter = default_setter, ) -> None: span = trace.get_current_span(context=context) @@ -138,20 +140,20 @@ def inject( return sampled = (trace.TraceFlags.SAMPLED & span_context.trace_flags) != 0 - set_in_carrier( + setter.set( carrier, self.TRACE_ID_KEY, format_trace_id(span_context.trace_id), ) - set_in_carrier( + setter.set( carrier, self.SPAN_ID_KEY, format_span_id(span_context.span_id) ) span_parent = getattr(span, "parent", None) if span_parent is not None: - set_in_carrier( + setter.set( carrier, self.PARENT_SPAN_ID_KEY, format_span_id(span_parent.span_id), ) - set_in_carrier(carrier, self.SAMPLED_KEY, "1" if sampled else "0") + setter.set(carrier, self.SAMPLED_KEY, "1" if sampled else "0") @property def fields(self) -> typing.Set[str]: @@ -164,8 +166,8 @@ def fields(self) -> typing.Set[str]: def _extract_first_element( - items: typing.Iterable[TextMapPropagatorT], -) -> typing.Optional[TextMapPropagatorT]: + items: typing.Iterable[CarrierT], +) -> typing.Optional[CarrierT]: if items is None: return None return next(iter(items), None) diff --git a/propagator/opentelemetry-propagator-b3/tests/performance/benchmarks/trace/propagation/test_benchmark_b3_format.py b/propagator/opentelemetry-propagator-b3/tests/performance/benchmarks/trace/propagation/test_benchmark_b3_format.py index 5048f495f06..3a7a251ad88 100644 --- a/propagator/opentelemetry-propagator-b3/tests/performance/benchmarks/trace/propagation/test_benchmark_b3_format.py +++ b/propagator/opentelemetry-propagator-b3/tests/performance/benchmarks/trace/propagation/test_benchmark_b3_format.py @@ -14,7 +14,6 @@ import opentelemetry.propagators.b3 as b3_format import opentelemetry.sdk.trace as trace -from opentelemetry.propagators.textmap import DictGetter FORMAT = b3_format.B3Format() @@ -22,7 +21,6 @@ def test_extract_single_header(benchmark): benchmark( FORMAT.extract, - DictGetter(), { FORMAT.SINGLE_HEADER_KEY: "bdb5b63237ed38aea578af665aa5aa60-c32d953d73ad2251-1-11fd79a30b0896cd285b396ae102dd76" }, @@ -35,7 +33,6 @@ def test_inject_empty_context(benchmark): with tracer.start_as_current_span("Child Span"): benchmark( FORMAT.inject, - dict.__setitem__, { FORMAT.TRACE_ID_KEY: "bdb5b63237ed38aea578af665aa5aa60", FORMAT.SPAN_ID_KEY: "00000000000000000c32d953d73ad225", diff --git a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py index f9d3bce1adb..d1d96a269f0 100644 --- a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py +++ b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py @@ -20,17 +20,14 @@ import opentelemetry.sdk.trace.id_generator as id_generator import opentelemetry.trace as trace_api from opentelemetry.context import get_current -from opentelemetry.propagators.textmap import DictGetter +from opentelemetry.propagators.textmap import DefaultGetter FORMAT = b3_format.B3Format() -carrier_getter = DictGetter() - - def get_child_parent_new_carrier(old_carrier): - ctx = FORMAT.extract(carrier_getter, old_carrier) + ctx = FORMAT.extract(old_carrier) parent_span_context = trace_api.get_current_span(ctx).get_span_context() parent = trace._Span("parent", parent_span_context) @@ -48,7 +45,7 @@ def get_child_parent_new_carrier(old_carrier): new_carrier = {} ctx = trace_api.set_span_in_context(child) - FORMAT.inject(dict.__setitem__, new_carrier, context=ctx) + FORMAT.inject(new_carrier, context=ctx) return child, parent, new_carrier @@ -239,7 +236,7 @@ def test_invalid_single_header(self): invalid SpanContext. """ carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} - ctx = FORMAT.extract(carrier_getter, carrier) + ctx = FORMAT.extract(carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) @@ -251,7 +248,7 @@ def test_missing_trace_id(self): FORMAT.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) + ctx = FORMAT.extract(carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) @@ -275,7 +272,7 @@ def test_invalid_trace_id( FORMAT.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) + ctx = FORMAT.extract(carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, 1) @@ -301,7 +298,7 @@ def test_invalid_span_id( FORMAT.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) + ctx = FORMAT.extract(carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, 1) @@ -314,7 +311,7 @@ def test_missing_span_id(self): FORMAT.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) + ctx = FORMAT.extract(carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) @@ -322,37 +319,34 @@ def test_missing_span_id(self): def test_inject_empty_context(): """If the current context has no span, don't add headers""" new_carrier = {} - FORMAT.inject(dict.__setitem__, new_carrier, get_current()) + FORMAT.inject(new_carrier, get_current()) assert len(new_carrier) == 0 @staticmethod def test_default_span(): """Make sure propagator does not crash when working with NonRecordingSpan""" - class CarrierGetter(DictGetter): + class CarrierGetter(DefaultGetter): def get(self, carrier, key): return carrier.get(key, None) - def setter(carrier, key, value): - carrier[key] = value - - ctx = FORMAT.extract(CarrierGetter(), {}) - FORMAT.inject(setter, {}, ctx) + ctx = FORMAT.extract({}, getter=CarrierGetter()) + FORMAT.inject({}, context=ctx) def test_fields(self): """Make sure the fields attribute returns the fields used in inject""" tracer = trace.TracerProvider().get_tracer("sdk_tracer_provider") - mock_set_in_carrier = Mock() + mock_setter = Mock() with tracer.start_as_current_span("parent"): with tracer.start_as_current_span("child"): - FORMAT.inject(mock_set_in_carrier, {}) + FORMAT.inject({}, setter=mock_setter) inject_fields = set() - for call in mock_set_in_carrier.mock_calls: + for call in mock_setter.mock_calls: inject_fields.add(call[1][1]) self.assertEqual(FORMAT.fields, inject_fields) diff --git a/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py b/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py index 8e7fe5f69ff..47f438531fb 100644 --- a/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py +++ b/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py @@ -19,10 +19,12 @@ from opentelemetry import baggage from opentelemetry.context import Context, get_current from opentelemetry.propagators.textmap import ( + CarrierT, Getter, Setter, TextMapPropagator, - TextMapPropagatorT, + default_getter, + default_setter, ) from opentelemetry.trace import format_span_id, format_trace_id @@ -39,9 +41,9 @@ class JaegerPropagator(TextMapPropagator): def extract( self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, + carrier: CarrierT, context: typing.Optional[Context] = None, + getter: Getter = default_getter, ) -> Context: if context is None: @@ -76,9 +78,9 @@ def extract( def inject( self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, + carrier: CarrierT, context: typing.Optional[Context] = None, + setter: Setter = default_setter, ) -> None: span = trace.get_current_span(context=context) span_context = span.get_span_context() @@ -91,7 +93,7 @@ def inject( trace_flags |= self.DEBUG_FLAG # set span identity - set_in_carrier( + setter.set( carrier, self.TRACE_ID_KEY, _format_uber_trace_id( @@ -108,9 +110,7 @@ def inject( return for key, value in baggage_entries.items(): baggage_key = self.BAGGAGE_PREFIX + key - set_in_carrier( - carrier, baggage_key, urllib.parse.quote(str(value)) - ) + setter.set(carrier, baggage_key, urllib.parse.quote(str(value))) @property def fields(self) -> typing.Set[str]: @@ -142,8 +142,8 @@ def _format_uber_trace_id(trace_id, span_id, parent_span_id, flags): def _extract_first_element( - items: typing.Iterable[TextMapPropagatorT], -) -> typing.Optional[TextMapPropagatorT]: + items: typing.Iterable[CarrierT], +) -> typing.Optional[CarrierT]: if items is None: return None return next(iter(items), None) diff --git a/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py b/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py index da8a855edbe..003bbfeb49d 100644 --- a/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py +++ b/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py @@ -22,17 +22,13 @@ from opentelemetry.propagators import ( # pylint: disable=no-name-in-module jaeger, ) -from opentelemetry.propagators.textmap import DictGetter FORMAT = jaeger.JaegerPropagator() -carrier_getter = DictGetter() - - def get_context_new_carrier(old_carrier, carrier_baggage=None): - ctx = FORMAT.extract(carrier_getter, old_carrier) + ctx = FORMAT.extract(old_carrier) if carrier_baggage: for key, value in carrier_baggage.items(): ctx = baggage.set_baggage(key, value, ctx) @@ -54,7 +50,7 @@ def get_context_new_carrier(old_carrier, carrier_baggage=None): new_carrier = {} ctx = trace_api.set_span_in_context(child, ctx) - FORMAT.inject(dict.__setitem__, new_carrier, context=ctx) + FORMAT.inject(new_carrier, context=ctx) return ctx, new_carrier @@ -72,14 +68,14 @@ def setUpClass(cls): def test_extract_valid_span(self): old_carrier = {FORMAT.TRACE_ID_KEY: self.serialized_uber_trace_id} - ctx = FORMAT.extract(carrier_getter, old_carrier) + ctx = FORMAT.extract(old_carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, self.trace_id) self.assertEqual(span_context.span_id, self.span_id) def test_missing_carrier(self): old_carrier = {} - ctx = FORMAT.extract(carrier_getter, old_carrier) + ctx = FORMAT.extract(old_carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) @@ -132,7 +128,7 @@ def test_baggage(self): old_carrier = {FORMAT.TRACE_ID_KEY: self.serialized_uber_trace_id} input_baggage = {"key1": "value1"} _, new_carrier = get_context_new_carrier(old_carrier, input_baggage) - ctx = FORMAT.extract(carrier_getter, new_carrier) + ctx = FORMAT.extract(new_carrier) self.assertDictEqual(input_baggage, ctx["baggage"]) def test_non_string_baggage(self): @@ -140,7 +136,7 @@ def test_non_string_baggage(self): input_baggage = {"key1": 1, "key2": True} formatted_baggage = {"key1": "1", "key2": "True"} _, new_carrier = get_context_new_carrier(old_carrier, input_baggage) - ctx = FORMAT.extract(carrier_getter, new_carrier) + ctx = FORMAT.extract(new_carrier) self.assertDictEqual(formatted_baggage, ctx["baggage"]) def test_extract_invalid_uber_trace_id(self): @@ -149,7 +145,7 @@ def test_extract_invalid_uber_trace_id(self): "uberctx-key1": "value1", } formatted_baggage = {"key1": "value1"} - context = FORMAT.extract(carrier_getter, old_carrier) + context = FORMAT.extract(old_carrier) span_context = trace_api.get_current_span(context).get_span_context() self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) self.assertDictEqual(formatted_baggage, context["baggage"]) @@ -160,7 +156,7 @@ def test_extract_invalid_trace_id(self): "uberctx-key1": "value1", } formatted_baggage = {"key1": "value1"} - context = FORMAT.extract(carrier_getter, old_carrier) + context = FORMAT.extract(old_carrier) span_context = trace_api.get_current_span(context).get_span_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) self.assertDictEqual(formatted_baggage, context["baggage"]) @@ -171,18 +167,18 @@ def test_extract_invalid_span_id(self): "uberctx-key1": "value1", } formatted_baggage = {"key1": "value1"} - context = FORMAT.extract(carrier_getter, old_carrier) + context = FORMAT.extract(old_carrier) span_context = trace_api.get_current_span(context).get_span_context() self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) self.assertDictEqual(formatted_baggage, context["baggage"]) def test_fields(self): tracer = trace.TracerProvider().get_tracer("sdk_tracer_provider") - mock_set_in_carrier = Mock() + mock_setter = Mock() with tracer.start_as_current_span("parent"): with tracer.start_as_current_span("child"): - FORMAT.inject(mock_set_in_carrier, {}) + FORMAT.inject({}, setter=mock_setter) inject_fields = set() - for call in mock_set_in_carrier.mock_calls: + for call in mock_setter.mock_calls: inject_fields.add(call[1][1]) self.assertEqual(FORMAT.fields, inject_fields) diff --git a/shim/opentelemetry-opentracing-shim/src/opentelemetry/shim/opentracing_shim/__init__.py b/shim/opentelemetry-opentracing-shim/src/opentelemetry/shim/opentracing_shim/__init__.py index b7a365302f9..2327abbfae1 100644 --- a/shim/opentelemetry-opentracing-shim/src/opentelemetry/shim/opentracing_shim/__init__.py +++ b/shim/opentelemetry-opentracing-shim/src/opentelemetry/shim/opentracing_shim/__init__.py @@ -102,7 +102,6 @@ from opentelemetry.baggage import get_baggage, set_baggage from opentelemetry.context import Context, attach, detach, get_value, set_value from opentelemetry.propagate import get_global_textmap -from opentelemetry.propagators.textmap import DictGetter from opentelemetry.shim.opentracing_shim import util from opentelemetry.shim.opentracing_shim.version import __version__ from opentelemetry.trace import INVALID_SPAN_CONTEXT, Link, NonRecordingSpan @@ -527,7 +526,6 @@ def __init__(self, tracer: OtelTracer): Format.TEXT_MAP, Format.HTTP_HEADERS, ) - self._carrier_getter = DictGetter() def unwrap(self): """Returns the :class:`opentelemetry.trace.Tracer` object that is @@ -684,7 +682,7 @@ def inject(self, span_context, format: object, carrier: object): propagator = get_global_textmap() ctx = set_span_in_context(NonRecordingSpan(span_context.unwrap())) - propagator.inject(type(carrier).__setitem__, carrier, context=ctx) + propagator.inject(carrier, context=ctx) def extract(self, format: object, carrier: object): """Returns an ``opentracing.SpanContext`` instance extracted from a @@ -712,7 +710,7 @@ def extract(self, format: object, carrier: object): raise UnsupportedFormatException propagator = get_global_textmap() - ctx = propagator.extract(self._carrier_getter, carrier) + ctx = propagator.extract(carrier) span = get_current_span(ctx) if span is not None: otel_context = span.get_span_context() diff --git a/tests/util/src/opentelemetry/test/mock_textmap.py b/tests/util/src/opentelemetry/test/mock_textmap.py index 1edd079042f..4cdef447d6a 100644 --- a/tests/util/src/opentelemetry/test/mock_textmap.py +++ b/tests/util/src/opentelemetry/test/mock_textmap.py @@ -17,10 +17,12 @@ from opentelemetry import trace from opentelemetry.context import Context, get_current from opentelemetry.propagators.textmap import ( + CarrierT, Getter, Setter, TextMapPropagator, - TextMapPropagatorT, + default_getter, + default_setter, ) @@ -33,17 +35,17 @@ class NOOPTextMapPropagator(TextMapPropagator): def extract( self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, + carrier: CarrierT, context: typing.Optional[Context] = None, + getter: Getter = default_getter, ) -> Context: return get_current() def inject( self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, + carrier: CarrierT, context: typing.Optional[Context] = None, + setter: Setter = default_setter, ) -> None: return None @@ -60,9 +62,9 @@ class MockTextMapPropagator(TextMapPropagator): def extract( self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, + carrier: CarrierT, context: typing.Optional[Context] = None, + getter: Getter = default_getter, ) -> Context: trace_id_list = getter.get(carrier, self.TRACE_ID_KEY) span_id_list = getter.get(carrier, self.SPAN_ID_KEY) @@ -82,15 +84,15 @@ def extract( def inject( self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, + carrier: CarrierT, context: typing.Optional[Context] = None, + setter: Setter = default_setter, ) -> None: span = trace.get_current_span(context) - set_in_carrier( + setter.set( carrier, self.TRACE_ID_KEY, str(span.get_span_context().trace_id) ) - set_in_carrier( + setter.set( carrier, self.SPAN_ID_KEY, str(span.get_span_context().span_id) )