Skip to content

Commit

Permalink
Return none for Getter if key does not exist (open-telemetry#1449)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzchen authored and Alex Boten committed Dec 8, 2020
1 parent 7500f73 commit 444161d
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _format_baggage(baggage_entries: typing.Mapping[str, object]) -> str:


def _extract_first_element(
items: typing.Iterable[textmap.TextMapPropagatorT],
items: typing.Optional[typing.Iterable[textmap.TextMapPropagatorT]],
) -> typing.Optional[textmap.TextMapPropagatorT]:
if items is None:
return None
Expand Down
14 changes: 9 additions & 5 deletions opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ class Getter(typing.Generic[TextMapPropagatorT]):
"""

def get(self, carrier: TextMapPropagatorT, key: str) -> typing.List[str]:
def get(
self, carrier: TextMapPropagatorT, key: str
) -> typing.Optional[typing.List[str]]:
"""Function that can retrieve zero
or more values from the carrier. In the case that
the value does not exist, returns an empty list.
Expand All @@ -38,8 +40,8 @@ def get(self, carrier: TextMapPropagatorT, key: str) -> typing.List[str]:
carrier: An object which contains values that are used to
construct a Context.
key: key of a field in carrier.
Returns: first value of the propagation key or an empty list if the
key doesn't exist.
Returns: first value of the propagation key or None if the key doesn't
exist.
"""
raise NotImplementedError()

Expand All @@ -58,8 +60,10 @@ def keys(self, carrier: TextMapPropagatorT) -> typing.List[str]:
class DictGetter(Getter[typing.Dict[str, CarrierValT]]):
def get(
self, carrier: typing.Dict[str, CarrierValT], key: str
) -> typing.List[str]:
val = carrier.get(key, [])
) -> typing.Optional[typing.List[str]]:
val = carrier.get(key, None)
if val is None:
return None
if isinstance(val, typing.Iterable) and not isinstance(val, str):
return list(val)
return [val]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def extract(
return trace.set_span_in_context(trace.INVALID_SPAN, context)

tracestate_headers = getter.get(carrier, self._TRACESTATE_HEADER_NAME)
tracestate = _parse_tracestate(tracestate_headers)
if tracestate_headers is None:
tracestate = None
else:
tracestate = _parse_tracestate(tracestate_headers)

span_context = trace.SpanContext(
trace_id=int(trace_id, 16),
Expand Down
4 changes: 2 additions & 2 deletions opentelemetry-api/src/opentelemetry/trace/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def __new__(
trace_id: int,
span_id: int,
is_remote: bool,
trace_flags: "TraceFlags" = DEFAULT_TRACE_OPTIONS,
trace_state: "TraceState" = DEFAULT_TRACE_STATE,
trace_flags: typing.Optional["TraceFlags"] = DEFAULT_TRACE_OPTIONS,
trace_state: typing.Optional["TraceState"] = DEFAULT_TRACE_STATE,
) -> "SpanContext":
if trace_flags is None:
trace_flags = DEFAULT_TRACE_OPTIONS
Expand Down
42 changes: 42 additions & 0 deletions opentelemetry-api/tests/trace/propagation/test_textmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from opentelemetry.trace.propagation.textmap import DictGetter


class TestDictGetter(unittest.TestCase):
def test_get_none(self):
getter = DictGetter()
carrier = {}
val = getter.get(carrier, "test")
self.assertIsNone(val)

def test_get_str(self):
getter = DictGetter()
carrier = {"test": "val"}
val = getter.get(carrier, "test")
self.assertEqual(val, ["val"])

def test_get_iter(self):
getter = DictGetter()
carrier = {"test": ["val"]}
val = getter.get(carrier, "test")
self.assertEqual(val, ["val"])

def test_keys(self):
getter = DictGetter()
keys = getter.keys({"test": "val"})
self.assertEqual(keys, ["test"])

0 comments on commit 444161d

Please sign in to comment.