diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index 628b02b9d2811..f25ff53b54142 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -41,7 +41,8 @@ set(ARROW_FLIGHT_SQL_SRCS sql_info_internal.cc column_metadata.cc client.cc - protocol_internal.cc) + protocol_internal.cc + server_session_middleware.cc) add_arrow_lib(arrow_flight_sql CMAKE_PACKAGE_NAME diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc index 25bf8e384ef06..e4d138f524ff7 100644 --- a/cpp/src/arrow/flight/sql/client.cc +++ b/cpp/src/arrow/flight/sql/client.cc @@ -31,6 +31,13 @@ #include "arrow/result.h" #include "arrow/util/logging.h" +// Lambda helper & CTAD +template +struct overloaded : Ts... { using Ts::operator()...; }; +template // CTAD will not be needed for >=C++20 +overloaded(Ts...) -> overloaded; + +namespace pb = arrow::flight::protocol; namespace flight_sql_pb = arrow::flight::protocol::sql; namespace arrow { @@ -802,6 +809,149 @@ ::arrow::Result FlightSqlClient::CancelQuery( return Status::IOError("Server returned unknown result ", result.result()); } +::arrow::Result> +FlightSqlClient::SetSessionOptions( + const FlightCallOptions& options, + const std::map& session_options) { + pb::ActionSetSessionOptionsRequest request; + auto* options_map = request.mutable_session_options(); + + for (const auto & [name, opt_value] : session_options) { + pb::SessionOptionValue pb_opt_value; + + if (opt_value.index() == std::variant_npos) + return Status::Invalid("Undefined SessionOptionValue type "); + + std::visit(overloaded{ + // TODO move this somewhere common that can have Proto-involved code + [&](std::string v) { pb_opt_value.set_string_value(v); }, + [&](bool v) { pb_opt_value.set_bool_value(v); }, + [&](int32_t v) { pb_opt_value.set_int32_value(v); }, + [&](int64_t v) { pb_opt_value.set_int64_value(v); }, + [&](float v) { pb_opt_value.set_float_value(v); }, + [&](double v) { pb_opt_value.set_double_value(v); }, + [&](std::vector v) { + auto* string_list_value = pb_opt_value.mutable_string_list_value(); + for (const std::string& s : v) + string_list_value->add_values(s); + } + }, opt_value); + (*options_map)[name] = std::move(pb_opt_value); + } + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("SetSessionOptions", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + pb::ActionSetSessionOptionsResult pb_result; + ARROW_RETURN_NOT_OK(ReadResult(results.get(), &pb_result)); + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + std::map result; + for (const auto & [result_key, result_value] : pb_result.results()) { + switch (result_value) { + case pb::ActionSetSessionOptionsResult::SET_SESSION_OPTION_RESULT_UNSPECIFIED: + result[result_key] = SetSessionOptionResult::kUnspecified; + break; + case pb::ActionSetSessionOptionsResult + ::SET_SESSION_OPTION_RESULT_OK: + result[result_key] = SetSessionOptionResult::kOk; + break; + case pb::ActionSetSessionOptionsResult + ::SET_SESSION_OPTION_RESULT_INVALID_VALUE: + result[result_key] = SetSessionOptionResult::kInvalidResult; + break; + case pb::ActionSetSessionOptionsResult::SET_SESSION_OPTION_RESULT_ERROR: + result[result_key] = SetSessionOptionResult::kError; + break; + default: + return Status::IOError("Invalid SetSessionOptionResult value for key " + + result_key); + } + } + + return result; +} + +::arrow::Result> +FlightSqlClient::GetSessionOptions ( + const FlightCallOptions& options) { + pb::ActionGetSessionOptionsRequest request; + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("GetSessionOptions", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + pb::ActionGetSessionOptionsResult pb_result; + ARROW_RETURN_NOT_OK(ReadResult(results.get(), &pb_result)); + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + + std::map result; + if (pb_result.session_options_size() > 0) { + for (auto& [pb_opt_name, pb_opt_val] : pb_result.session_options()) { + SessionOptionValue val; + switch (pb_opt_val.option_value_case()) { + case pb::SessionOptionValue::OPTION_VALUE_NOT_SET: + return Status::Invalid("Unset option_value for name '" + pb_opt_name + "'"); + case pb::SessionOptionValue::kStringValue: + val = pb_opt_val.string_value(); + break; + case pb::SessionOptionValue::kBoolValue: + val = pb_opt_val.bool_value(); + break; + case pb::SessionOptionValue::kInt32Value: + val = pb_opt_val.int32_value(); + break; + case pb::SessionOptionValue::kInt64Value: + val = pb_opt_val.int64_value(); + break; + case pb::SessionOptionValue::kFloatValue: + val = pb_opt_val.float_value(); + break; + case pb::SessionOptionValue::kDoubleValue: + val = pb_opt_val.double_value(); + break; + case pb::SessionOptionValue::kStringListValue: + val.emplace>(); + std::get>(val) + .reserve(pb_opt_val.string_list_value().values_size()); + for (const std::string& s : pb_opt_val.string_list_value().values()) + std::get>(val).push_back(s); + break; + } + result[pb_opt_name] = std::move(val); + } + } + + return result; +} + +::arrow::Result FlightSqlClient::CloseSession( + const FlightCallOptions& options) { + pb::ActionCloseSessionRequest request; + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("CloseSession", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + pb::ActionCloseSessionResult result; + ARROW_RETURN_NOT_OK(ReadResult(results.get(), &result)); + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + switch (result.result()) { + case pb::ActionCloseSessionResult::CLOSE_RESULT_UNSPECIFIED: + return CloseSessionResult::kUnspecified; + case pb::ActionCloseSessionResult::CLOSE_RESULT_CLOSED: + return CloseSessionResult::kClosed; + case pb::ActionCloseSessionResult::CLOSE_RESULT_CLOSING: + return CloseSessionResult::kClosing; + case pb::ActionCloseSessionResult::CLOSE_RESULT_NOT_CLOSEABLE: + return CloseSessionResult::kNotClosable; + default: + break; + } + + return Status::IOError("Server returned unknown result ", result.result()); +} + Status FlightSqlClient::Close() { return impl_->Close(); } std::ostream& operator<<(std::ostream& os, CancelResult result) { diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h index 648f71563e9c7..46475625e2ff8 100644 --- a/cpp/src/arrow/flight/sql/client.h +++ b/cpp/src/arrow/flight/sql/client.h @@ -19,6 +19,7 @@ #include #include +#include #include #include "arrow/flight/client.h" @@ -329,6 +330,25 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { /// \param[in] info The FlightInfo of the query to cancel. ::arrow::Result CancelQuery(const FlightCallOptions& options, const FlightInfo& info); + + /// \brief Sets session options. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] session_options The session options to set. + ::arrow::Result> SetSessionOptions( + const FlightCallOptions& options, + const std::map& session_options); + + /// \brief Gets current session options. + /// + /// \param[in] options RPC-layer hints for this call. + ::arrow::Result> GetSessionOptions( + const FlightCallOptions& options); + + /// \brief Explicitly closes the session if applicable. + /// + /// \param[in] options RPC-layer hints for this call. + ::arrow::Result CloseSession(const FlightCallOptions& options); /// \brief Explicitly shut down and clean up the client. Status Close(); diff --git a/cpp/src/arrow/flight/sql/protocol_internal.h b/cpp/src/arrow/flight/sql/protocol_internal.h index ce50ad2f61b1e..71e82eb83bb4f 100644 --- a/cpp/src/arrow/flight/sql/protocol_internal.h +++ b/cpp/src/arrow/flight/sql/protocol_internal.h @@ -24,3 +24,4 @@ #include "arrow/flight/sql/visibility.h" #include "arrow/flight/sql/FlightSql.pb.h" // IWYU pragma: export +#include "arrow/flight/Flight.pb.h" diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index 7f6d9b75a88f7..0f670f9d8cfde 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -32,6 +32,12 @@ #include "arrow/type.h" #include "arrow/util/checked_cast.h" +// Lambda helper & CTAD +template +struct overloaded : Ts... { using Ts::operator()...; }; +template // CTAD will not be needed for >=C++20 +overloaded(Ts...) -> overloaded; + #define PROPERTY_TO_OPTIONAL(COMMAND, PROPERTY) \ COMMAND.has_##PROPERTY() ? std::make_optional(COMMAND.PROPERTY()) : std::nullopt @@ -269,6 +275,17 @@ arrow::Result ParseActionCancelQueryRequest( return result; } +arrow::Result ParseActionCloseSessionRequest( + const google::protobuf::Any& any) { + pb::ActionCloseSessionRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionCloseSessionRequest"); + } + + ActionCloseSessionRequest result; + return result; +} + arrow::Result ParseActionCreatePreparedStatementRequest(const google::protobuf::Any& any) { pb::sql::ActionCreatePreparedStatementRequest command; @@ -359,6 +376,64 @@ arrow::Result ParseActionEndTransactionRequest( return result; } +arrow::Result ParseActionSetSessionOptionsRequest( + const google::protobuf::Any& any) { + pb::ActionSetSessionOptionsRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionSetSessionOptionsRequest"); + } + + ActionSetSessionOptionsRequest result; + if (command.session_options_size() > 0) { + for (const auto & [name, pb_val] : command.session_options()) { + SessionOptionValue val; + switch (pb_val.option_value_case()) { + case pb::SessionOptionValue::OPTION_VALUE_NOT_SET: + return Status::Invalid("Unset SessionOptionValue for name '" + name + "'"); + case pb::SessionOptionValue::kStringValue: + val = pb_val.string_value(); + break; + case pb::SessionOptionValue::kBoolValue: + val = pb_val.bool_value(); + break; + case pb::SessionOptionValue::kInt32Value: + val = pb_val.int32_value(); + break; + case pb::SessionOptionValue::kInt64Value: + val = pb_val.int64_value(); + break; + case pb::SessionOptionValue::kFloatValue: + val = pb_val.float_value(); + break; + case pb::SessionOptionValue::kDoubleValue: + val = pb_val.double_value(); + break; + case pb::SessionOptionValue::kStringListValue: + val.emplace>(); + std::get>(val) + .reserve(pb_val.string_list_value().values_size()); + for (const std::string& s : pb_val.string_list_value().values()) + std::get>(val).push_back(s); + break; + } + result.session_options[name] = std::move(val); + } + } + + return result; +} + +arrow::Result ParseActionGetSessionOptionsRequest( + const google::protobuf::Any& any) { + pb::ActionGetSessionOptionsRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionGetSessionOptionsRequest"); + } + + ActionGetSessionOptionsRequest result; + return result; +} + arrow::Result PackActionResult(const google::protobuf::Message& message) { google::protobuf::Any any; if (!any.PackFrom(message)) { @@ -423,6 +498,78 @@ arrow::Result PackActionResult(ActionCreatePreparedStatementResult resul return PackActionResult(pb_result); } +arrow::Result PackActionResult(ActionSetSessionOptionsResult result) { + pb::ActionSetSessionOptionsResult pb_result; + auto* pb_results_map = pb_result.mutable_results(); + for (const auto& [opt_name, res] : result.results) { + pb::ActionSetSessionOptionsResult_SetSessionOptionResult val; + switch (res) { + case SetSessionOptionResult::kUnspecified: + val = pb::ActionSetSessionOptionsResult::SET_SESSION_OPTION_RESULT_UNSPECIFIED; + break; + case SetSessionOptionResult::kOk: + val = pb::ActionSetSessionOptionsResult::SET_SESSION_OPTION_RESULT_OK; + break; + case SetSessionOptionResult::kInvalidResult: + val = pb::ActionSetSessionOptionsResult::SET_SESSION_OPTION_RESULT_INVALID_VALUE; + break; + case SetSessionOptionResult::kError: + val = pb::ActionSetSessionOptionsResult::SET_SESSION_OPTION_RESULT_ERROR; + break; + } + (*pb_results_map)[opt_name] = val; + } + return PackActionResult(pb_result); +} + +arrow::Result PackActionResult(ActionGetSessionOptionsResult result) { + pb::ActionGetSessionOptionsResult pb_result; + auto* pb_results = pb_result.mutable_session_options(); + for (const auto& [name, opt_value] : result.session_options) { + pb::SessionOptionValue pb_opt_value; + + if (opt_value.index() == std::variant_npos) + return Status::Invalid("Undefined SessionOptionValue type"); + + std::visit(overloaded{ + // TODO move this somewhere common that can have Proto-involved code + [&](std::string v) { pb_opt_value.set_string_value(v); }, + [&](bool v) { pb_opt_value.set_bool_value(v); }, + [&](int32_t v) { pb_opt_value.set_int32_value(v); }, + [&](int64_t v) { pb_opt_value.set_int64_value(v); }, + [&](float v) { pb_opt_value.set_float_value(v); }, + [&](double v) { pb_opt_value.set_double_value(v); }, + [&](std::vector v) { + auto* string_list_value = pb_opt_value.mutable_string_list_value(); + for (const std::string& s : v) + string_list_value->add_values(s); + } + }, opt_value); + (*pb_results)[name] = std::move(pb_opt_value); + } + + return PackActionResult(pb_result); +} + +arrow::Result PackActionResult(CloseSessionResult result) { + pb::ActionCloseSessionResult pb_result; + switch (result) { + case CloseSessionResult::kUnspecified: + pb_result.set_result(pb::ActionCloseSessionResult::CLOSE_RESULT_UNSPECIFIED); + break; + case CloseSessionResult::kClosed: + pb_result.set_result(pb::ActionCloseSessionResult::CLOSE_RESULT_CLOSED); + break; + case CloseSessionResult::kClosing: + pb_result.set_result(pb::ActionCloseSessionResult::CLOSE_RESULT_CLOSING); + break; + case CloseSessionResult::kNotClosable: + pb_result.set_result(pb::ActionCloseSessionResult::CLOSE_RESULT_NOT_CLOSEABLE); + break; + } + return PackActionResult(pb_result); +} + } // namespace arrow::Result StatementQueryTicket::Deserialize( @@ -747,8 +894,11 @@ Status FlightSqlServerBase::ListActions(const ServerCallContext& context, FlightSqlServerBase::kCreatePreparedStatementActionType, FlightSqlServerBase::kCreatePreparedSubstraitPlanActionType, FlightSqlServerBase::kClosePreparedStatementActionType, + FlightSqlServerBase::kCloseSessionActionType, FlightSqlServerBase::kEndSavepointActionType, FlightSqlServerBase::kEndTransactionActionType, + FlightSqlServerBase::kSetSessionOptionsActionType, + FlightSqlServerBase::kGetSessionOptionsActionType }; return Status::OK(); } @@ -784,6 +934,13 @@ Status FlightSqlServerBase::DoAction(const ServerCallContext& context, ARROW_ASSIGN_OR_RAISE(CancelResult result, CancelQuery(context, internal_command)); ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(result)); + results.push_back(std::move(packed_result)); + } else if (action.type == FlightSqlServerBase::kCloseSessionActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionCloseSessionRequest internal_command, + ParseActionCloseSessionRequest(any)); + ARROW_ASSIGN_OR_RAISE(CloseSessionResult result, CloseSession(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(std::move(result))); + results.push_back(std::move(packed_result)); } else if (action.type == FlightSqlServerBase::kCreatePreparedStatementActionType.type) { @@ -815,6 +972,22 @@ Status FlightSqlServerBase::DoAction(const ServerCallContext& context, ARROW_ASSIGN_OR_RAISE(ActionEndTransactionRequest internal_command, ParseActionEndTransactionRequest(any)); ARROW_RETURN_NOT_OK(EndTransaction(context, internal_command)); + } else if (action.type == FlightSqlServerBase::kSetSessionOptionsActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionSetSessionOptionsRequest internal_command, + ParseActionSetSessionOptionsRequest(any)); + ARROW_ASSIGN_OR_RAISE(ActionSetSessionOptionsResult result, + SetSessionOptions(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); + } else if (action.type == FlightSqlServerBase::kGetSessionOptionsActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionGetSessionOptionsRequest internal_command, + ParseActionGetSessionOptionsRequest(any)); + ARROW_ASSIGN_OR_RAISE(ActionGetSessionOptionsResult result, + GetSessionOptions(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); } else { return Status::NotImplemented("Action not implemented: ", action.type); } @@ -1042,6 +1215,12 @@ arrow::Result FlightSqlServerBase::CancelQuery( return Status::NotImplemented("CancelQuery not implemented"); } +arrow::Result FlightSqlServerBase::CloseSession( + const ServerCallContext& context, + const ActionCloseSessionRequest& request) { + return Status::NotImplemented("CloseSession not implemented"); +} + arrow::Result FlightSqlServerBase::CreatePreparedStatement( const ServerCallContext& context, @@ -1072,6 +1251,18 @@ Status FlightSqlServerBase::EndTransaction(const ServerCallContext& context, return Status::NotImplemented("EndTransaction not implemented"); } +arrow::Result FlightSqlServerBase::SetSessionOptions ( + const ServerCallContext& context, + const ActionSetSessionOptionsRequest& request) { + return Status::NotImplemented("SetSessionOptions not implemented"); +} + +arrow::Result FlightSqlServerBase::GetSessionOptions ( + const ServerCallContext& context, + const ActionGetSessionOptionsRequest& request) { + return Status::NotImplemented("GetSessionOptions not implemented"); +} + Status FlightSqlServerBase::DoPutPreparedStatementQuery( const ServerCallContext& context, const PreparedStatementQuery& command, FlightMessageReader* reader, FlightMetadataWriter* writer) { diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h index 65f6670171dfd..6a18900355043 100644 --- a/cpp/src/arrow/flight/sql/server.h +++ b/cpp/src/arrow/flight/sql/server.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include "arrow/flight/server.h" @@ -225,6 +226,27 @@ struct ARROW_FLIGHT_SQL_EXPORT ActionCreatePreparedStatementResult { std::string prepared_statement_handle; }; +/// \brief A request to close the open client session. +struct ARROW_FLIGHT_SQL_EXPORT ActionCloseSessionRequest {}; + +/// \brief A request to set a set of session options by key/value. +struct ARROW_FLIGHT_SQL_EXPORT ActionSetSessionOptionsRequest { + std::map session_options; +}; + +/// \brief The result(s) of setting session option(s). +struct ARROW_FLIGHT_SQL_EXPORT ActionSetSessionOptionsResult { + std::map results; +}; + +/// \brief A request to get current session options. +struct ARROW_FLIGHT_SQL_EXPORT ActionGetSessionOptionsRequest {}; + +/// \brief The current session options. +struct ARROW_FLIGHT_SQL_EXPORT ActionGetSessionOptionsResult { + std::map session_options; +}; + /// @} /// \brief A utility function to create a ticket (a opaque binary @@ -594,6 +616,27 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { virtual Status EndTransaction(const ServerCallContext& context, const ActionEndTransactionRequest& request); + /// \brief Set server session option(s). + /// \param[in] context The call context. + /// \param[in] request The session options to set. + virtual arrow::Result SetSessionOptions( + const ServerCallContext& context, + const ActionSetSessionOptionsRequest& request); + + /// \brief Get server session option(s). + /// \param[in] context The call context. + /// \param[in] request Request object. + virtual arrow::Result GetSessionOptions( + const ServerCallContext& context, + const ActionGetSessionOptionsRequest& request); + + /// \brief Close/invalidate the session. + /// \param[in] context The call context. + /// \param[in] request Request object. + virtual arrow::Result CloseSession( + const ServerCallContext& context, + const ActionCloseSessionRequest& request); + /// \brief Attempt to explicitly cancel a query. /// \param[in] context The call context. /// \param[in] request The query to cancel. @@ -661,6 +704,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { "Closes a reusable prepared statement resource on the server.\n" "Request Message: ActionClosePreparedStatementRequest\n" "Response Message: N/A"}; + const ActionType kCloseSessionActionType = + ActionType{"CloseSession", + "Explicitly close an open session.\n" + "Request Message: ActionCloseSessionRequest\n" + "Response Message: ActionCloseSessionResult"}; const ActionType kEndSavepointActionType = ActionType{"EndSavepoint", "End a savepoint.\n" @@ -671,6 +719,16 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { "End a savepoint.\n" "Request Message: ActionEndTransactionRequest\n" "Response Message: N/A"}; + const ActionType kSetSessionOptionsActionType = + ActionType{"SetSessionOptions", + "Set a series of session options.\n" + "Request Message: ActionSetSessionOptionsRequest\n" + "Response Message: ActionSetSessionOptionsResult"}; + const ActionType kGetSessionOptionsActionType = + ActionType{"GetSessionOption", + "Get a series of session options.\n" + "Request Message: ActionGetSessionOptionRequest\n" + "Response Message: ActionGetSessionOptionResult"}; Status ListActions(const ServerCallContext& context, std::vector* actions) final; diff --git a/cpp/src/arrow/flight/sql/server_session_middleware.cc b/cpp/src/arrow/flight/sql/server_session_middleware.cc new file mode 100644 index 0000000000000..4e80cad3e8db6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/server_session_middleware.cc @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include "arrow/flight/sql/server_session_middleware.h" +#include +#include +#include +#include + +namespace arrow { +namespace flight { +namespace sql { + +/// \brief A factory for ServerSessionMiddleware, itself storing session data. +class ServerSessionMiddlewareFactory : public ServerMiddlewareFactory { + protected: + std::map> session_store_; + std::shared_mutex session_store_lock_; + boost::uuids::random_generator uuid_generator_; + + std::vector> ParseCookieString( + const std::string_view& s) { + const std::string list_sep = "; "; + const std::string pair_sep = "="; + const size_t pair_sep_len = pair_sep.length(); + + std::vector> result; + + size_t cur = 0; + while (cur < s.length()) { + const size_t end = s.find(list_sep, cur); + size_t len; + if (end == std::string::npos) { + // No (further) list delimiters + len = std::string::npos; + cur = s.length(); + } else { + len = end - cur; + cur = end; + } + const std::string_view tok = s.substr(cur, len); + + const size_t val_pos = tok.find(pair_sep); + result.emplace_back( + tok.substr(0, val_pos), + tok.substr(val_pos + pair_sep_len, std::string::npos) + ); + } + + return result; + } + + public: + Status StartCall(const CallInfo &, const CallHeaders &incoming_headers, + std::shared_ptr *middleware) { + std::string session_id; + + const std::pair& + headers_it_pr = incoming_headers.equal_range("cookie"); + for (auto itr = headers_it_pr.first; itr != headers_it_pr.second; ++itr) { + const std::string_view& cookie_header = itr->second; + const std::vector> cookies = + ParseCookieString(cookie_header); + for (const std::pair& cookie : cookies) { + if (cookie.first == kSessionCookieName) { + session_id = cookie.second; + if (!session_id.length()) + return Status::Invalid( + "Empty " + static_cast(kSessionCookieName) + + " cookie value."); + } + } + if (session_id.length()) break; + } + + if (!session_id.length()) { + // No cookie was found + *middleware = std::shared_ptr( + new ServerSessionMiddleware(this, incoming_headers)); + } else { + try { + const std::shared_lock l(session_store_lock_); + auto session = session_store_.at(session_id); + *middleware = std::shared_ptr( + new ServerSessionMiddleware(this, incoming_headers, + session, session_id)); + } catch (std::out_of_range& e) { + return Status::Invalid( + "Invalid or expired " + + static_cast(kSessionCookieName) + " cookie."); + } + } + + return Status::OK(); + } + + /// \brief Get a new, empty session option map and its id key. + std::shared_ptr GetNewSession(std::string* session_id) { + std::string new_id = boost::lexical_cast(uuid_generator_()); + *session_id = new_id; + auto session = std::make_shared(); + + const std::unique_lock l(session_store_lock_); + session_store_[new_id] = session; + + return session; + } +}; + +ServerSessionMiddleware::ServerSessionMiddleware(ServerSessionMiddlewareFactory* factory, + const CallHeaders& headers) + : factory_(factory), headers_(headers), existing_session(false) {} + +ServerSessionMiddleware::ServerSessionMiddleware( + ServerSessionMiddlewareFactory* factory, const CallHeaders& headers, + std::shared_ptr session, + std::string session_id) + : factory_(factory), headers_(headers), session_(session), existing_session(true) {} + +void ServerSessionMiddleware::SendingHeaders(AddCallHeaders* addCallHeaders) { + if (!existing_session && session_) { + addCallHeaders->AddHeader( + "set-cookie", + static_cast(kSessionCookieName) + "=" + session_id_); + } +} + +void ServerSessionMiddleware::CallCompleted(const Status&) {} + +bool ServerSessionMiddleware::HasSession() const { + return static_cast(session_); +} +std::shared_ptr ServerSessionMiddleware::GetSession() { + if (!session_) + session_ = factory_->GetNewSession(&session_id_); + return session_; +} +const CallHeaders& ServerSessionMiddleware::GetCallHeaders() const { + return headers_; +} + + + +std::shared_ptr MakeServerSessionMiddlewareFactory() { + return std::shared_ptr( + new ServerSessionMiddlewareFactory()); +} + +SessionOptionValue FlightSqlSession::GetSessionOption(const std::string& k) { + const std::shared_lock l(map_lock_); + return map_.at(k); +} +void FlightSqlSession::SetSessionOption(const std::string& k, const SessionOptionValue& v) { + const std::unique_lock l(map_lock_); + map_[k] = v; +} +void FlightSqlSession::EraseSessionOption(const std::string& k) { + const std::unique_lock l(map_lock_); + map_.erase(k); +} + +} // namespace sql +} // namespace flight +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/flight/sql/server_session_middleware.h b/cpp/src/arrow/flight/sql/server_session_middleware.h new file mode 100644 index 0000000000000..ca44e36ace910 --- /dev/null +++ b/cpp/src/arrow/flight/sql/server_session_middleware.h @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Middleware for handling Flight SQL Sessions including session cookie handling. +// Currently experimental. + +#pragma once + +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/sql/types.h" + +namespace arrow { +namespace flight { +namespace sql { + +class ServerSessionMiddlewareFactory; + +static constexpr char const kSessionCookieName[] = + "flight_sql_session_id"; + +class FlightSqlSession { + protected: + std::map map_; + std::shared_mutex map_lock_; + public: + /// \brief Get session option by key + SessionOptionValue GetSessionOption(const std::string&); + /// \brief Set session option by key to given value + void SetSessionOption(const std::string&, const SessionOptionValue&); + /// \brief Idempotently remove key from this call's Session, if Session & key exist + void EraseSessionOption(const std::string&); +}; + +/// \brief A middleware to handle Session option persistence and related *Cookie headers. +class ARROW_FLIGHT_SQL_EXPORT ServerSessionMiddleware + : public ServerMiddleware { + public: + static constexpr char const kMiddlewareName[] = + "arrow::flight::sql::ServerSessionMiddleware"; + + std::string name() const override { return kMiddlewareName; } + void SendingHeaders(AddCallHeaders*) override; + void CallCompleted(const Status&) override; + + /// \brief Is there an existing session (either existing or new) + bool HasSession() const; + /// \brief Get existing or new call-associated session + std::shared_ptr GetSession(); + /// \brief Get request headers, in lieu of a provided or created session. + const CallHeaders& GetCallHeaders() const; + + protected: + friend class ServerSessionMiddlewareFactory; + ServerSessionMiddlewareFactory* factory_; + const CallHeaders& headers_; + std::shared_ptr session_; + std::string session_id_; + const bool existing_session; + + ServerSessionMiddleware(ServerSessionMiddlewareFactory*, + const CallHeaders&); + ServerSessionMiddleware(ServerSessionMiddlewareFactory*, + const CallHeaders&, + std::shared_ptr, + std::string session_id); +}; + +/// \brief Returns a ServerMiddlewareFactory that handles Session option storage. +ARROW_FLIGHT_SQL_EXPORT std::shared_ptr +MakeServerSessionMiddlewareFactory(); + +} // namespace sql +} // namespace flight +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/flight/sql/types.h b/cpp/src/arrow/flight/sql/types.h index 293b1d5579ec0..0249ba01de05c 100644 --- a/cpp/src/arrow/flight/sql/types.h +++ b/cpp/src/arrow/flight/sql/types.h @@ -44,6 +44,10 @@ using SqlInfoResult = /// \brief Map SQL info identifier to its value. using SqlInfoResultMap = std::unordered_map; +/// \brief Variant supporting all possible types for {Set,Get}SessionOptions +using SessionOptionValue = + std::variant>; + /// \brief Options to be set in the SqlInfo. struct ARROW_FLIGHT_SQL_EXPORT SqlInfoOptions { /// \brief Predefined info values for GetSqlInfo. @@ -920,6 +924,22 @@ enum class CancelResult : int8_t { kNotCancellable, }; +/// \brief The result of setting a session option. +enum class SetSessionOptionResult : int8_t { + kUnspecified, + kOk, + kInvalidResult, + kError +}; + +/// \brief The result of closing a session. +enum class CloseSessionResult : int8_t { + kUnspecified, + kClosed, + kClosing, + kNotClosable +}; + ARROW_FLIGHT_SQL_EXPORT std::ostream& operator<<(std::ostream& os, CancelResult result); diff --git a/format/Flight.proto b/format/Flight.proto index 635b1793d2bab..261e9d751736e 100644 --- a/format/Flight.proto +++ b/format/Flight.proto @@ -17,6 +17,7 @@ */ syntax = "proto3"; +import "google/protobuf/descriptor.proto"; option java_package = "org.apache.arrow.flight.impl"; option go_package = "github.com/apache/arrow/go/arrow/flight/internal/flight"; @@ -360,3 +361,96 @@ message FlightData { message PutResult { bytes app_metadata = 1; } + +/* + * Request message for the "Close Session" action. + * + * The exiting session is referenced via a cookie header. + */ +message ActionCloseSessionRequest { + option (experimental) = true; +} + +/* + * The result of closing a session. + * + * The result should be wrapped in a google.protobuf.Any message. + */ +message ActionCloseSessionResult { + option (experimental) = true; + + enum CloseSessionResult { + // The session close status is unknown. Servers should avoid using + // this value (send a NOT_FOUND error if the requested query is + // not known). Clients can retry the request. + CLOSE_RESULT_UNSPECIFIED = 0; + // The session close request is complete. Subsequent requests with + // a NOT_FOUND error. + CLOSE_RESULT_CLOSED = 1; + // The session close request is in progress. The client may retry + // the close request. + CLOSE_RESULT_CLOSING = 2; + // The session is not closeable. The client should not retry the + // close request. + CLOSE_RESULT_NOT_CLOSEABLE = 3; + } + + CloseSessionResult result = 1; +} + +message SessionOptionValue { + option (experimental) = true; + + message StringListValue { + repeated string values = 1; + } + + oneof option_value { + string string_value = 1; + bool bool_value = 2; + sfixed32 int32_value = 3; + sfixed64 int64_value = 4; + float float_value = 5; + double double_value = 6; + StringListValue string_list_value = 7; + } +} + +message ActionSetSessionOptionsRequest { + option (experimental) = true; + + map session_options = 1; +} + +message ActionSetSessionOptionsResult { + option (experimental) = true; + + enum SetSessionOptionResult { + // The status of setting the option is unknown. Servers should avoid using + // this value (send a NOT_FOUND error if the requested query is + // not known). Clients can retry the request. + SET_SESSION_OPTION_RESULT_UNSPECIFIED = 0; + // The session option setting completed successfully. + SET_SESSION_OPTION_RESULT_OK = 1; + // The session cannot be set to the given value. + SET_SESSION_OPTION_RESULT_INVALID_VALUE = 2; + // The session cannot be set. + SET_SESSION_OPTION_RESULT_ERROR = 3; + } + + map results = 1; +} + +message ActionGetSessionOptionsRequest { + option (experimental) = true; +} + +message ActionGetSessionOptionsResult { + option (experimental) = true; + + map session_options = 1; +} + +extend google.protobuf.MessageOptions { + bool experimental = 1000; +} diff --git a/format/FlightSql.proto b/format/FlightSql.proto index d8a6cb5bfdb07..2de618b587a60 100644 --- a/format/FlightSql.proto +++ b/format/FlightSql.proto @@ -21,6 +21,7 @@ import "google/protobuf/descriptor.proto"; option java_package = "org.apache.arrow.flight.sql.impl"; option go_package = "github.com/apache/arrow/go/arrow/flight/internal/flight"; + package arrow.flight.protocol.sql; /* diff --git a/java/adapter/avro/pom.xml b/java/adapter/avro/pom.xml index ae7f4f41a6ff7..b67f1267670db 100644 --- a/java/adapter/avro/pom.xml +++ b/java/adapter/avro/pom.xml @@ -16,7 +16,7 @@ org.apache.arrow arrow-java-root - 12.0.1 + 12.0.1-projectid ../../pom.xml diff --git a/java/adapter/jdbc/pom.xml b/java/adapter/jdbc/pom.xml index f376436769eba..a9bc032ab582d 100644 --- a/java/adapter/jdbc/pom.xml +++ b/java/adapter/jdbc/pom.xml @@ -16,7 +16,7 @@ org.apache.arrow arrow-java-root - 12.0.1 + 12.0.1-projectid ../../pom.xml diff --git a/java/adapter/orc/pom.xml b/java/adapter/orc/pom.xml index 9217ef1a27211..9d97a74c275c8 100644 --- a/java/adapter/orc/pom.xml +++ b/java/adapter/orc/pom.xml @@ -114,7 +114,7 @@ org.apache.arrow arrow-java-root - 12.0.1 + 12.0.1-projectid ../../pom.xml diff --git a/java/algorithm/pom.xml b/java/algorithm/pom.xml index b52009e8b6174..832ba84dda5c4 100644 --- a/java/algorithm/pom.xml +++ b/java/algorithm/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 12.0.1 + 12.0.1-projectid arrow-algorithm Arrow Algorithms diff --git a/java/c/pom.xml b/java/c/pom.xml index 6d9db7c1f5447..6b8b53e39522a 100644 --- a/java/c/pom.xml +++ b/java/c/pom.xml @@ -13,7 +13,7 @@ arrow-java-root org.apache.arrow - 12.0.1 + 12.0.1-projectid 4.0.0 diff --git a/java/compression/pom.xml b/java/compression/pom.xml index a6e76f575a282..e0e8cb9275fd4 100644 --- a/java/compression/pom.xml +++ b/java/compression/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 12.0.1 + 12.0.1-projectid arrow-compression Arrow Compression diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index 6ef3dc05a6dd3..ea4b60e80c5e0 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -15,7 +15,7 @@ arrow-java-root org.apache.arrow - 12.0.1 + 12.0.1-projectid 4.0.0 diff --git a/java/flight/flight-core/pom.xml b/java/flight/flight-core/pom.xml index b35e9a0a93db3..caca2ec436c5e 100644 --- a/java/flight/flight-core/pom.xml +++ b/java/flight/flight-core/pom.xml @@ -14,7 +14,7 @@ arrow-flight org.apache.arrow - 12.0.1 + 12.0.1-projectid ../pom.xml diff --git a/java/flight/flight-grpc/pom.xml b/java/flight/flight-grpc/pom.xml index 3401c32d1b7d8..f43df37819475 100644 --- a/java/flight/flight-grpc/pom.xml +++ b/java/flight/flight-grpc/pom.xml @@ -13,7 +13,7 @@ arrow-flight org.apache.arrow - 12.0.1 + 12.0.1-projectid ../pom.xml 4.0.0 diff --git a/java/flight/flight-integration-tests/pom.xml b/java/flight/flight-integration-tests/pom.xml index b8bdc64da29d1..df8f1992fe08b 100644 --- a/java/flight/flight-integration-tests/pom.xml +++ b/java/flight/flight-integration-tests/pom.xml @@ -15,7 +15,7 @@ arrow-flight org.apache.arrow - 12.0.1 + 12.0.1-projectid ../pom.xml diff --git a/java/flight/flight-sql-jdbc-core/pom.xml b/java/flight/flight-sql-jdbc-core/pom.xml index 12a028f254a19..4ed85c6e35b95 100644 --- a/java/flight/flight-sql-jdbc-core/pom.xml +++ b/java/flight/flight-sql-jdbc-core/pom.xml @@ -16,7 +16,7 @@ arrow-flight org.apache.arrow - 12.0.1 + 12.0.1-projectid ../pom.xml 4.0.0 diff --git a/java/flight/flight-sql-jdbc-driver/pom.xml b/java/flight/flight-sql-jdbc-driver/pom.xml index 15c7f737aee42..f08cada2540ac 100644 --- a/java/flight/flight-sql-jdbc-driver/pom.xml +++ b/java/flight/flight-sql-jdbc-driver/pom.xml @@ -16,7 +16,7 @@ arrow-flight org.apache.arrow - 12.0.1 + 12.0.1-projectid ../pom.xml 4.0.0 diff --git a/java/flight/flight-sql/pom.xml b/java/flight/flight-sql/pom.xml index d8371e4355564..893f1100024b0 100644 --- a/java/flight/flight-sql/pom.xml +++ b/java/flight/flight-sql/pom.xml @@ -14,7 +14,7 @@ arrow-flight org.apache.arrow - 12.0.1 + 12.0.1-projectid ../pom.xml diff --git a/java/flight/pom.xml b/java/flight/pom.xml index ec972a67fb926..036391d9148f3 100644 --- a/java/flight/pom.xml +++ b/java/flight/pom.xml @@ -15,7 +15,7 @@ arrow-java-root org.apache.arrow - 12.0.1 + 12.0.1-projectid 4.0.0 diff --git a/java/format/pom.xml b/java/format/pom.xml index 98a53fbfd7a24..8ca291fbbfed3 100644 --- a/java/format/pom.xml +++ b/java/format/pom.xml @@ -15,7 +15,7 @@ arrow-java-root org.apache.arrow - 12.0.1 + 12.0.1-projectid arrow-format diff --git a/java/gandiva/pom.xml b/java/gandiva/pom.xml index bed66b427e625..4ab85c24c31b6 100644 --- a/java/gandiva/pom.xml +++ b/java/gandiva/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 12.0.1 + 12.0.1-projectid org.apache.arrow.gandiva diff --git a/java/memory/memory-core/pom.xml b/java/memory/memory-core/pom.xml index 7acd474c6c8b6..da2274d78b206 100644 --- a/java/memory/memory-core/pom.xml +++ b/java/memory/memory-core/pom.xml @@ -13,7 +13,7 @@ arrow-memory org.apache.arrow - 12.0.1 + 12.0.1-projectid 4.0.0 diff --git a/java/memory/memory-netty/pom.xml b/java/memory/memory-netty/pom.xml index b101e0fe3877d..12a7627d1d45b 100644 --- a/java/memory/memory-netty/pom.xml +++ b/java/memory/memory-netty/pom.xml @@ -13,7 +13,7 @@ arrow-memory org.apache.arrow - 12.0.1 + 12.0.1-projectid 4.0.0 diff --git a/java/memory/memory-unsafe/pom.xml b/java/memory/memory-unsafe/pom.xml index 0c927e332d2a0..34589921014c8 100644 --- a/java/memory/memory-unsafe/pom.xml +++ b/java/memory/memory-unsafe/pom.xml @@ -13,7 +13,7 @@ arrow-memory org.apache.arrow - 12.0.1 + 12.0.1-projectid 4.0.0 diff --git a/java/memory/pom.xml b/java/memory/pom.xml index 15896ecc0112b..551cff4b9064d 100644 --- a/java/memory/pom.xml +++ b/java/memory/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 12.0.1 + 12.0.1-projectid arrow-memory Arrow Memory diff --git a/java/performance/pom.xml b/java/performance/pom.xml index 4e2d66dd3ebc4..0823a5f364107 100644 --- a/java/performance/pom.xml +++ b/java/performance/pom.xml @@ -14,7 +14,7 @@ arrow-java-root org.apache.arrow - 12.0.1 + 12.0.1-projectid arrow-performance jar @@ -74,7 +74,7 @@ org.apache.arrow arrow-algorithm - 12.0.1 + 12.0.1-projectid test diff --git a/java/pom.xml b/java/pom.xml index 747320d2f8a40..64327819d6986 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -20,7 +20,7 @@ org.apache.arrow arrow-java-root - 12.0.1 + 12.0.1-projectid pom Apache Arrow Java Root POM diff --git a/java/tools/pom.xml b/java/tools/pom.xml index d81a61b5fdcca..3c972d267fa54 100644 --- a/java/tools/pom.xml +++ b/java/tools/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 12.0.1 + 12.0.1-projectid arrow-tools Arrow Tools diff --git a/java/vector/pom.xml b/java/vector/pom.xml index bfedd7d3c181d..222aa624cdc8d 100644 --- a/java/vector/pom.xml +++ b/java/vector/pom.xml @@ -14,7 +14,7 @@ org.apache.arrow arrow-java-root - 12.0.1 + 12.0.1-projectid arrow-vector Arrow Vectors