diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py b/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py index 0960bc73bad..ec505135bec 100644 --- a/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py @@ -34,11 +34,12 @@ def get(self, carrier: TextMapPropagatorT, key: str) -> typing.List[str]: or more values from the carrier. In the case that the value does not exist, returns an empty list. - Args: carrier: and object which contains values that are used to - construct a Context. This object must be paired with an appropriate - getter which understands how to extract a value from it. - key: key of a field in carrier. Returns: first value of the - propagation key or an empty list if the key doesn't exist. + Args: + carrier: An object which contains values that are used to + construct a Context. + key: key of a field in carrier. + Returns: first value of the propagation key or an empty list if the + key doesn't exist. """ raise NotImplementedError() @@ -46,10 +47,8 @@ def keys(self, carrier: TextMapPropagatorT) -> typing.List[str]: """Function that can retrieve all the keys in a carrier object. Args: - carrier: and object which contains values that are - used to construct a Context. This object - must be paired with an appropriate getter - which understands how to extract a value from it. + carrier: An object which contains values that are + used to construct a Context. Returns: list of keys from the carrier. """ @@ -60,31 +59,15 @@ class DictGetter(Getter[typing.Dict[str, CarrierValT]]): def get( self, carrier: typing.Dict[str, CarrierValT], key: str ) -> typing.List[str]: - val = carrier.get(key, None) - if not val: - return [] - return val if isinstance(val, typing.List) else [val] + val = carrier.get(key, []) + if isinstance(val, typing.Iterable) and not isinstance(val, str): + return list(val) + return [val] def keys(self, carrier: typing.Dict[str, CarrierValT]) -> typing.List[str]: return list(carrier.keys()) -class CustomGetter(Getter[TextMapPropagatorT]): - def __init__( - self, - get: typing.Callable[[TextMapPropagatorT, str], typing.List[str]], - keys: typing.Callable[[TextMapPropagatorT], typing.List[str]], - ): - self._get = get - self._keys = keys - - def get(self, carrier: TextMapPropagatorT, key: str) -> typing.List[str]: - return self._get(carrier, key) - - def keys(self, carrier: TextMapPropagatorT) -> typing.List[str]: - return self._keys(carrier) - - class TextMapPropagator(abc.ABC): """This class provides an interface that enables extracting and injecting context into headers of HTTP requests. HTTP frameworks and clients diff --git a/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py b/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py index d68ebe20d77..79c4618aee7 100644 --- a/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py +++ b/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py @@ -19,7 +19,7 @@ import opentelemetry.sdk.trace.propagation.b3_format as b3_format import opentelemetry.trace as trace_api from opentelemetry.context import get_current -from opentelemetry.trace.propagation.textmap import CustomGetter, DictGetter +from opentelemetry.trace.propagation.textmap import DictGetter FORMAT = b3_format.B3Format() @@ -320,13 +320,12 @@ def test_inject_empty_context(): def test_default_span(): """Make sure propagator does not crash when working with DefaultSpan""" - def default_span_getter(carrier, key): - return carrier.get(key, None) + class CarrierGetter(DictGetter): + def get(self, carrier, key): + return carrier.get(key, None) def setter(carrier, key, value): carrier[key] = value - ctx = FORMAT.extract( - CustomGetter(default_span_getter, DictGetter().keys), {} - ) + ctx = FORMAT.extract(CarrierGetter(), {}) FORMAT.inject(setter, {}, ctx)