From 03dae3bece45fb0e45a035d66004017d1fc39af9 Mon Sep 17 00:00:00 2001 From: Prajilesh N Date: Sat, 3 Oct 2020 01:24:26 +0530 Subject: [PATCH] converted textmap propagator getter to a class and added keys method --- .../exporter/datadog/propagator.py | 10 ++-- .../tests/test_datadog_format.py | 15 ++++-- .../instrumentation/asgi/__init__.py | 38 +++++++++------ .../instrumentation/celery/__init__.py | 20 ++++++-- .../instrumentation/grpc/_server.py | 16 +++++-- .../opentracing_shim/__init__.py | 46 +++++++++++++++++-- .../instrumentation/tornado/__init__.py | 17 +++++-- .../instrumentation/wsgi/__init__.py | 36 ++++++++++----- .../baggage/propagation/__init__.py | 4 +- .../src/opentelemetry/propagators/__init__.py | 8 ++-- .../opentelemetry/propagators/composite.py | 2 +- .../trace/propagation/textmap.py | 41 ++++++++++++++++- .../trace/propagation/tracecontext.py | 6 +-- .../tests/baggage/test_baggage_propagation.py | 15 ++++-- .../propagators/test_global_httptextformat.py | 17 +++++-- .../test_tracecontexthttptextformat.py | 17 +++++-- opentelemetry-sdk/CHANGELOG.md | 2 + .../sdk/trace/propagation/b3_format.py | 12 ++--- .../tests/trace/propagation/test_b3_format.py | 20 +++++--- .../src/opentelemetry/test/mock_textmap.py | 8 ++-- 20 files changed, 256 insertions(+), 94 deletions(-) diff --git a/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py b/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py index d2e60476e68..df73b4f1664 100644 --- a/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py +++ b/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py @@ -39,24 +39,24 @@ class DatadogFormat(TextMapPropagator): def extract( self, - get_from_carrier: Getter[TextMapPropagatorT], + get_from_carrier: Getter, carrier: TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: trace_id = extract_first_element( - get_from_carrier(carrier, self.TRACE_ID_KEY) + get_from_carrier.get(carrier, self.TRACE_ID_KEY) ) span_id = extract_first_element( - get_from_carrier(carrier, self.PARENT_ID_KEY) + get_from_carrier.get(carrier, self.PARENT_ID_KEY) ) sampled = extract_first_element( - get_from_carrier(carrier, self.SAMPLING_PRIORITY_KEY) + get_from_carrier.get(carrier, self.SAMPLING_PRIORITY_KEY) ) origin = extract_first_element( - get_from_carrier(carrier, self.ORIGIN_KEY) + get_from_carrier.get(carrier, self.ORIGIN_KEY) ) trace_flags = trace.TraceFlags() diff --git a/exporter/opentelemetry-exporter-datadog/tests/test_datadog_format.py b/exporter/opentelemetry-exporter-datadog/tests/test_datadog_format.py index 994cac2d602..6ff865ba815 100644 --- a/exporter/opentelemetry-exporter-datadog/tests/test_datadog_format.py +++ b/exporter/opentelemetry-exporter-datadog/tests/test_datadog_format.py @@ -22,9 +22,18 @@ FORMAT = propagator.DatadogFormat() -def get_as_list(dict_object, key): - value = dict_object.get(key) - return [value] if value is not None else [] +class Getter: + @staticmethod + def get(dict_object, key): + value = dict_object.get(key) + return [value] if value is not None else [] + + @staticmethod + def keys(dict_object): + return dict_object.keys() + + +get_as_list = Getter() class TestDatadogFormat(unittest.TestCase): diff --git a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py index 879662ddcfb..7443e76faba 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py @@ -32,19 +32,31 @@ from opentelemetry.trace.status import Status, StatusCanonicalCode -def get_header_from_scope(scope: dict, header_name: str) -> typing.List[str]: - """Retrieve a HTTP header value from the ASGI scope. +class Getter: + @staticmethod + def get(scope: dict, header_name: str) -> typing.List[str]: + """Retrieve a HTTP header value from the ASGI scope. + + Returns: + A list with a single string with the header value if it exists, else an empty list. + """ + headers = scope.get("headers") + return [ + value.decode("utf8") + for (key, value) in headers + if key.decode("utf8") == header_name] + + @staticmethod + def keys(scope: dict) -> typing.List[str]: + """Retrieve all the HTTP header keys for an ASGI scope.. + + Returns: + A list with all the keys in scope. + """ + return scope.keys() - Returns: - A list with a single string with the header value if it exists, else an empty list. - """ - headers = scope.get("headers") - return [ - value.decode("utf8") - for (key, value) in headers - if key.decode("utf8") == header_name - ] +get_header_from_scope = Getter() def collect_request_attributes(scope): """Collects HTTP request attributes from the ASGI scope and returns a @@ -72,10 +84,10 @@ def collect_request_attributes(scope): http_method = scope.get("method") if http_method: result["http.method"] = http_method - http_host_value = ",".join(get_header_from_scope(scope, "host")) + http_host_value = ",".join(get_header_from_scope.get(scope, "host")) if http_host_value: result["http.server_name"] = http_host_value - http_user_agent = get_header_from_scope(scope, "user-agent") + http_user_agent = get_header_from_scope.get(scope, "user-agent") if len(http_user_agent) > 0: result["http.user_agent"] = http_user_agent[0] diff --git a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py index 4768e93d18e..8c02c66783f 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py @@ -118,6 +118,7 @@ def _trace_prerun(self, *args, **kwargs): return request = task.request + carrier_extractor = Getter() tracectx = propagators.extract(carrier_extractor, request) or {} parent = get_current_span(tracectx) @@ -248,8 +249,17 @@ def _trace_retry(*args, **kwargs): span.set_attribute(_TASK_RETRY_REASON_KEY, str(reason)) -def carrier_extractor(carrier, key): - value = getattr(carrier, key, []) - if isinstance(value, str) or not isinstance(value, Iterable): - value = (value,) - return value +class Getter: + @staticmethod + def get(carrier, key): + value = getattr(carrier, key, []) + if isinstance(value, str) or not isinstance(value, Iterable): + value = (value,) + return value + + @staticmethod + def keys(carrier): + return carrier.keys() + + + diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py index cb0e997d367..6769dba564a 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py @@ -109,6 +109,17 @@ def _check_error_code(span, servicer_context, rpc_info): if servicer_context.code != grpc.StatusCode.OK: rpc_info.error = servicer_context.code +class Getter: + @staticmethod + def get(metadata, key) -> List[str]: + md_dict = {md.key: md.value for md in metadata} + return [md_dict[key]] if key in md_dict else [] + + @staticmethod + def keys(metadata) -> List[str]: + md_dict = {md.key: md.value for md in metadata} + return md_dict.keys() + class OpenTelemetryServerInterceptor( grpcext.UnaryServerInterceptor, grpcext.StreamServerInterceptor @@ -121,11 +132,8 @@ def __init__(self, tracer): def _set_remote_context(self, servicer_context): metadata = servicer_context.invocation_metadata() if metadata: - md_dict = {md.key: md.value for md in metadata} - - def get_from_grpc_metadata(metadata, key) -> List[str]: - return [md_dict[key]] if key in md_dict else [] + get_from_grpc_metadata = Getter() # Update the context with the traceparent from the RPC metadata. ctx = propagators.extract(get_from_grpc_metadata, metadata) token = attach(ctx) diff --git a/instrumentation/opentelemetry-instrumentation-opentracing-shim/src/opentelemetry/instrumentation/opentracing_shim/__init__.py b/instrumentation/opentelemetry-instrumentation-opentracing-shim/src/opentelemetry/instrumentation/opentracing_shim/__init__.py index 6bb22130d8e..10ab5caf809 100644 --- a/instrumentation/opentelemetry-instrumentation-opentracing-shim/src/opentelemetry/instrumentation/opentracing_shim/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-opentracing-shim/src/opentelemetry/instrumentation/opentracing_shim/__init__.py @@ -503,6 +503,47 @@ def tracer(self) -> "TracerShim": return self._tracer +class Getter: + """This class provides an interface that enables extracting propagated + fields from a carrier + + """ + + @staticmethod + def get(carrier, key): + """Function that can retrieve zero + or more values from the carrier. In the case that + the value does not exist, returns an empty list. + + Args: + carrier: and object which contains values that are + used to construct a Context. This object + must be paired with an appropriate get_from_carrier + which understands how to extract a value from it. + key: key of a field in carrier. + Returns: + first value of the propagation key or an empty list if the key doesn't exist. + """ + + value = carrier.get(key) + return [value] if value is not None else [] + + @staticmethod + def keys(carrier): + """Function that can retrieve all the keys in a carrier object. + + Args: + carrier: and object which contains values that are + used to construct a Context. This object + must be paired with an appropriate get_from_carrier + which understands how to extract a value from it. + Returns: + list of keys from the carrier. + """ + + return carrier.keys() + + class TracerShim(Tracer): """Wraps a :class:`opentelemetry.trace.Tracer` object. @@ -706,10 +747,7 @@ def extract(self, format: object, carrier: object): if format not in self._supported_formats: raise UnsupportedFormatException - def get_as_list(dict_object, key): - value = dict_object.get(key) - return [value] if value is not None else [] - + get_as_list = Getter() propagator = propagators.get_global_textmap() ctx = propagator.extract(get_as_list, carrier) span = get_current_span(ctx) diff --git a/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py b/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py index 6379d841a03..3078bfbe3ac 100644 --- a/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py @@ -174,11 +174,17 @@ def _log_exception(tracer, func, handler, args, kwargs): return func(*args, **kwargs) -def _get_header_from_request_headers( - headers: dict, header_name: str -) -> typing.List[str]: - header = headers.get(header_name) - return [header] if header else [] +class Getter: + @staticmethod + def get( + headers: dict, header_name: str + ) -> typing.List[str]: + header = headers.get(header_name) + return [header] if header else [] + + @staticmethod + def keys(headers) -> typing.List[str]: + return headers.keys() def _get_attributes_from_request(request): @@ -206,6 +212,7 @@ def _get_operation_name(handler, request): def _start_span(tracer, handler, start_time) -> _TraceContext: + _get_header_from_request_headers = Getter() token = context.attach( propagators.extract( _get_header_from_request_headers, handler.request.headers, diff --git a/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py b/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py index 56c8b755c5c..505cc915575 100644 --- a/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py @@ -66,19 +66,30 @@ def hello(): _HTTP_VERSION_PREFIX = "HTTP/" -def get_header_from_environ( - environ: dict, header_name: str -) -> typing.List[str]: - """Retrieve a HTTP header value from the PEP3333-conforming WSGI environ. +class Getter: + @staticmethod + def get( + environ: dict, header_name: str + ) -> typing.List[str]: + """Retrieve a HTTP header value from the PEP3333-conforming WSGI environ. - Returns: - A list with a single string with the header value if it exists, else an empty list. - """ - environ_key = "HTTP_" + header_name.upper().replace("-", "_") - value = environ.get(environ_key) - if value is not None: - return [value] - return [] + Returns: + A list with a single string with the header value if it exists, else an empty list. + """ + environ_key = "HTTP_" + header_name.upper().replace("-", "_") + value = environ.get(environ_key) + if value is not None: + return [value] + return [] + + @staticmethod + def keys(environ: dict) -> typing.List[str]: + """Retrieve all the HTTP header keys for an PEP3333-conforming WSGI environ. + + Returns: + A list with all the keys in environ. + """ + return environ.keys() def setifnotnone(dic, key, value): @@ -195,6 +206,7 @@ def __call__(self, environ, start_response): start_response: The WSGI start_response callable. """ + get_header_from_environ = Getter() token = context.attach( propagators.extract(get_header_from_environ, environ) ) diff --git a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py index d0920e68590..7d589cb0c64 100644 --- a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py +++ b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py @@ -29,7 +29,7 @@ class BaggagePropagator(textmap.TextMapPropagator): def extract( self, - get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT], + get_from_carrier: textmap.Getter, carrier: textmap.TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: @@ -43,7 +43,7 @@ def extract( context = get_current() header = _extract_first_element( - get_from_carrier(carrier, self._BAGGAGE_HEADER_NAME) + get_from_carrier.get(carrier, self._BAGGAGE_HEADER_NAME) ) if not header or len(header) > self.MAX_HEADER_LENGTH: diff --git a/opentelemetry-api/src/opentelemetry/propagators/__init__.py b/opentelemetry-api/src/opentelemetry/propagators/__init__.py index c274b19f8a0..1b3f3e84b94 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/__init__.py +++ b/opentelemetry-api/src/opentelemetry/propagators/__init__.py @@ -82,16 +82,16 @@ def example_route(): def extract( - get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT], + get_from_carrier: textmap.Getter, carrier: textmap.TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: """ Uses the configured propagator to extract a Context from the carrier. Args: - get_from_carrier: a function that can retrieve zero - or more values from the carrier. In the case that - the value does not exist, return an empty list. + get_from_carrier: an object which contains a get function that can retrieve zero + or more values from the carrier and a keys function that can get all the keys + from carrier. carrier: and object which contains values that are used to construct a Context. This object must be paired with an appropriate get_from_carrier diff --git a/opentelemetry-api/src/opentelemetry/propagators/composite.py b/opentelemetry-api/src/opentelemetry/propagators/composite.py index 3499d2ea08a..c307479021e 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/composite.py +++ b/opentelemetry-api/src/opentelemetry/propagators/composite.py @@ -35,7 +35,7 @@ def __init__( def extract( self, - get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT], + get_from_carrier: textmap.Getter, carrier: textmap.TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py b/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py index 6f9ed897e11..af9b6689651 100644 --- a/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py @@ -20,7 +20,44 @@ TextMapPropagatorT = typing.TypeVar("TextMapPropagatorT") Setter = typing.Callable[[TextMapPropagatorT, str, str], None] -Getter = typing.Callable[[TextMapPropagatorT, str], typing.List[str]] + + +class Getter(abc.ABC): + """This class provides an interface that enables extracting propagated + fields from a carrier + + """ + + @staticmethod + @abc.abstractmethod + def get(carrier: TextMapPropagatorT, key: str) -> typing.List[str]: + """Function that can retrieve zero + or more values from the carrier. In the case that + the value does not exist, returns an empty list. + + Args: + carrier: and object which contains values that are + used to construct a Context. This object + must be paired with an appropriate get_from_carrier + which understands how to extract a value from it. + key: key of a field in carrier. + Returns: + first value of the propagation key or an empty list if the key doesn't exist. + """ + + @staticmethod + @abc.abstractmethod + def keys(carrier: TextMapPropagatorT) -> typing.List[str]: + """Function that can retrieve all the keys in a carrier object. + + Args: + carrier: and object which contains values that are + used to construct a Context. This object + must be paired with an appropriate get_from_carrier + which understands how to extract a value from it. + Returns: + list of keys from the carrier. + """ class TextMapPropagator(abc.ABC): @@ -35,7 +72,7 @@ class TextMapPropagator(abc.ABC): @abc.abstractmethod def extract( self, - get_from_carrier: Getter[TextMapPropagatorT], + get_from_carrier: Getter, carrier: TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py index 8627b9a65cb..33385311f74 100644 --- a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py @@ -60,7 +60,7 @@ class TraceContextTextMapPropagator(textmap.TextMapPropagator): def extract( self, - get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT], + get_from_carrier: textmap.Getter, carrier: textmap.TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: @@ -68,7 +68,7 @@ def extract( See `opentelemetry.trace.propagation.textmap.TextMapPropagator.extract` """ - header = get_from_carrier(carrier, self._TRACEPARENT_HEADER_NAME) + header = get_from_carrier.get(carrier, self._TRACEPARENT_HEADER_NAME) if not header: return trace.set_span_in_context(trace.INVALID_SPAN, context) @@ -91,7 +91,7 @@ def extract( if version == "ff": return trace.set_span_in_context(trace.INVALID_SPAN, context) - tracestate_headers = get_from_carrier( + tracestate_headers = get_from_carrier.get( carrier, self._TRACESTATE_HEADER_NAME ) tracestate = _parse_tracestate(tracestate_headers) diff --git a/opentelemetry-api/tests/baggage/test_baggage_propagation.py b/opentelemetry-api/tests/baggage/test_baggage_propagation.py index d5c16ead5d4..d69ce2e25f5 100644 --- a/opentelemetry-api/tests/baggage/test_baggage_propagation.py +++ b/opentelemetry-api/tests/baggage/test_baggage_propagation.py @@ -20,10 +20,17 @@ from opentelemetry.context import get_current -def get_as_list( - dict_object: typing.Dict[str, typing.List[str]], key: str -) -> typing.List[str]: - return dict_object.get(key, []) +class Getter: + @staticmethod + def get(dict_object, key): + return dict_object.get(key, []) + + @staticmethod + def keys(dict_object): + return dict_object.keys() + + +get_as_list = Getter() class TestBaggagePropagation(unittest.TestCase): diff --git a/opentelemetry-api/tests/propagators/test_global_httptextformat.py b/opentelemetry-api/tests/propagators/test_global_httptextformat.py index 9a97b281297..259ef3a5508 100644 --- a/opentelemetry-api/tests/propagators/test_global_httptextformat.py +++ b/opentelemetry-api/tests/propagators/test_global_httptextformat.py @@ -20,11 +20,18 @@ from opentelemetry.trace import get_current_span, set_span_in_context -def get_as_list( - dict_object: typing.Dict[str, typing.List[str]], key: str -) -> typing.List[str]: - value = dict_object.get(key) - return value if value is not None else [] +class Getter: + @staticmethod + def get(dict_object, key): + value = dict_object.get(key) + return value if value is not None else [] + + @staticmethod + def keys(dict_object): + return dict_object.keys() + + +get_as_list = Getter() class TestDefaultGlobalPropagator(unittest.TestCase): diff --git a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py index 8abe4193873..839d0567761 100644 --- a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py +++ b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py @@ -21,11 +21,18 @@ FORMAT = tracecontext.TraceContextTextMapPropagator() -def get_as_list( - dict_object: typing.Dict[str, typing.List[str]], key: str -) -> typing.List[str]: - value = dict_object.get(key) - return value if value is not None else [] +class Getter: + @staticmethod + def get(dict_object, key): + value = dict_object.get(key) + return value if value is not None else [] + + @staticmethod + def keys(dict_object): + return dict_object.keys() + + +get_as_list = Getter() class TestTraceContextFormat(unittest.TestCase): diff --git a/opentelemetry-sdk/CHANGELOG.md b/opentelemetry-sdk/CHANGELOG.md index 71864282885..3b592ad9d3d 100644 --- a/opentelemetry-sdk/CHANGELOG.md +++ b/opentelemetry-sdk/CHANGELOG.md @@ -10,6 +10,8 @@ ([#1105](https://github.com/open-telemetry/opentelemetry-python/pull/1120)) - Allow for Custom Trace and Span IDs Generation - `IdsGenerator` for TracerProvider ([#1153](https://github.com/open-telemetry/opentelemetry-python/pull/1153)) +- Add keys method to TextMap propagator Getter + ([#1084](https://github.com/open-telemetry/opentelemetry-python/issues/1084)) ## Version 0.13b0 diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py index 813b6e85600..d6c880f9d54 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py @@ -43,7 +43,7 @@ class B3Format(TextMapPropagator): def extract( self, - get_from_carrier: Getter[TextMapPropagatorT], + get_from_carrier: Getter, carrier: TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: @@ -53,7 +53,7 @@ def extract( flags = None single_header = _extract_first_element( - get_from_carrier(carrier, self.SINGLE_HEADER_KEY) + get_from_carrier.get(carrier, self.SINGLE_HEADER_KEY) ) if single_header: # The b3 spec calls for the sampling state to be @@ -75,25 +75,25 @@ def extract( else: trace_id = ( _extract_first_element( - get_from_carrier(carrier, self.TRACE_ID_KEY) + get_from_carrier.get(carrier, self.TRACE_ID_KEY) ) or trace_id ) span_id = ( _extract_first_element( - get_from_carrier(carrier, self.SPAN_ID_KEY) + get_from_carrier.get(carrier, self.SPAN_ID_KEY) ) or span_id ) sampled = ( _extract_first_element( - get_from_carrier(carrier, self.SAMPLED_KEY) + get_from_carrier.get(carrier, self.SAMPLED_KEY) ) or sampled ) flags = ( _extract_first_element( - get_from_carrier(carrier, self.FLAGS_KEY) + get_from_carrier.get(carrier, self.FLAGS_KEY) ) or flags ) diff --git a/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py b/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py index 07b3010087a..176651aa20a 100644 --- a/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py +++ b/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py @@ -23,9 +23,18 @@ FORMAT = b3_format.B3Format() -def get_as_list(dict_object, key): - value = dict_object.get(key) - return [value] if value is not None else [] +class Getter: + @staticmethod + def get(dict_object, key): + value = dict_object.get(key) + return [value] if value is not None else [] + + @staticmethod + def keys(dict_object): + return dict_object.keys() + + +get_as_list = Getter() def get_child_parent_new_carrier(old_carrier): @@ -321,11 +330,8 @@ def test_inject_empty_context(): def test_default_span(): """Make sure propagator does not crash when working with DefaultSpan""" - def getter(carrier, key): - return carrier.get(key, None) - def setter(carrier, key, value): carrier[key] = value - ctx = FORMAT.extract(getter, {}) + ctx = FORMAT.extract(get_as_list, {}) FORMAT.inject(setter, {}, ctx) diff --git a/tests/util/src/opentelemetry/test/mock_textmap.py b/tests/util/src/opentelemetry/test/mock_textmap.py index 92c0f21f0ec..90260c86883 100644 --- a/tests/util/src/opentelemetry/test/mock_textmap.py +++ b/tests/util/src/opentelemetry/test/mock_textmap.py @@ -33,7 +33,7 @@ class NOOPTextMapPropagator(TextMapPropagator): def extract( self, - get_from_carrier: Getter[TextMapPropagatorT], + get_from_carrier: Getter, carrier: TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: @@ -56,12 +56,12 @@ class MockTextMapPropagator(TextMapPropagator): def extract( self, - get_from_carrier: Getter[TextMapPropagatorT], + get_from_carrier: Getter, carrier: TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: - trace_id_list = get_from_carrier(carrier, self.TRACE_ID_KEY) - span_id_list = get_from_carrier(carrier, self.SPAN_ID_KEY) + trace_id_list = get_from_carrier.get(carrier, self.TRACE_ID_KEY) + span_id_list = get_from_carrier.get(carrier, self.SPAN_ID_KEY) if not trace_id_list or not span_id_list: return trace.set_span_in_context(trace.INVALID_SPAN)