diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 212db853d6682..4cdb5fe127dd5 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -184,10 +184,13 @@ set(ARROW_FLIGHT_SRCS "${CMAKE_CURRENT_BINARY_DIR}/Flight.pb.cc" client.cc client_cookie_middleware.cc + client_tracing_middleware.cc cookie_internal.cc + middleware.cc serialization_internal.cc server.cc server_auth.cc + server_tracing_middleware.cc transport.cc transport_server.cc # Bundle the gRPC impl with libarrow_flight diff --git a/cpp/src/arrow/flight/api.h b/cpp/src/arrow/flight/api.h index c58a9d48afa8e..61c475dc20473 100644 --- a/cpp/src/arrow/flight/api.h +++ b/cpp/src/arrow/flight/api.h @@ -20,8 +20,10 @@ #include "arrow/flight/client.h" #include "arrow/flight/client_auth.h" #include "arrow/flight/client_middleware.h" +#include "arrow/flight/client_tracing_middleware.h" #include "arrow/flight/middleware.h" #include "arrow/flight/server.h" #include "arrow/flight/server_auth.h" #include "arrow/flight/server_middleware.h" +#include "arrow/flight/server_tracing_middleware.h" #include "arrow/flight/types.h" diff --git a/cpp/src/arrow/flight/client_tracing_middleware.cc b/cpp/src/arrow/flight/client_tracing_middleware.cc new file mode 100644 index 0000000000000..a45784bd31ecd --- /dev/null +++ b/cpp/src/arrow/flight/client_tracing_middleware.cc @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/client_tracing_middleware.h" + +#include +#include +#include +#include + +#include "arrow/util/tracing_internal.h" + +#ifdef ARROW_WITH_OPENTELEMETRY +#include +#include +#endif + +namespace arrow { +namespace flight { + +namespace { +#ifdef ARROW_WITH_OPENTELEMETRY +namespace otel = opentelemetry; +class FlightClientCarrier : public otel::context::propagation::TextMapCarrier { + public: + FlightClientCarrier() = default; + + otel::nostd::string_view Get(otel::nostd::string_view key) const noexcept override { + return ""; + } + + void Set(otel::nostd::string_view key, + otel::nostd::string_view value) noexcept override { + context_.emplace_back(key, value); + } + + std::vector> context_; +}; + +class TracingClientMiddleware : public ClientMiddleware { + public: + explicit TracingClientMiddleware(FlightClientCarrier carrier) + : carrier_(std::move(carrier)) {} + virtual ~TracingClientMiddleware() = default; + + void SendingHeaders(AddCallHeaders* outgoing_headers) override { + // The exact headers added are not arbitrary and are defined in + // the OpenTelemetry specification (see + // open-telemetry/opentelemetry-specification api-propagators.md) + for (const auto& pair : carrier_.context_) { + outgoing_headers->AddHeader(pair.first, pair.second); + } + } + void ReceivedHeaders(const CallHeaders&) override {} + void CallCompleted(const Status&) override {} + + private: + FlightClientCarrier carrier_; +}; + +class TracingClientMiddlewareFactory : public ClientMiddlewareFactory { + public: + virtual ~TracingClientMiddlewareFactory() = default; + void StartCall(const CallInfo& info, + std::unique_ptr* middleware) override { + FlightClientCarrier carrier; + auto context = otel::context::RuntimeContext::GetCurrent(); + auto propagator = + otel::context::propagation::GlobalTextMapPropagator::GetGlobalPropagator(); + propagator->Inject(carrier, context); + *middleware = std::make_unique(std::move(carrier)); + } +}; +#else +class TracingClientMiddlewareFactory : public ClientMiddlewareFactory { + public: + virtual ~TracingClientMiddlewareFactory() = default; + void StartCall(const CallInfo&, std::unique_ptr*) override {} +}; +#endif +} // namespace + +std::shared_ptr MakeTracingClientMiddlewareFactory() { + return std::make_shared(); +} + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/client_tracing_middleware.h b/cpp/src/arrow/flight/client_tracing_middleware.h new file mode 100644 index 0000000000000..3a8b665ed6c0f --- /dev/null +++ b/cpp/src/arrow/flight/client_tracing_middleware.h @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Middleware implementation for propagating OpenTelemetry spans. + +#pragma once + +#include + +#include "arrow/flight/client_middleware.h" + +namespace arrow { +namespace flight { + +/// \brief Returns a ClientMiddlewareFactory that handles sending OpenTelemetry spans. +ARROW_FLIGHT_EXPORT std::shared_ptr +MakeTracingClientMiddlewareFactory(); + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 217f910d640bc..db187013ec971 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -27,10 +27,13 @@ #include #include #include +#include #include #include #include "arrow/flight/api.h" +#include "arrow/flight/client_tracing_middleware.h" +#include "arrow/flight/server_tracing_middleware.h" #include "arrow/ipc/test_common.h" #include "arrow/status.h" #include "arrow/testing/generator.h" @@ -54,6 +57,26 @@ #include "arrow/flight/serialization_internal.h" #include "arrow/flight/test_definitions.h" #include "arrow/flight/test_util.h" +// OTel includes must come after any gRPC includes, and +// client_header_internal.h includes gRPC. See: +// https://github.com/open-telemetry/opentelemetry-cpp/blob/main/examples/otlp/README.md +// +// > gRPC internally uses a different version of Abseil than +// > OpenTelemetry C++ SDK. +// > ... +// > ...in case if you run into conflict between Abseil library and +// > OpenTelemetry C++ absl::variant implementation, please include +// > either grpcpp/grpcpp.h or +// > opentelemetry/exporters/otlp/otlp_grpc_exporter.h BEFORE any +// > other API headers. This approach efficiently avoids the conflict +// > between the two different versions of Abseil. +#include "arrow/util/tracing_internal.h" +#ifdef ARROW_WITH_OPENTELEMETRY +#include +#include +#include +#include +#endif namespace arrow { namespace flight { @@ -441,21 +464,21 @@ static thread_local std::string current_span_id = ""; // A server middleware that stores the current span ID, in an // emulation of OpenTracing style distributed tracing. -class TracingServerMiddleware : public ServerMiddleware { +class TracingTestServerMiddleware : public ServerMiddleware { public: - explicit TracingServerMiddleware(const std::string& current_span_id) + explicit TracingTestServerMiddleware(const std::string& current_span_id) : span_id(current_span_id) {} void SendingHeaders(AddCallHeaders* outgoing_headers) override {} void CallCompleted(const Status& status) override {} - std::string name() const override { return "TracingServerMiddleware"; } + std::string name() const override { return "TracingTestServerMiddleware"; } std::string span_id; }; -class TracingServerMiddlewareFactory : public ServerMiddlewareFactory { +class TracingTestServerMiddlewareFactory : public ServerMiddlewareFactory { public: - TracingServerMiddlewareFactory() {} + TracingTestServerMiddlewareFactory() {} Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, std::shared_ptr* middleware) override { @@ -463,7 +486,7 @@ class TracingServerMiddlewareFactory : public ServerMiddlewareFactory { incoming_headers.equal_range("x-tracing-span-id"); if (iter_pair.first != iter_pair.second) { const std::string_view& value = (*iter_pair.first).second; - *middleware = std::make_shared(std::string(value)); + *middleware = std::make_shared(std::string(value)); } return Status::OK(); } @@ -627,10 +650,10 @@ class ReportContextTestServer : public FlightServerBase { std::unique_ptr* result) override { std::shared_ptr buf; const ServerMiddleware* middleware = context.GetMiddleware("tracing"); - if (middleware == nullptr || middleware->name() != "TracingServerMiddleware") { + if (middleware == nullptr || middleware->name() != "TracingTestServerMiddleware") { buf = Buffer::FromString(""); } else { - buf = Buffer::FromString(((const TracingServerMiddleware*)middleware)->span_id); + buf = Buffer::FromString(((const TracingTestServerMiddleware*)middleware)->span_id); } *result = std::make_unique(std::vector{Result{buf}}); return Status::OK(); @@ -658,10 +681,10 @@ class PropagatingTestServer : public FlightServerBase { Status DoAction(const ServerCallContext& context, const Action& action, std::unique_ptr* result) override { const ServerMiddleware* middleware = context.GetMiddleware("tracing"); - if (middleware == nullptr || middleware->name() != "TracingServerMiddleware") { + if (middleware == nullptr || middleware->name() != "TracingTestServerMiddleware") { current_span_id = ""; } else { - current_span_id = ((const TracingServerMiddleware*)middleware)->span_id; + current_span_id = ((const TracingTestServerMiddleware*)middleware)->span_id; } return client_->DoAction(action).Value(result); @@ -728,7 +751,7 @@ class TestCountingServerMiddleware : public ::testing::Test { class TestPropagatingMiddleware : public ::testing::Test { public: void SetUp() { - server_middleware_ = std::make_shared(); + server_middleware_ = std::make_shared(); second_client_middleware_ = std::make_shared(); client_middleware_ = std::make_shared(); @@ -782,7 +805,7 @@ class TestPropagatingMiddleware : public ::testing::Test { std::unique_ptr client_; std::unique_ptr first_server_; std::unique_ptr second_server_; - std::shared_ptr server_middleware_; + std::shared_ptr server_middleware_; std::shared_ptr second_client_middleware_; std::shared_ptr client_middleware_; }; @@ -1528,5 +1551,139 @@ TEST_F(TestCancel, DoExchange) { ARROW_UNUSED(do_exchange_result.writer->Close()); } +class TracingTestServer : public FlightServerBase { + public: + Status DoAction(const ServerCallContext& call_context, const Action&, + std::unique_ptr* result) override { + std::vector results; + auto* middleware = + reinterpret_cast(call_context.GetMiddleware("tracing")); + if (!middleware) return Status::Invalid("Could not find middleware"); +#ifdef ARROW_WITH_OPENTELEMETRY + // Ensure the trace context is present (but the value is random so + // we cannot assert any particular value) + EXPECT_FALSE(middleware->GetTraceContext().empty()); + auto span = arrow::internal::tracing::GetTracer()->GetCurrentSpan(); + const auto context = span->GetContext(); + { + const auto& span_id = context.span_id(); + ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(span_id.Id().size())); + std::memcpy(buffer->mutable_data(), span_id.Id().data(), span_id.Id().size()); + results.push_back({std::move(buffer)}); + } + { + const auto& trace_id = context.trace_id(); + ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(trace_id.Id().size())); + std::memcpy(buffer->mutable_data(), trace_id.Id().data(), trace_id.Id().size()); + results.push_back({std::move(buffer)}); + } +#else + // Ensure the trace context is not present (as OpenTelemetry is not enabled) + EXPECT_TRUE(middleware->GetTraceContext().empty()); +#endif + *result = std::make_unique(std::move(results)); + return Status::OK(); + } +}; + +class TestTracing : public ::testing::Test { + public: + void SetUp() { +#ifdef ARROW_WITH_OPENTELEMETRY + // The default tracer always generates no-op spans which have no + // span/trace ID. Set up a different tracer. Note, this needs to + // be run before Arrow uses OTel as GetTracer() gets a tracer once + // and keeps it in a static. + std::vector> processors; + auto provider = + opentelemetry::nostd::shared_ptr( + new opentelemetry::sdk::trace::TracerProvider(std::move(processors))); + opentelemetry::trace::Provider::SetTracerProvider(std::move(provider)); + + opentelemetry::context::propagation::GlobalTextMapPropagator::SetGlobalPropagator( + opentelemetry::nostd::shared_ptr< + opentelemetry::context::propagation::TextMapPropagator>( + new opentelemetry::trace::propagation::HttpTraceContext())); +#endif + + ASSERT_OK(MakeServer( + &server_, &client_, + [](FlightServerOptions* options) { + options->middleware.emplace_back("tracing", + MakeTracingServerMiddlewareFactory()); + return Status::OK(); + }, + [](FlightClientOptions* options) { + options->middleware.push_back(MakeTracingClientMiddlewareFactory()); + return Status::OK(); + })); + } + void TearDown() { ASSERT_OK(server_->Shutdown()); } + + protected: + std::unique_ptr client_; + std::unique_ptr server_; +}; + +#ifdef ARROW_WITH_OPENTELEMETRY +// Must define it ourselves to avoid a linker error +constexpr size_t kSpanIdSize = opentelemetry::trace::SpanId::kSize; +constexpr size_t kTraceIdSize = opentelemetry::trace::TraceId::kSize; + +TEST_F(TestTracing, NoParentTrace) { + ASSERT_OK_AND_ASSIGN(auto results, client_->DoAction(Action{})); + + ASSERT_OK_AND_ASSIGN(auto result, results->Next()); + ASSERT_NE(result, nullptr); + ASSERT_NE(result->body, nullptr); + // Span ID should be a valid span ID, i.e. the server must have started a span + ASSERT_EQ(result->body->size(), kSpanIdSize); + opentelemetry::trace::SpanId span_id({result->body->data(), kSpanIdSize}); + ASSERT_TRUE(span_id.IsValid()); + + ASSERT_OK_AND_ASSIGN(result, results->Next()); + ASSERT_NE(result, nullptr); + ASSERT_NE(result->body, nullptr); + ASSERT_EQ(result->body->size(), kTraceIdSize); + opentelemetry::trace::TraceId trace_id({result->body->data(), kTraceIdSize}); + ASSERT_TRUE(trace_id.IsValid()); +} +TEST_F(TestTracing, WithParentTrace) { + auto* tracer = arrow::internal::tracing::GetTracer(); + auto span = tracer->StartSpan("test"); + auto scope = tracer->WithActiveSpan(span); + + auto span_context = span->GetContext(); + auto current_trace_id = span_context.trace_id().Id(); + + ASSERT_OK_AND_ASSIGN(auto results, client_->DoAction(Action{})); + + ASSERT_OK_AND_ASSIGN(auto result, results->Next()); + ASSERT_NE(result, nullptr); + ASSERT_NE(result->body, nullptr); + ASSERT_EQ(result->body->size(), kSpanIdSize); + opentelemetry::trace::SpanId span_id({result->body->data(), kSpanIdSize}); + ASSERT_TRUE(span_id.IsValid()); + + ASSERT_OK_AND_ASSIGN(result, results->Next()); + ASSERT_NE(result, nullptr); + ASSERT_NE(result->body, nullptr); + ASSERT_EQ(result->body->size(), kTraceIdSize); + opentelemetry::trace::TraceId trace_id({result->body->data(), kTraceIdSize}); + // The server span should have the same trace ID as the client span. + ASSERT_EQ(std::string_view(reinterpret_cast(trace_id.Id().data()), + trace_id.Id().size()), + std::string_view(reinterpret_cast(current_trace_id.data()), + current_trace_id.size())); +} +#else +TEST_F(TestTracing, NoOp) { + // The middleware should not cause any trouble when OTel is not enabled. + ASSERT_OK_AND_ASSIGN(auto results, client_->DoAction(Action{})); + ASSERT_OK_AND_ASSIGN(auto result, results->Next()); + ASSERT_EQ(result, nullptr); +} +#endif + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/middleware.cc b/cpp/src/arrow/flight/middleware.cc new file mode 100644 index 0000000000000..ffbcb6aad205c --- /dev/null +++ b/cpp/src/arrow/flight/middleware.cc @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/middleware.h" + +namespace arrow { +namespace flight { + +std::string ToString(FlightMethod method) { + // Technically, we can get this via Protobuf reflection, but in + // practice we'd have to hardcode the method names to look up the + // method descriptor... + switch (method) { + case FlightMethod::Handshake: + return "Handshake"; + case FlightMethod::ListFlights: + return "ListFlights"; + case FlightMethod::GetFlightInfo: + return "GetFlightInfo"; + case FlightMethod::GetSchema: + return "GetSchema"; + case FlightMethod::DoGet: + return "DoGet"; + case FlightMethod::DoPut: + return "DoPut"; + case FlightMethod::DoAction: + return "DoAction"; + case FlightMethod::ListActions: + return "ListActions"; + case FlightMethod::DoExchange: + return "DoExchange"; + case FlightMethod::Invalid: + default: + return "(unknown Flight method)"; + } +} + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/middleware.h b/cpp/src/arrow/flight/middleware.h index b050e9cc6ed92..dc1ad24bc5c89 100644 --- a/cpp/src/arrow/flight/middleware.h +++ b/cpp/src/arrow/flight/middleware.h @@ -30,7 +30,6 @@ #include "arrow/status.h" namespace arrow { - namespace flight { /// \brief Headers sent from the client or server. @@ -66,6 +65,10 @@ enum class FlightMethod : char { DoExchange = 9, }; +/// \brief Get a human-readable name for a Flight method. +ARROW_FLIGHT_EXPORT +std::string ToString(FlightMethod method); + /// \brief Information about an instance of a Flight RPC. struct ARROW_FLIGHT_EXPORT CallInfo { public: @@ -74,5 +77,4 @@ struct ARROW_FLIGHT_EXPORT CallInfo { }; } // namespace flight - } // namespace arrow diff --git a/cpp/src/arrow/flight/server_tracing_middleware.cc b/cpp/src/arrow/flight/server_tracing_middleware.cc new file mode 100644 index 0000000000000..eac530efb8afa --- /dev/null +++ b/cpp/src/arrow/flight/server_tracing_middleware.cc @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/server_tracing_middleware.h" + +#include +#include +#include +#include + +#include "arrow/flight/transport/grpc/util_internal.h" +#include "arrow/util/tracing_internal.h" + +#ifdef ARROW_WITH_OPENTELEMETRY +#include +#include +#include +#include +#include +#endif + +namespace arrow { +namespace flight { + +#ifdef ARROW_WITH_OPENTELEMETRY +namespace otel = opentelemetry; +namespace { +class FlightServerCarrier : public otel::context::propagation::TextMapCarrier { + public: + explicit FlightServerCarrier(const CallHeaders& incoming_headers) + : incoming_headers_(incoming_headers) {} + + otel::nostd::string_view Get(otel::nostd::string_view key) const noexcept override { + std::string_view arrow_key(key.data(), key.size()); + auto it = incoming_headers_.find(arrow_key); + if (it == incoming_headers_.end()) return ""; + std::string_view result = it->second; + return {result.data(), result.size()}; + } + + void Set(otel::nostd::string_view, otel::nostd::string_view) noexcept override {} + + const CallHeaders& incoming_headers_; +}; +class KeyValueCarrier : public otel::context::propagation::TextMapCarrier { + public: + explicit KeyValueCarrier(std::vector* items) + : items_(items) {} + otel::nostd::string_view Get(otel::nostd::string_view key) const noexcept override { + return {}; + } + void Set(otel::nostd::string_view key, + otel::nostd::string_view value) noexcept override { + items_->push_back({std::string(key), std::string(value)}); + } + + private: + std::vector* items_; +}; +} // namespace + +class TracingServerMiddleware::Impl { + public: + Impl(otel::trace::Scope scope, otel::nostd::shared_ptr span) + : scope_(std::move(scope)), span_(std::move(span)) {} + void CallCompleted(const Status& status) { + if (!status.ok()) { + auto grpc_status = transport::grpc::ToGrpcStatus(status, /*ctx=*/nullptr); + span_->SetStatus(otel::trace::StatusCode::kError, status.ToString()); + span_->SetAttribute(OTEL_GET_TRACE_ATTR(AttrRpcGrpcStatusCode), + static_cast(grpc_status.error_code())); + } else { + span_->SetStatus(otel::trace::StatusCode::kOk, ""); + span_->SetAttribute(OTEL_GET_TRACE_ATTR(AttrRpcGrpcStatusCode), int32_t(0)); + } + span_->End(); + } + std::vector GetTraceContext() const { + std::vector result; + KeyValueCarrier carrier(&result); + auto context = otel::context::RuntimeContext::GetCurrent(); + otel::trace::propagation::HttpTraceContext propagator; + propagator.Inject(carrier, context); + return result; + } + + private: + otel::trace::Scope scope_; + otel::nostd::shared_ptr span_; +}; + +class TracingServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + virtual ~TracingServerMiddlewareFactory() = default; + Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) override { + constexpr char kRpcSystem[] = "grpc"; + constexpr char kServiceName[] = "arrow.flight.protocol.FlightService"; + + FlightServerCarrier carrier(incoming_headers); + auto context = otel::context::RuntimeContext::GetCurrent(); + auto propagator = + otel::context::propagation::GlobalTextMapPropagator::GetGlobalPropagator(); + auto new_context = propagator->Extract(carrier, context); + + otel::trace::StartSpanOptions options; + options.kind = otel::trace::SpanKind::kServer; + options.parent = otel::trace::GetSpan(new_context)->GetContext(); + + auto* tracer = arrow::internal::tracing::GetTracer(); + auto method_name = ToString(info.method); + auto span = tracer->StartSpan( + method_name, + { + // Attributes from experimental trace semantic conventions spec + // https://github.com/open-telemetry/opentelemetry-specification/blob/main/semantic_conventions/trace/rpc.yaml + {OTEL_GET_TRACE_ATTR(AttrRpcSystem), kRpcSystem}, + {OTEL_GET_TRACE_ATTR(AttrRpcService), kServiceName}, + {OTEL_GET_TRACE_ATTR(AttrRpcMethod), method_name}, + }, + options); + auto scope = tracer->WithActiveSpan(span); + + std::unique_ptr impl( + new TracingServerMiddleware::Impl(std::move(scope), std::move(span))); + *middleware = std::shared_ptr( + new TracingServerMiddleware(std::move(impl))); + return Status::OK(); + } +}; +#else +class TracingServerMiddleware::Impl { + public: + void CallCompleted(const Status&) {} + std::vector GetTraceContext() const { return {}; } +}; +class TracingServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + virtual ~TracingServerMiddlewareFactory() = default; + Status StartCall(const CallInfo&, const CallHeaders&, + std::shared_ptr* middleware) override { + std::unique_ptr impl( + new TracingServerMiddleware::Impl()); + *middleware = std::shared_ptr( + new TracingServerMiddleware(std::move(impl))); + return Status::OK(); + } +}; +#endif + +TracingServerMiddleware::TracingServerMiddleware(std::unique_ptr impl) + : impl_(std::move(impl)) {} +TracingServerMiddleware::~TracingServerMiddleware() = default; +void TracingServerMiddleware::SendingHeaders(AddCallHeaders*) {} +void TracingServerMiddleware::CallCompleted(const Status& status) { + impl_->CallCompleted(status); +} +std::vector TracingServerMiddleware::GetTraceContext() + const { + return impl_->GetTraceContext(); +} +constexpr char const TracingServerMiddleware::kMiddlewareName[]; + +std::shared_ptr MakeTracingServerMiddlewareFactory() { + return std::make_shared(); +} + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/server_tracing_middleware.h b/cpp/src/arrow/flight/server_tracing_middleware.h new file mode 100644 index 0000000000000..581c8354368cf --- /dev/null +++ b/cpp/src/arrow/flight/server_tracing_middleware.h @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Middleware implementation for propagating OpenTelemetry spans. + +#pragma once + +#include +#include +#include + +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/visibility.h" +#include "arrow/status.h" + +namespace arrow { +namespace flight { + +/// \brief Returns a ServerMiddlewareFactory that handles receiving OpenTelemetry spans. +ARROW_FLIGHT_EXPORT std::shared_ptr +MakeTracingServerMiddlewareFactory(); + +/// \brief A server middleware that provides access to the +/// OpenTelemetry context, if present. +/// +/// Used to make the OpenTelemetry span available in Python. +class ARROW_FLIGHT_EXPORT TracingServerMiddleware : public ServerMiddleware { + public: + ~TracingServerMiddleware(); + + static constexpr char const kMiddlewareName[] = + "arrow::flight::TracingServerMiddleware"; + + std::string name() const override { return kMiddlewareName; } + void SendingHeaders(AddCallHeaders*) override; + void CallCompleted(const Status&) override; + + struct TraceKey { + std::string key; + std::string value; + }; + /// \brief Get the trace context. + std::vector GetTraceContext() const; + + private: + class Impl; + friend class TracingServerMiddlewareFactory; + + explicit TracingServerMiddleware(std::unique_ptr impl); + std::unique_ptr impl_; +}; + +} // namespace flight +} // namespace arrow diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 16e4aad5a00c5..b6c9177195a1c 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -1741,13 +1741,22 @@ cdef class ServerCallContext(_Weakrefable): CServerMiddleware* c_middleware = \ self.context.GetMiddleware(CPyServerMiddlewareName) CPyServerMiddleware* middleware + vector[CTracingServerMiddlewareTraceKey] c_trace_context + if c_middleware == NULL: + c_middleware = self.context.GetMiddleware(tobytes(key)) + if c_middleware == NULL: return None - if c_middleware.name() != CPyServerMiddlewareName: - return None - middleware = c_middleware - py_middleware = <_ServerMiddlewareWrapper> middleware.py_object() - return py_middleware.middleware.get(key) + elif c_middleware.name() == CPyServerMiddlewareName: + middleware = c_middleware + py_middleware = <_ServerMiddlewareWrapper> middleware.py_object() + return py_middleware.middleware.get(key) + elif c_middleware.name() == CTracingServerMiddlewareName: + c_trace_context = ( c_middleware + ).GetTraceContext() + trace_context = {pair.key: pair.value for pair in c_trace_context} + return TracingServerMiddleware(trace_context) + return None @staticmethod cdef ServerCallContext wrap(const CServerCallContext& context): @@ -2528,6 +2537,22 @@ cdef class ServerMiddlewareFactory(_Weakrefable): """ +cdef class TracingServerMiddlewareFactory(ServerMiddlewareFactory): + """A factory for tracing middleware instances. + + This enables OpenTelemetry support in Arrow (if Arrow was compiled + with OpenTelemetry support enabled). A new span will be started on + each RPC call. The TracingServerMiddleware instance can then be + retrieved within an RPC handler to get the propagated context, + which can be used to start a new span on the Python side. + + Because the Python/C++ OpenTelemetry libraries do not + interoperate, spans on the C++ side are not directly visible to + the Python side and vice versa. + + """ + + cdef class ServerMiddleware(_Weakrefable): """Server-side middleware for a call, instantiated per RPC. @@ -2574,6 +2599,13 @@ cdef class ServerMiddleware(_Weakrefable): c_instance[0].reset(new CPyServerMiddleware(py_middleware, vtable)) +class TracingServerMiddleware(ServerMiddleware): + __slots__ = ["trace_context"] + + def __init__(self, trace_context): + self.trace_context = trace_context + + cdef class _ServerMiddlewareFactoryWrapper(ServerMiddlewareFactory): """Wrapper to bundle server middleware into a single C++ one.""" @@ -2739,7 +2771,27 @@ cdef class FlightServerBase(_Weakrefable): c_options.get().tls_certificates.push_back(c_cert) if middleware: - py_middleware = _ServerMiddlewareFactoryWrapper(middleware) + non_tracing_middleware = {} + enable_tracing = None + for key, factory in middleware.items(): + if isinstance(factory, TracingServerMiddlewareFactory): + if enable_tracing is not None: + raise ValueError( + "Can only provide " + "TracingServerMiddlewareFactory once") + if tobytes(key) == CPyServerMiddlewareName: + raise ValueError(f"Middleware key cannot be {key}") + enable_tracing = key + else: + non_tracing_middleware[key] = factory + + if enable_tracing: + c_middleware.first = tobytes(enable_tracing) + c_middleware.second = MakeTracingServerMiddlewareFactory() + c_options.get().middleware.push_back(c_middleware) + + py_middleware = _ServerMiddlewareFactoryWrapper( + non_tracing_middleware) c_middleware.first = CPyServerMiddlewareName c_middleware.second.reset(new CPyServerMiddlewareFactory( py_middleware, diff --git a/python/pyarrow/flight.py b/python/pyarrow/flight.py index 0664ff2c99276..8f9fa6fa7c98e 100644 --- a/python/pyarrow/flight.py +++ b/python/pyarrow/flight.py @@ -60,4 +60,5 @@ ServerMiddleware, ServerMiddlewareFactory, Ticket, + TracingServerMiddlewareFactory, ) diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 6377459404c3e..3301c1b6360b2 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -22,8 +22,8 @@ from pyarrow.includes.libarrow cimport * cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: - cdef char* CPyServerMiddlewareName\ - " arrow::py::flight::kPyServerMiddlewareName" + cdef char* CTracingServerMiddlewareName\ + " arrow::flight::TracingServerMiddleware::kMiddlewareName" cdef cppclass CActionType" arrow::flight::ActionType": c_string type @@ -322,6 +322,20 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: " arrow::flight::ClientMiddlewareFactory": pass + cpdef cppclass CTracingServerMiddlewareTraceKey\ + " arrow::flight::TracingServerMiddleware::TraceKey": + CTracingServerMiddlewareTraceKey() + c_string key + c_string value + + cdef cppclass CTracingServerMiddleware\ + " arrow::flight::TracingServerMiddleware"(CServerMiddleware): + vector[CTracingServerMiddlewareTraceKey] GetTraceContext() + + cdef shared_ptr[CServerMiddlewareFactory] \ + MakeTracingServerMiddlewareFactory\ + " arrow::flight::MakeTracingServerMiddlewareFactory"() + cdef cppclass CFlightServerOptions" arrow::flight::FlightServerOptions": CFlightServerOptions(const CLocation& location) CLocation location @@ -472,6 +486,9 @@ ctypedef CStatus cb_client_middleware_start_call( unique_ptr[CClientMiddleware]*) cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: + cdef char* CPyServerMiddlewareName\ + " arrow::py::flight::kPyServerMiddlewareName" + cdef cppclass PyFlightServerVtable: PyFlightServerVtable() function[cb_list_flights] list_flights diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 72d1fa5ec3359..69318a5535b67 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -2196,3 +2196,33 @@ def test_interpreter_shutdown(): See https://issues.apache.org/jira/browse/ARROW-16597. """ util.invoke_script("arrow_16597.py") + + +class TracingFlightServer(FlightServerBase): + """A server that echoes back trace context values.""" + + def do_action(self, context, action): + trace_context = context.get_middleware("tracing").trace_context + # Don't turn this method into a generator since then + # trace_context will be evaluated after we've exited the scope + # of the OTel span (and so the value we want won't be present) + return ((f"{key}: {value}").encode("utf-8") + for (key, value) in trace_context.items()) + + +def test_tracing(): + with TracingFlightServer(middleware={ + "tracing": flight.TracingServerMiddlewareFactory(), + }) as server, \ + FlightClient(('localhost', server.port)) as client: + # We can't tell if Arrow was built with OpenTelemetry support, + # so we can't count on any particular values being there; we + # can only ensure things don't blow up either way. + options = flight.FlightCallOptions(headers=[ + # Pretend we have an OTel implementation + (b"traceparent", b"00-000ff00f00f0ff000f0f00ff0f00fff0-" + b"000f0000f0f00000-00"), + (b"tracestate", b""), + ]) + for value in client.do_action((b"", b""), options=options): + pass