diff --git a/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py b/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py index 11991413eb8..7b6249fa465 100644 --- a/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py +++ b/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py @@ -21,7 +21,7 @@ from opentelemetry.instrumentation.wsgi import ( add_response_attributes, collect_request_attributes, - get_header_from_environ, + getter, ) from opentelemetry.propagators import extract from opentelemetry.trace import SpanKind, get_tracer @@ -98,7 +98,7 @@ def process_request(self, request): environ = request.META - token = attach(extract(get_header_from_environ, environ)) + token = attach(extract(getter, environ)) tracer = get_tracer(__name__, __version__) diff --git a/instrumentation/opentelemetry-instrumentation-falcon/src/opentelemetry/instrumentation/falcon/__init__.py b/instrumentation/opentelemetry-instrumentation-falcon/src/opentelemetry/instrumentation/falcon/__init__.py index bfcd45a8b58..3a1323680b9 100644 --- a/instrumentation/opentelemetry-instrumentation-falcon/src/opentelemetry/instrumentation/falcon/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-falcon/src/opentelemetry/instrumentation/falcon/__init__.py @@ -114,9 +114,7 @@ def __call__(self, env, start_response): start_time = time_ns() - token = context.attach( - propagators.extract(otel_wsgi.get_header_from_environ, env) - ) + token = context.attach(propagators.extract(otel_wsgi.getter, env)) attributes = otel_wsgi.collect_request_attributes(env) span = self._tracer.start_span( otel_wsgi.get_default_span_name(env), diff --git a/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py b/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py index 90082dd850e..76457875304 100644 --- a/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py @@ -113,9 +113,7 @@ def _before_request(): span_name = flask.request.endpoint or otel_wsgi.get_default_span_name( environ ) - token = context.attach( - propagators.extract(otel_wsgi.get_header_from_environ, environ) - ) + token = context.attach(propagators.extract(otel_wsgi.getter, environ)) tracer = trace.get_tracer(__name__, __version__) diff --git a/instrumentation/opentelemetry-instrumentation-pyramid/src/opentelemetry/instrumentation/pyramid/callbacks.py b/instrumentation/opentelemetry-instrumentation-pyramid/src/opentelemetry/instrumentation/pyramid/callbacks.py index ada239b8e31..a854c55e2e1 100644 --- a/instrumentation/opentelemetry-instrumentation-pyramid/src/opentelemetry/instrumentation/pyramid/callbacks.py +++ b/instrumentation/opentelemetry-instrumentation-pyramid/src/opentelemetry/instrumentation/pyramid/callbacks.py @@ -69,9 +69,7 @@ def _before_traversal(event): start_time = environ.get(_ENVIRON_STARTTIME_KEY) - token = context.attach( - propagators.extract(otel_wsgi.get_header_from_environ, environ) - ) + token = context.attach(propagators.extract(otel_wsgi.getter, environ)) tracer = trace.get_tracer(__name__, __version__) if request.matched_route: diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py b/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py index 5b3b314195d..60ca5971190 100644 --- a/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py @@ -22,6 +22,20 @@ Setter = typing.Callable[[TextMapPropagatorT, str, str], None] +def default_get(carrier: TextMapPropagatorT, key: str) -> typing.List[str]: + return [carrier.get(key)] if carrier.get(key) else [] + + +def default_keys(carrier: TextMapPropagatorT) -> typing.List[str]: + return list(carrier.keys()) + + +GetterGetFunction = typing.Callable[ + [TextMapPropagatorT, str], typing.List[str] +] +GetterKeysFunction = typing.Callable[[TextMapPropagatorT], typing.List[str]] + + class Getter: """This class implements a Getter that enables extracting propagated fields from a carrier @@ -30,10 +44,8 @@ class Getter: def __init__( self, - get=lambda carrier, key: [carrier.get(key)] - if carrier.get(key) - else [], - keys=lambda carrier: list(carrier.keys()), + get: GetterGetFunction[TextMapPropagatorT] = default_get, + keys: GetterKeysFunction[TextMapPropagatorT] = default_keys, ): self._get = get self._keys = keys