diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/setup.cfg b/sdk-extension/opentelemetry-sdk-extension-aws/setup.cfg index 4aa890ca53..6cfbe7cb68 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/setup.cfg +++ b/sdk-extension/opentelemetry-sdk-extension-aws/setup.cfg @@ -39,7 +39,7 @@ package_dir= =src packages=find_namespace: install_requires = - opentelemetry-api == 0.15.b0 + opentelemetry-api == 0.16.dev0 [options.entry_points] opentelemetry_propagator = @@ -47,7 +47,7 @@ opentelemetry_propagator = [options.extras_require] test = - opentelemetry-test == 0.15.b0 + opentelemetry-test == 0.16.dev0 [options.packages.find] where = src diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/propagation/aws_xray_format.py b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/propagation/aws_xray_format.py index 10c02a6e07..d10d11f5a2 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/propagation/aws_xray_format.py +++ b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/propagation/aws_xray_format.py @@ -55,10 +55,6 @@ class AwsXRayFormat(TextMapPropagator): IS_SAMPLED = "1" NOT_SAMPLED = "0" - # pylint: disable=too-many-locals - # pylint: disable=too-many-return-statements - # pylint: disable=too-many-branches - # pylint: disable=too-many-statements def extract( self, getter: Getter[TextMapPropagatorT], @@ -79,53 +75,66 @@ def extract( trace.INVALID_SPAN, context=context ) + trace_id, span_id, sampled, err = self.extract_span_properties( + trace_header + ) + + if err is not None: + return trace.set_span_in_context( + trace.INVALID_SPAN, context=context + ) + + options = 0 + if sampled: + options |= trace.TraceFlags.SAMPLED + + span_context = trace.SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=True, + trace_flags=trace.TraceFlags(options), + trace_state=trace.TraceState(), + ) + + if not span_context.is_valid: + _logger.error( + "Invalid Span Extracted. Insertting INVALID span into provided context." + ) + return trace.set_span_in_context( + trace.INVALID_SPAN, context=context + ) + + return trace.set_span_in_context( + trace.DefaultSpan(span_context), context=context + ) + + def extract_span_properties(self, trace_header): trace_id = trace.INVALID_TRACE_ID span_id = trace.INVALID_SPAN_ID sampled = False - next_kv_pair_start = 0 + extract_err = None - while next_kv_pair_start < len(trace_header): - try: - kv_pair_delimiter_index = trace_header.index( - self.KV_PAIR_DELIMITER, next_kv_pair_start - ) - kv_pair_subset = trace_header[ - next_kv_pair_start:kv_pair_delimiter_index - ] - next_kv_pair_start = kv_pair_delimiter_index + 1 - except ValueError: - kv_pair_subset = trace_header[next_kv_pair_start:] - next_kv_pair_start = len(trace_header) - - stripped_kv_pair = kv_pair_subset.strip() + for kv_pair_str in trace_header.split(self.KV_PAIR_DELIMITER): + if extract_err: + break try: - key_and_value_delimiter_index = stripped_kv_pair.index( + key_str, value_str = kv_pair_str.split( self.KEY_AND_VALUE_DELIMITER ) + key, value = key_str.strip(), value_str.strip() except ValueError: _logger.error( ( "Error parsing X-Ray trace header. Invalid key value pair: %s. Returning INVALID span context.", - kv_pair_subset, + kv_pair_str, ) ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) + return trace_id, span_id, sampled, extract_err - value = stripped_kv_pair[key_and_value_delimiter_index + 1 :] - - if stripped_kv_pair.startswith(self.TRACE_ID_KEY): - if ( - len(value) != self.TRACE_ID_LENGTH - or not value.startswith(self.TRACE_ID_VERSION) - or value[self.TRACE_ID_DELIMITER_INDEX_1] - != self.TRACE_ID_DELIMITER - or value[self.TRACE_ID_DELIMITER_INDEX_2] - != self.TRACE_ID_DELIMITER - ): + if key == self.TRACE_ID_KEY: + if not self.validate_trace_id(value): _logger.error( ( "Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", @@ -133,19 +142,11 @@ def extract( trace_header, ) ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) + extract_err = True + break - timestamp_subset = value[ - self.TRACE_ID_DELIMITER_INDEX_1 - + 1 : self.TRACE_ID_DELIMITER_INDEX_2 - ] - unique_id_subset = value[ - self.TRACE_ID_DELIMITER_INDEX_2 + 1 : self.TRACE_ID_LENGTH - ] try: - trace_id = int(timestamp_subset + unique_id_subset, 16) + trace_id = self.parse_trace_id(value) except ValueError: _logger.error( ( @@ -154,11 +155,9 @@ def extract( trace_header, ) ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) - elif stripped_kv_pair.startswith(self.PARENT_ID_KEY): - if len(value) != self.PARENT_ID_LENGTH: + extract_err = True + elif key == self.PARENT_ID_KEY: + if not self.validate_span_id(value): _logger.error( ( "Invalid ParentId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", @@ -166,12 +165,11 @@ def extract( trace_header, ) ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) + extract_err = True + break try: - span_id = int(value, 16) + span_id = AwsXRayFormat.parse_span_id(value) except ValueError: _logger.error( ( @@ -180,25 +178,9 @@ def extract( trace_header, ) ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) - elif stripped_kv_pair.startswith(self.SAMPLED_FLAG_KEY): - is_sampled_flag_valid = True - - if len(value) != self.SAMPLED_FLAG_LENGTH: - is_sampled_flag_valid = False - - if is_sampled_flag_valid: - sampled_flag = value[0] - if sampled_flag == self.IS_SAMPLED: - sampled = True - elif sampled_flag == self.NOT_SAMPLED: - sampled = False - else: - is_sampled_flag_valid = False - - if not is_sampled_flag_valid: + extract_err = True + elif key == self.SAMPLED_FLAG_KEY: + if not self.validate_sampled_flag(value): _logger.error( ( "Invalid Sampling flag in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context.", @@ -206,34 +188,51 @@ def extract( trace_header, ) ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) + extract_err = True + break - options = 0 - if sampled: - options |= trace.TraceFlags.SAMPLED + sampled = self.parse_sampled_flag(value) - span_context = trace.SpanContext( - trace_id=trace_id, - span_id=span_id, - is_remote=True, - trace_flags=trace.TraceFlags(options), - trace_state=trace.TraceState(), - ) + return trace_id, span_id, sampled, extract_err - if not span_context.is_valid: - _logger.error( - "Invalid Span Extracted. Insertting INVALID span into provided context." - ) - return trace.set_span_in_context( - trace.INVALID_SPAN, context=context - ) + def validate_trace_id(self, trace_id_str): + return ( + len(trace_id_str) == self.TRACE_ID_LENGTH + and trace_id_str.startswith(self.TRACE_ID_VERSION) + and trace_id_str[self.TRACE_ID_DELIMITER_INDEX_1] + == self.TRACE_ID_DELIMITER + and trace_id_str[self.TRACE_ID_DELIMITER_INDEX_2] + == self.TRACE_ID_DELIMITER + ) - return trace.set_span_in_context( - trace.DefaultSpan(span_context), context=context + def parse_trace_id(self, trace_id_str): + timestamp_subset = trace_id_str[ + self.TRACE_ID_DELIMITER_INDEX_1 + + 1 : self.TRACE_ID_DELIMITER_INDEX_2 + ] + unique_id_subset = trace_id_str[ + self.TRACE_ID_DELIMITER_INDEX_2 + 1 : self.TRACE_ID_LENGTH + ] + return int(timestamp_subset + unique_id_subset, 16) + + def validate_span_id(self, span_id_str): + return len(span_id_str) == self.PARENT_ID_LENGTH + + @staticmethod + def parse_span_id(span_id_str): + return int(span_id_str, 16) + + def validate_sampled_flag(self, sampled_flag_str): + return len( + sampled_flag_str + ) == self.SAMPLED_FLAG_LENGTH and sampled_flag_str in ( + self.IS_SAMPLED, + self.NOT_SAMPLED, ) + def parse_sampled_flag(self, sampled_flag_str): + return sampled_flag_str[0] == self.IS_SAMPLED + def inject( self, set_in_carrier: Setter[TextMapPropagatorT], diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/propagation/test_aws_xray_format.py b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/propagation/test_aws_xray_format.py index fc0946baba..0ea84ef217 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/propagation/test_aws_xray_format.py +++ b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/propagation/test_aws_xray_format.py @@ -247,6 +247,41 @@ def test_extract_with_additional_fields(self): get_extracted_span_context(build_test_context()), ) + def test_extract_with_extra_whitespace(self): + default_xray_trace_header_dict = build_dict_with_xray_trace_header() + trace_header_components = default_xray_trace_header_dict[ + AwsXRayFormat.TRACE_HEADER_KEY + ].split(AwsXRayFormat.KV_PAIR_DELIMITER) + xray_trace_header_dict_with_extra_whitespace = CaseInsensitiveDict( + { + AwsXRayFormat.TRACE_HEADER_KEY: AwsXRayFormat.KV_PAIR_DELIMITER.join( + [ + AwsXRayFormat.KEY_AND_VALUE_DELIMITER.join( + [ + " " + key + " ", + " " + value + " ", + ] + ) + for kv_pair_str in trace_header_components + for key, value in [ + kv_pair_str.split( + AwsXRayFormat.KEY_AND_VALUE_DELIMITER + ) + ] + ] + ) + } + ) + actual_context_encompassing_extracted = AwsXRayPropagatorTest.XRAY_PROPAGATOR.extract( + AwsXRayPropagatorTest.carrier_getter, + xray_trace_header_dict_with_extra_whitespace, + ) + + self.assertEqual( + get_extracted_span_context(actual_context_encompassing_extracted), + get_extracted_span_context(build_test_context()), + ) + def test_extract_invalid_xray_trace_header(self): actual_context_encompassing_extracted = AwsXRayPropagatorTest.XRAY_PROPAGATOR.extract( AwsXRayPropagatorTest.carrier_getter,