diff --git a/src/propagation.cpp b/src/propagation.cpp index 471e7d11..7dfa5d0f 100644 --- a/src/propagation.cpp +++ b/src/propagation.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "sample.h" @@ -119,6 +120,24 @@ std::unique_ptr>> enforce_tag_pres return nullptr; } +// Interpret the specified `text` as a non-negative integer formatted in the +// specified `base` (e.g. base 10 for decimal, base 16 for hexadecimal), +// possibly surrounded by whitespace, and return the integer. Throw an +// exception derived from `std::logic_error` if an error occurs. +uint64_t parse_uint64(const std::string &text, int base) { + std::size_t end_index; + const uint64_t result = std::stoull(text, &end_index, base); + + // If any of the remaining characters are not whitespace, then `text` + // contains something other than a base-`base` integer. + if (std::any_of(text.begin() + end_index, text.end(), + [](unsigned char ch) { return !std::isspace(ch); })) { + throw std::invalid_argument("integer text field has a trailing non-whitespace character"); + } + + return result; +} + } // namespace std::vector getPropagationHeaderNames(const std::set &styles, @@ -416,8 +435,8 @@ ot::expected> SpanContext::deserialize( std::string trace_id_str = j[json_trace_id_key]; std::string parent_id_str = j[json_parent_id_key]; - trace_id = std::stoull(trace_id_str); - parent_id = std::stoull(parent_id_str); + trace_id = parse_uint64(trace_id_str, 10); + parent_id = parse_uint64(parent_id_str, 10); if (j.find(json_sampling_priority_key) != j.end()) { sampling_priority = asSamplingPriority(j[json_sampling_priority_key]); @@ -483,10 +502,10 @@ ot::expected> SpanContext::deserialize( reader.ForeachKey([&](ot::string_view key, ot::string_view value) -> ot::expected { try { if (equals_ignore_case(key, headers_impl.trace_id_header)) { - trace_id = std::stoull(value, nullptr, headers_impl.base); + trace_id = parse_uint64(value, headers_impl.base); trace_id_set = true; } else if (equals_ignore_case(key, headers_impl.span_id_header)) { - parent_id = std::stoull(value, nullptr, headers_impl.base); + parent_id = parse_uint64(value, headers_impl.base); parent_id_set = true; } else if (equals_ignore_case(key, headers_impl.sampling_priority_header)) { sampling_priority = asSamplingPriority(std::stoi(value)); diff --git a/test/propagation_test.cpp b/test/propagation_test.cpp index 737deb3b..f5b35ce6 100644 --- a/test/propagation_test.cpp +++ b/test/propagation_test.cpp @@ -97,6 +97,33 @@ TEST_CASE("SpanContext") { REQUIRE(*received_context2 == *received_context); } } + + SECTION("even with leading whitespace in integer fields") { + carrier.Set("x-datadog-trace-id", " 123"); + auto sc = SpanContext::deserialize(logger, carrier, propagation_styles); + REQUIRE(sc); + auto received_context = dynamic_cast(sc->get()); + REQUIRE(received_context); + REQUIRE(received_context->traceId() == 123); + } + + SECTION("even with trailing whitespace in integer fields") { + carrier.Set("x-datadog-trace-id", "123 "); + auto sc = SpanContext::deserialize(logger, carrier, propagation_styles); + REQUIRE(sc); + auto received_context = dynamic_cast(sc->get()); + REQUIRE(received_context); + REQUIRE(received_context->traceId() == 123); + } + + SECTION("even with whitespace surrounding integer fields") { + carrier.Set("x-datadog-trace-id", " 123 "); + auto sc = SpanContext::deserialize(logger, carrier, propagation_styles); + REQUIRE(sc); + auto received_context = dynamic_cast(sc->get()); + REQUIRE(received_context); + REQUIRE(received_context->traceId() == 123); + } } SECTION("can access ids") { REQUIRE(context.ToTraceID() == "123"); @@ -196,6 +223,12 @@ TEST_CASE("deserialize fails") { REQUIRE(!err); REQUIRE(err.error() == ot::span_context_corrupted_error); } + + SECTION("when decimal integer IDs start decimal but have hex characters") { + carrier.Set(test_case.x_datadog_trace_id, "123deadbeef"); + auto err = SpanContext::deserialize(logger, carrier, test_case.styles); + REQUIRE(!err); + } } TEST_CASE("SamplingPriority values are clamped apropriately for b3") {