Skip to content

Commit

Permalink
converted textmap propagator getter to a class and added keys method
Browse files Browse the repository at this point in the history
  • Loading branch information
nprajilesh committed Oct 2, 2020
1 parent 14fad78 commit 4101d81
Show file tree
Hide file tree
Showing 13 changed files with 124 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,24 @@ class DatadogFormat(TextMapPropagator):

def extract(
self,
get_from_carrier: Getter[TextMapPropagatorT],
get_from_carrier: Getter,
carrier: TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
trace_id = extract_first_element(
get_from_carrier(carrier, self.TRACE_ID_KEY)
get_from_carrier.get(carrier, self.TRACE_ID_KEY)
)

span_id = extract_first_element(
get_from_carrier(carrier, self.PARENT_ID_KEY)
get_from_carrier.get(carrier, self.PARENT_ID_KEY)
)

sampled = extract_first_element(
get_from_carrier(carrier, self.SAMPLING_PRIORITY_KEY)
get_from_carrier.get(carrier, self.SAMPLING_PRIORITY_KEY)
)

origin = extract_first_element(
get_from_carrier(carrier, self.ORIGIN_KEY)
get_from_carrier.get(carrier, self.ORIGIN_KEY)
)

trace_flags = trace.TraceFlags()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,18 @@
FORMAT = propagator.DatadogFormat()


def get_as_list(dict_object, key):
value = dict_object.get(key)
return [value] if value is not None else []
class Getter:
@staticmethod
def get(dict_object, key):
value = dict_object.get(key)
return [value] if value is not None else []

@staticmethod
def keys(dict_object):
return dict_object.keys()


get_as_list = Getter()


class TestDatadogFormat(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class BaggagePropagator(textmap.TextMapPropagator):

def extract(
self,
get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT],
get_from_carrier: textmap.Getter,
carrier: textmap.TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
Expand All @@ -43,7 +43,7 @@ def extract(
context = get_current()

header = _extract_first_element(
get_from_carrier(carrier, self._BAGGAGE_HEADER_NAME)
get_from_carrier.get(carrier, self._BAGGAGE_HEADER_NAME)
)

if not header or len(header) > self.MAX_HEADER_LENGTH:
Expand Down
8 changes: 4 additions & 4 deletions opentelemetry-api/src/opentelemetry/propagators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,16 @@ def example_route():


def extract(
get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT],
get_from_carrier: textmap.Getter,
carrier: textmap.TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
""" Uses the configured propagator to extract a Context from the carrier.
Args:
get_from_carrier: a function that can retrieve zero
or more values from the carrier. In the case that
the value does not exist, return an empty list.
get_from_carrier: an object which contains a get function that can retrieve zero
or more values from the carrier and a keys function that can get all the keys
from carrier.
carrier: and object which contains values that are
used to construct a Context. This object
must be paired with an appropriate get_from_carrier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(

def extract(
self,
get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT],
get_from_carrier: textmap.Getter,
carrier: textmap.TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
Expand Down
41 changes: 39 additions & 2 deletions opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,44 @@
TextMapPropagatorT = typing.TypeVar("TextMapPropagatorT")

Setter = typing.Callable[[TextMapPropagatorT, str, str], None]
Getter = typing.Callable[[TextMapPropagatorT, str], typing.List[str]]


class Getter(abc.ABC):
"""This class provides an interface that enables extracting propagated
fields from a carrier
"""

@abc.abstractmethod
@staticmethod
def get(carrier: TextMapPropagatorT, key: str) -> typing.List[str]:
"""Function that can retrieve zero
or more values from the carrier. In the case that
the value does not exist, returns an empty list.
Args:
carrier: and object which contains values that are
used to construct a Context. This object
must be paired with an appropriate get_from_carrier
which understands how to extract a value from it.
key: key of a field in carrier.
Returns:
first value of the propagation key or an empty list if the key doesn't exist.
"""

@abc.abstractmethod
@staticmethod
def keys(carrier: TextMapPropagatorT) -> typing.List[str]:
"""Function that can retrieve all the keys in a carrier object.
Args:
carrier: and object which contains values that are
used to construct a Context. This object
must be paired with an appropriate get_from_carrier
which understands how to extract a value from it.
Returns:
list of keys from the carrier.
"""


class TextMapPropagator(abc.ABC):
Expand All @@ -35,7 +72,7 @@ class TextMapPropagator(abc.ABC):
@abc.abstractmethod
def extract(
self,
get_from_carrier: Getter[TextMapPropagatorT],
get_from_carrier: Getter,
carrier: TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ class TraceContextTextMapPropagator(textmap.TextMapPropagator):

def extract(
self,
get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT],
get_from_carrier: textmap.Getter,
carrier: textmap.TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
"""Extracts SpanContext from the carrier.
See `opentelemetry.trace.propagation.textmap.TextMapPropagator.extract`
"""
header = get_from_carrier(carrier, self._TRACEPARENT_HEADER_NAME)
header = get_from_carrier.get(carrier, self._TRACEPARENT_HEADER_NAME)

if not header:
return trace.set_span_in_context(trace.INVALID_SPAN, context)
Expand All @@ -91,7 +91,7 @@ def extract(
if version == "ff":
return trace.set_span_in_context(trace.INVALID_SPAN, context)

tracestate_headers = get_from_carrier(
tracestate_headers = get_from_carrier.get(
carrier, self._TRACESTATE_HEADER_NAME
)
tracestate = _parse_tracestate(tracestate_headers)
Expand Down
15 changes: 11 additions & 4 deletions opentelemetry-api/tests/baggage/test_baggage_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,17 @@
from opentelemetry.context import get_current


def get_as_list(
dict_object: typing.Dict[str, typing.List[str]], key: str
) -> typing.List[str]:
return dict_object.get(key, [])
class Getter:
@staticmethod
def get(dict_object, key):
return dict_object.get(key, [])

@staticmethod
def keys(dict_object):
return dict_object.keys()


get_as_list = Getter()


class TestBaggagePropagation(unittest.TestCase):
Expand Down
17 changes: 12 additions & 5 deletions opentelemetry-api/tests/propagators/test_global_httptextformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@
from opentelemetry.trace import get_current_span, set_span_in_context


def get_as_list(
dict_object: typing.Dict[str, typing.List[str]], key: str
) -> typing.List[str]:
value = dict_object.get(key)
return value if value is not None else []
class Getter:
@staticmethod
def get(dict_object, key):
value = dict_object.get(key)
return value if value is not None else []

@staticmethod
def keys(dict_object):
return dict_object.keys()


get_as_list = Getter()


class TestDefaultGlobalPropagator(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@
FORMAT = tracecontext.TraceContextTextMapPropagator()


def get_as_list(
dict_object: typing.Dict[str, typing.List[str]], key: str
) -> typing.List[str]:
value = dict_object.get(key)
return value if value is not None else []
class Getter:
@staticmethod
def get(dict_object, key):
value = dict_object.get(key)
return value if value is not None else []

@staticmethod
def keys(dict_object):
return dict_object.keys()


get_as_list = Getter()


class TestTraceContextFormat(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class B3Format(TextMapPropagator):

def extract(
self,
get_from_carrier: Getter[TextMapPropagatorT],
get_from_carrier: Getter,
carrier: TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
Expand All @@ -53,7 +53,7 @@ def extract(
flags = None

single_header = _extract_first_element(
get_from_carrier(carrier, self.SINGLE_HEADER_KEY)
get_from_carrier.get(carrier, self.SINGLE_HEADER_KEY)
)
if single_header:
# The b3 spec calls for the sampling state to be
Expand All @@ -75,25 +75,25 @@ def extract(
else:
trace_id = (
_extract_first_element(
get_from_carrier(carrier, self.TRACE_ID_KEY)
get_from_carrier.get(carrier, self.TRACE_ID_KEY)
)
or trace_id
)
span_id = (
_extract_first_element(
get_from_carrier(carrier, self.SPAN_ID_KEY)
get_from_carrier.get(carrier, self.SPAN_ID_KEY)
)
or span_id
)
sampled = (
_extract_first_element(
get_from_carrier(carrier, self.SAMPLED_KEY)
get_from_carrier.get(carrier, self.SAMPLED_KEY)
)
or sampled
)
flags = (
_extract_first_element(
get_from_carrier(carrier, self.FLAGS_KEY)
get_from_carrier.get(carrier, self.FLAGS_KEY)
)
or flags
)
Expand Down
20 changes: 13 additions & 7 deletions opentelemetry-sdk/tests/trace/propagation/test_b3_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,18 @@
FORMAT = b3_format.B3Format()


def get_as_list(dict_object, key):
value = dict_object.get(key)
return [value] if value is not None else []
class Getter:
@staticmethod
def get(dict_object, key):
value = dict_object.get(key)
return [value] if value is not None else []

@staticmethod
def keys(dict_object):
return dict_object.keys()


get_as_list = Getter()


def get_child_parent_new_carrier(old_carrier):
Expand Down Expand Up @@ -321,11 +330,8 @@ def test_inject_empty_context():
def test_default_span():
"""Make sure propagator does not crash when working with DefaultSpan"""

def getter(carrier, key):
return carrier.get(key, None)

def setter(carrier, key, value):
carrier[key] = value

ctx = FORMAT.extract(getter, {})
ctx = FORMAT.extract(get_as_list, {})
FORMAT.inject(setter, {}, ctx)
8 changes: 4 additions & 4 deletions tests/util/src/opentelemetry/test/mock_textmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class NOOPTextMapPropagator(TextMapPropagator):

def extract(
self,
get_from_carrier: Getter[TextMapPropagatorT],
get_from_carrier: Getter,
carrier: TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
Expand All @@ -56,12 +56,12 @@ class MockTextMapPropagator(TextMapPropagator):

def extract(
self,
get_from_carrier: Getter[TextMapPropagatorT],
get_from_carrier: Getter,
carrier: TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
trace_id_list = get_from_carrier(carrier, self.TRACE_ID_KEY)
span_id_list = get_from_carrier(carrier, self.SPAN_ID_KEY)
trace_id_list = get_from_carrier.get(carrier, self.TRACE_ID_KEY)
span_id_list = get_from_carrier.get(carrier, self.SPAN_ID_KEY)

if not trace_id_list or not span_id_list:
return trace.set_span_in_context(trace.INVALID_SPAN)
Expand Down

0 comments on commit 4101d81

Please sign in to comment.