From d907f5fec864fe33083f2ba4c88b3ed4c9e636e3 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Mon, 1 May 2023 11:40:35 +0900 Subject: [PATCH] GH-35375: [C++][FlightRPC] Add `arrow::flight::ServerCallContext::incoming_headers()` --- cpp/src/arrow/flight/middleware.h | 8 +------- cpp/src/arrow/flight/server.h | 2 ++ .../flight/transport/grpc/grpc_server.cc | 19 +++++++++++-------- .../arrow/flight/transport/ucx/ucx_server.cc | 2 ++ cpp/src/arrow/flight/types.h | 6 ++++++ 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/flight/middleware.h b/cpp/src/arrow/flight/middleware.h index dc1ad24bc5c89..e936b9f02020c 100644 --- a/cpp/src/arrow/flight/middleware.h +++ b/cpp/src/arrow/flight/middleware.h @@ -20,23 +20,17 @@ #pragma once -#include #include #include #include #include -#include "arrow/flight/visibility.h" // IWYU pragma: keep +#include "arrow/flight/types.h" #include "arrow/status.h" namespace arrow { namespace flight { -/// \brief Headers sent from the client or server. -/// -/// Header values are ordered. -using CallHeaders = std::multimap; - /// \brief A write-only wrapper around headers for an RPC call. class ARROW_FLIGHT_EXPORT AddCallHeaders { public: diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index 1d1b1a50f37b6..6fb8ab1213117 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -137,6 +137,8 @@ class ARROW_FLIGHT_EXPORT ServerCallContext { /// \brief Check if the current RPC has been cancelled (by the client, by /// a network error, etc.). virtual bool is_cancelled() const = 0; + /// \brief The headers sent by the client for this call. + virtual const CallHeaders& incoming_headers() const = 0; }; class ARROW_FLIGHT_EXPORT FlightServerOptions { diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc index a643111e3b2b0..acf80462f1a92 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc @@ -117,11 +117,18 @@ class GrpcServerAuthSender : public ServerAuthSender { class GrpcServerCallContext : public ServerCallContext { explicit GrpcServerCallContext(::grpc::ServerContext* context) - : context_(context), peer_(context_->peer()) {} + : context_(context), peer_(context_->peer()) { + for (const auto& entry : context->client_metadata()) { + incoming_headers_.insert( + {std::string_view(entry.first.data(), entry.first.length()), + std::string_view(entry.second.data(), entry.second.length())}); + } + } const std::string& peer_identity() const override { return peer_identity_; } const std::string& peer() const override { return peer_; } bool is_cancelled() const override { return context_->IsCancelled(); } + const CallHeaders& incoming_headers() const override { return incoming_headers_; } // Helper method that runs interceptors given the result of an RPC, // then returns the final gRPC status to send to the client @@ -156,6 +163,7 @@ class GrpcServerCallContext : public ServerCallContext { std::string peer_identity_; std::vector> middleware_; std::unordered_map> middleware_map_; + CallHeaders incoming_headers_; }; class GrpcAddServerHeaders : public AddCallHeaders { @@ -310,17 +318,12 @@ class GrpcServiceHandler final : public FlightService::Service { GrpcServerCallContext& flight_context) { // Run server middleware const CallInfo info{method}; - CallHeaders incoming_headers; - for (const auto& entry : context->client_metadata()) { - incoming_headers.insert( - {std::string_view(entry.first.data(), entry.first.length()), - std::string_view(entry.second.data(), entry.second.length())}); - } GrpcAddServerHeaders outgoing_headers(context); for (const auto& factory : middleware_) { std::shared_ptr instance; - Status result = factory.second->StartCall(info, incoming_headers, &instance); + Status result = + factory.second->StartCall(info, flight_context.incoming_headers(), &instance); if (!result.ok()) { // Interceptor rejected call, end the request on all existing // interceptors diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index 946b29383bf25..4a573d742929a 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -76,9 +76,11 @@ class UcxServerCallContext : public flight::ServerCallContext { return nullptr; } bool is_cancelled() const override { return false; } + const CallHeaders& incoming_headers() const override { return incoming_headers_; } private: std::string peer_; + CallHeaders incoming_headers_; }; class UcxServerStream : public internal::ServerDataStream { diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 39353bcb9977a..9d92f0be95538 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -123,6 +124,11 @@ ARROW_FLIGHT_EXPORT Status MakeFlightError(FlightStatusCode code, std::string message, std::string extra_info = {}); +/// \brief Headers sent from the client or server. +/// +/// Header values are ordered. +using CallHeaders = std::multimap; + /// \brief A TLS certificate plus key. struct ARROW_FLIGHT_EXPORT CertKeyPair { /// \brief The certificate in PEM format.