Skip to content

Commit

Permalink
Fix TraceState to adhere to specs (#1502)
Browse files Browse the repository at this point in the history
  • Loading branch information
srikanthccv authored Jan 20, 2021
1 parent c750109 commit 1d39f7f
Show file tree
Hide file tree
Showing 10 changed files with 376 additions and 96 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ env:
# Otherwise, set variable to the commit of your branch on
# opentelemetry-python-contrib which is compatible with these Core repo
# changes.
CONTRIB_REPO_SHA: 32cac7a9ff6c831aa0e9514bb38c430fce819141
CONTRIB_REPO_SHA: 1e319dbaf21df7573f15f35773b8272579dd1030

jobs:
build:
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#1535](https://github.com/open-telemetry/opentelemetry-python/pull/1535))
- `opentelemetry-sdk` Remove rate property setter from TraceIdRatioBasedSampler
([#1536](https://github.com/open-telemetry/opentelemetry-python/pull/1536))
- Fix TraceState to adhere to specs
([#1502](https://github.com/open-telemetry/opentelemetry-python/pull/1502))

### Removed
- `opentelemetry-api` Remove ThreadLocalRuntimeContext since python3.4 is not supported.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_translate_to_collector(self):
span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=trace_api.TraceState({"testKey": "testValue"}),
trace_state=trace_api.TraceState([("testkey", "testvalue")]),
)
parent_span_context = trace_api.SpanContext(
trace_id, parent_id, is_remote=False
Expand Down Expand Up @@ -200,9 +200,9 @@ def test_translate_to_collector(self):
)
self.assertEqual(output_spans[0].status.message, "test description")
self.assertEqual(len(output_spans[0].tracestate.entries), 1)
self.assertEqual(output_spans[0].tracestate.entries[0].key, "testKey")
self.assertEqual(output_spans[0].tracestate.entries[0].key, "testkey")
self.assertEqual(
output_spans[0].tracestate.entries[0].value, "testValue"
output_spans[0].tracestate.entries[0].value, "testvalue"
)

self.assertEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,7 @@
import opentelemetry.trace as trace
from opentelemetry.context.context import Context
from opentelemetry.trace.propagation import textmap

# Keys and values are strings of up to 256 printable US-ASCII characters.
# Implementations should conform to the `W3C Trace Context - Tracestate`_
# spec, which describes additional restrictions on valid field values.
#
# .. _W3C Trace Context - Tracestate:
# https://www.w3.org/TR/trace-context/#tracestate-field

_KEY_WITHOUT_VENDOR_FORMAT = r"[a-z][_0-9a-z\-\*\/]{0,255}"
_KEY_WITH_VENDOR_FORMAT = (
r"[a-z0-9][_0-9a-z\-\*\/]{0,240}@[a-z][_0-9a-z\-\*\/]{0,13}"
)

_KEY_FORMAT = _KEY_WITHOUT_VENDOR_FORMAT + "|" + _KEY_WITH_VENDOR_FORMAT
_VALUE_FORMAT = (
r"[\x20-\x2b\x2d-\x3c\x3e-\x7e]{0,255}[\x21-\x2b\x2d-\x3c\x3e-\x7e]"
)

_DELIMITER_FORMAT = "[ \t]*,[ \t]*"
_MEMBER_FORMAT = "({})(=)({})[ \t]*".format(_KEY_FORMAT, _VALUE_FORMAT)

_DELIMITER_FORMAT_RE = re.compile(_DELIMITER_FORMAT)
_MEMBER_FORMAT_RE = re.compile(_MEMBER_FORMAT)

_TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS = 32
from opentelemetry.trace.span import TraceState


class TraceContextTextMapPropagator(textmap.TextMapPropagator):
Expand Down Expand Up @@ -94,7 +70,7 @@ def extract(
if tracestate_headers is None:
tracestate = None
else:
tracestate = _parse_tracestate(tracestate_headers)
tracestate = TraceState.from_header(tracestate_headers)

span_context = trace.SpanContext(
trace_id=int(trace_id, 16),
Expand Down Expand Up @@ -130,7 +106,7 @@ def inject(
carrier, self._TRACEPARENT_HEADER_NAME, traceparent_string
)
if span_context.trace_state:
tracestate_string = _format_tracestate(span_context.trace_state)
tracestate_string = span_context.trace_state.to_header()
set_in_carrier(
carrier, self._TRACESTATE_HEADER_NAME, tracestate_string
)
Expand All @@ -143,57 +119,3 @@ def fields(self) -> typing.Set[str]:
`opentelemetry.trace.propagation.textmap.TextMapPropagator.fields`
"""
return {self._TRACEPARENT_HEADER_NAME, self._TRACESTATE_HEADER_NAME}


def _parse_tracestate(header_list: typing.List[str]) -> trace.TraceState:
"""Parse one or more w3c tracestate header into a TraceState.
Args:
string: the value of the tracestate header.
Returns:
A valid TraceState that contains values extracted from
the tracestate header.
If the format of one headers is illegal, all values will
be discarded and an empty tracestate will be returned.
If the number of keys is beyond the maximum, all values
will be discarded and an empty tracestate will be returned.
"""
tracestate = trace.TraceState()
value_count = 0
for header in header_list:
for member in re.split(_DELIMITER_FORMAT_RE, header):
# empty members are valid, but no need to process further.
if not member:
continue
match = _MEMBER_FORMAT_RE.fullmatch(member)
if not match:
# TODO: log this?
return trace.TraceState()
key, _eq, value = match.groups()
if key in tracestate: # pylint:disable=E1135
# duplicate keys are not legal in
# the header, so we will remove
return trace.TraceState()
# typing.Dict's update is not recognized by pylint:
# https://github.com/PyCQA/pylint/issues/2420
tracestate[key] = value # pylint:disable=E1137
value_count += 1
if value_count > _TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS:
return trace.TraceState()
return tracestate


def _format_tracestate(tracestate: trace.TraceState) -> str:
"""Parse a w3c tracestate header into a TraceState.
Args:
tracestate: the tracestate header to write
Returns:
A string that adheres to the w3c tracestate
header format.
"""
return ",".join(key + "=" + value for key, value in tracestate.items())
186 changes: 185 additions & 1 deletion opentelemetry-api/src/opentelemetry/trace/span.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import abc
import logging
import re
import types as python_types
import typing
from collections import OrderedDict

from opentelemetry.trace.status import Status
from opentelemetry.util import types
from opentelemetry.util.tracestate import (
_DELIMITER_PATTERN,
_MEMBER_PATTERN,
_TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS,
_is_valid_pair,
)

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -135,7 +143,7 @@ def sampled(self) -> bool:
DEFAULT_TRACE_OPTIONS = TraceFlags.get_default()


class TraceState(typing.Dict[str, str]):
class TraceState(typing.Mapping[str, str]):
"""A list of key-value pairs representing vendor-specific trace info.
Keys and values are strings of up to 256 printable US-ASCII characters.
Expand All @@ -146,10 +154,186 @@ class TraceState(typing.Dict[str, str]):
https://www.w3.org/TR/trace-context/#tracestate-field
"""

def __init__(
self,
entries: typing.Optional[
typing.Sequence[typing.Tuple[str, str]]
] = None,
) -> None:
self._dict = OrderedDict() # type: OrderedDict[str, str]
if entries is None:
return
if len(entries) > _TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS:
_logger.warning(
"There can't be more than %s key/value pairs.",
_TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS,
)
return

for key, value in entries:
if _is_valid_pair(key, value):
if key in self._dict:
_logger.warning("Duplicate key: %s found.", key)
continue
self._dict[key] = value
else:
_logger.warning(
"Invalid key/value pair (%s, %s) found.", key, value
)

def __getitem__(self, key: str) -> typing.Optional[str]: # type: ignore
return self._dict.get(key)

def __iter__(self) -> typing.Iterator[str]:
return iter(self._dict)

def __len__(self) -> int:
return len(self._dict)

def __repr__(self) -> str:
pairs = [
"{key=%s, value=%s}" % (key, value)
for key, value in self._dict.items()
]
return str(pairs)

def add(self, key: str, value: str) -> "TraceState":
"""Adds a key-value pair to tracestate. The provided pair should
adhere to w3c tracestate identifiers format.
Args:
key: A valid tracestate key to add
value: A valid tracestate value to add
Returns:
A new TraceState with the modifications applied.
If the provided key-value pair is invalid or results in tracestate
that violates tracecontext specification, they are discarded and
same tracestate will be returned.
"""
if not _is_valid_pair(key, value):
_logger.warning(
"Invalid key/value pair (%s, %s) found.", key, value
)
return self
# There can be a maximum of 32 pairs
if len(self) >= _TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS:
_logger.warning("There can't be more 32 key/value pairs.")
return self
# Duplicate entries are not allowed
if key in self._dict:
_logger.warning("The provided key %s already exists.", key)
return self
new_state = [(key, value)] + list(self._dict.items())
return TraceState(new_state)

def update(self, key: str, value: str) -> "TraceState":
"""Updates a key-value pair in tracestate. The provided pair should
adhere to w3c tracestate identifiers format.
Args:
key: A valid tracestate key to update
value: A valid tracestate value to update for key
Returns:
A new TraceState with the modifications applied.
If the provided key-value pair is invalid or results in tracestate
that violates tracecontext specification, they are discarded and
same tracestate will be returned.
"""
if not _is_valid_pair(key, value):
_logger.warning(
"Invalid key/value pair (%s, %s) found.", key, value
)
return self
prev_state = self._dict.copy()
prev_state[key] = value
prev_state.move_to_end(key, last=False)
new_state = list(prev_state.items())
return TraceState(new_state)

def delete(self, key: str) -> "TraceState":
"""Deletes a key-value from tracestate.
Args:
key: A valid tracestate key to remove key-value pair from tracestate
Returns:
A new TraceState with the modifications applied.
If the provided key-value pair is invalid or results in tracestate
that violates tracecontext specification, they are discarded and
same tracestate will be returned.
"""
if key not in self._dict:
_logger.warning("The provided key %s doesn't exist.", key)
return self
prev_state = self._dict.copy()
prev_state.pop(key)
new_state = list(prev_state.items())
return TraceState(new_state)

def to_header(self) -> str:
"""Creates a w3c tracestate header from a TraceState.
Returns:
A string that adheres to the w3c tracestate
header format.
"""
return ",".join(key + "=" + value for key, value in self._dict.items())

@classmethod
def from_header(cls, header_list: typing.List[str]) -> "TraceState":
"""Parses one or more w3c tracestate header into a TraceState.
Args:
header_list: one or more w3c tracestate headers.
Returns:
A valid TraceState that contains values extracted from
the tracestate header.
If the format of one headers is illegal, all values will
be discarded and an empty tracestate will be returned.
If the number of keys is beyond the maximum, all values
will be discarded and an empty tracestate will be returned.
"""
pairs = OrderedDict()
for header in header_list:
for member in re.split(_DELIMITER_PATTERN, header):
# empty members are valid, but no need to process further.
if not member:
continue
match = _MEMBER_PATTERN.fullmatch(member)
if not match:
_logger.warning(
"Member doesn't match the w3c identifiers format %s",
member,
)
return cls()
key, _eq, value = match.groups()
# duplicate keys are not legal in header
if key in pairs:
return cls()
pairs[key] = value
return cls(list(pairs.items()))

@classmethod
def get_default(cls) -> "TraceState":
return cls()

def keys(self) -> typing.KeysView[str]:
return self._dict.keys()

def items(self) -> typing.ItemsView[str, str]:
return self._dict.items()

def values(self) -> typing.ValuesView[str]:
return self._dict.values()


DEFAULT_TRACE_STATE = TraceState.get_default()

Expand Down
Loading

0 comments on commit 1d39f7f

Please sign in to comment.