Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-36155: [C++][Go][Java][FlightRPC] Add support for long-running queries #36946

Merged
merged 10 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions cpp/src/arrow/flight/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,15 @@ set(FLIGHT_GENERATED_PROTO_FILES

set(PROTO_DEPENDS ${FLIGHT_PROTO} ${ARROW_PROTOBUF_LIBPROTOBUF} gRPC::grpc_cpp_plugin)

set(FLIGHT_PROTOC_COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${FLIGHT_PROTO_PATH}")
if(Protobuf_VERSION VERSION_LESS 3.15)
list(APPEND FLIGHT_PROTOC_COMMAND "--experimental_allow_proto3_optional")
endif()
add_custom_command(OUTPUT ${FLIGHT_GENERATED_PROTO_FILES}
COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${FLIGHT_PROTO_PATH}"
"--cpp_out=${CMAKE_CURRENT_BINARY_DIR}" "${FLIGHT_PROTO}"
DEPENDS ${PROTO_DEPENDS} ARGS
COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${FLIGHT_PROTO_PATH}"
COMMAND ${FLIGHT_PROTOC_COMMAND}
"--cpp_out=${CMAKE_CURRENT_BINARY_DIR}" "${FLIGHT_PROTO}"
COMMAND ${FLIGHT_PROTOC_COMMAND}
"--grpc_out=${CMAKE_CURRENT_BINARY_DIR}"
"--plugin=protoc-gen-grpc=$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
"${FLIGHT_PROTO}")
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,14 @@ arrow::Future<FlightInfo> FlightClient::GetFlightInfoAsync(
return future;
}

arrow::Result<std::unique_ptr<PollInfo>> FlightClient::PollFlightInfo(
const FlightCallOptions& options, const FlightDescriptor& descriptor) {
std::unique_ptr<PollInfo> info;
RETURN_NOT_OK(CheckOpen());
RETURN_NOT_OK(transport_->PollFlightInfo(options, descriptor, &info));
return info;
}

arrow::Result<std::unique_ptr<SchemaResult>> FlightClient::GetSchema(
const FlightCallOptions& options, const FlightDescriptor& descriptor) {
RETURN_NOT_OK(CheckOpen());
Expand Down
13 changes: 13 additions & 0 deletions cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,19 @@ class ARROW_FLIGHT_EXPORT FlightClient {
return GetFlightInfoAsync({}, descriptor);
}

/// \brief Request and poll a long running query
/// \param[in] options Per-RPC options
/// \param[in] descriptor the dataset request or a descriptor returned by a
/// prioir PollFlightInfo call
/// \return Arrow result with the PollInfo describing the status of
/// the requested query
arrow::Result<std::unique_ptr<PollInfo>> PollFlightInfo(
const FlightCallOptions& options, const FlightDescriptor& descriptor);
arrow::Result<std::unique_ptr<PollInfo>> PollFlightInfo(
const FlightDescriptor& descriptor) {
return PollFlightInfo({}, descriptor);
}

/// \brief Request schema for a single flight, which may be an existing
/// dataset or a command to be executed
/// \param[in] options Per-RPC options
Expand Down
37 changes: 36 additions & 1 deletion cpp/src/arrow/flight/flight_internals_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,12 @@ void TestRoundtrip(const std::vector<FlightType>& values,

ASSERT_OK_AND_ASSIGN(std::string serialized, values[i].SerializeToString());
ASSERT_OK_AND_ASSIGN(auto deserialized, FlightType::Deserialize(serialized));
if constexpr (std::is_same_v<FlightType, FlightInfo>) {
if constexpr (std::is_same_v<FlightType, FlightInfo> ||
std::is_same_v<FlightType, PollInfo>) {
ARROW_SCOPED_TRACE("Deserialized = ", deserialized->ToString());
EXPECT_EQ(values[i], *deserialized);
} else {
ARROW_SCOPED_TRACE("Deserialized = ", deserialized.ToString());
EXPECT_EQ(values[i], deserialized);
}

Expand Down Expand Up @@ -255,6 +258,38 @@ TEST(FlightTypes, FlightInfo) {
ASSERT_NO_FATAL_FAILURE(TestRoundtrip<pb::FlightInfo>(values, reprs));
}

TEST(FlightTypes, PollInfo) {
ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 1234));
Schema schema({field("ints", int64())});
auto desc = FlightDescriptor::Command("foo");
auto endpoint = FlightEndpoint{Ticket{"foo"}, {}, std::nullopt};
auto info = MakeFlightInfo(schema, desc, {endpoint}, -1, 42, true);
// 2023-06-19 03:14:06.004330100
// We must use microsecond resolution here for portability.
// std::chrono::system_clock::time_point may not provide nanosecond
// resolution on some platforms such as Windows.
const auto expiration_time_duration =
std::chrono::seconds{1687144446} + std::chrono::nanoseconds{4339000};
Timestamp expiration_time(
std::chrono::duration_cast<Timestamp::duration>(expiration_time_duration));
std::vector<PollInfo> values = {
PollInfo{std::make_unique<FlightInfo>(info), std::nullopt, std::nullopt,
std::nullopt},
PollInfo{std::make_unique<FlightInfo>(info), FlightDescriptor::Command("poll"), 0.1,
expiration_time},
};
std::vector<std::string> reprs = {
"<PollInfo info=" + info.ToString() +
" descriptor=null "
"progress=null expiration_time=null>",
"<PollInfo info=" + info.ToString() +
" descriptor=<FlightDescriptor cmd='poll'> "
"progress=0.1 expiration_time=2023-06-19 03:14:06.004339000>",
};

ASSERT_NO_FATAL_FAILURE(TestRoundtrip<pb::PollInfo>(values, reprs));
}

TEST(FlightTypes, Result) {
std::vector<Result> values = {
{Buffer::FromString("")},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ TEST(FlightIntegration, ExpirationTimeRenewFlightEndpoint) {
ASSERT_OK(RunScenario("expiration_time:renew_flight_endpoint"));
}

TEST(FlightIntegration, PollFlightInfo) { ASSERT_OK(RunScenario("poll_flight_info")); }

TEST(FlightIntegration, FlightSql) { ASSERT_OK(RunScenario("flight_sql")); }

TEST(FlightIntegration, FlightSqlExtension) {
Expand Down
78 changes: 75 additions & 3 deletions cpp/src/arrow/flight/integration_tests/test_integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -708,9 +708,7 @@ class ExpirationTimeCancelFlightInfoScenario : public Scenario {

/// \brief The expiration time scenario - RenewFlightEndpoint.
///
/// This tests that the client can renew a FlightEndpoint and read
/// data in renewed expiration time even when the original
/// expiration time is over.
/// This tests that the client can renew a FlightEndpoint.
class ExpirationTimeRenewFlightEndpointScenario : public Scenario {
Status MakeServer(std::unique_ptr<FlightServerBase>* server,
FlightServerOptions* options) override {
Expand Down Expand Up @@ -746,6 +744,77 @@ class ExpirationTimeRenewFlightEndpointScenario : public Scenario {
}
};

/// \brief The server used for testing PollFlightInfo().
class PollFlightInfoServer : public FlightServerBase {
public:
PollFlightInfoServer() : FlightServerBase() {}

Status PollFlightInfo(const ServerCallContext& context,
const FlightDescriptor& descriptor,
std::unique_ptr<PollInfo>* result) override {
auto schema = arrow::schema({arrow::field("number", arrow::uint32(), false)});
std::vector<FlightEndpoint> endpoints = {
FlightEndpoint{{"long-running query"}, {}, std::nullopt}};
ARROW_ASSIGN_OR_RAISE(
auto info, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1, false));
if (descriptor == FlightDescriptor::Command("poll")) {
*result = std::make_unique<PollInfo>(std::make_unique<FlightInfo>(std::move(info)),
std::nullopt, 1.0, std::nullopt);
} else {
*result =
std::make_unique<PollInfo>(std::make_unique<FlightInfo>(std::move(info)),
FlightDescriptor::Command("poll"), 0.1,
Timestamp::clock::now() + std::chrono::seconds{10});
}
return Status::OK();
}
};

/// \brief The PollFlightInfo scenario.
///
/// This tests that the client can poll a long-running query.
class PollFlightInfoScenario : public Scenario {
Status MakeServer(std::unique_ptr<FlightServerBase>* server,
FlightServerOptions* options) override {
*server = std::make_unique<PollFlightInfoServer>();
return Status::OK();
}

Status MakeClient(FlightClientOptions* options) override { return Status::OK(); }

Status RunClient(std::unique_ptr<FlightClient> client) override {
ARROW_ASSIGN_OR_RAISE(
auto info, client->PollFlightInfo(FlightDescriptor::Command("heavy query")));
if (!info->descriptor.has_value()) {
return Status::Invalid("Description is missing: ", info->ToString());
}
if (!info->progress.has_value()) {
return Status::Invalid("Progress is missing: ", info->ToString());
}
if (!(0.0 <= *info->progress && *info->progress <= 1.0)) {
return Status::Invalid("Invalid progress: ", info->ToString());
}
if (!info->expiration_time.has_value()) {
return Status::Invalid("Expiration time is missing: ", info->ToString());
}
ARROW_ASSIGN_OR_RAISE(info, client->PollFlightInfo(*info->descriptor));
if (info->descriptor.has_value()) {
return Status::Invalid("Retried but not finished yet: ", info->ToString());
}
if (!info->progress.has_value()) {
return Status::Invalid("Progress is missing in finished query: ", info->ToString());
}
if (fabs(*info->progress - 1.0) > arrow::kDefaultAbsoluteTolerance) {
return Status::Invalid("Progress for finished query isn't 1.0: ", info->ToString());
}
if (info->expiration_time.has_value()) {
return Status::Invalid("Expiration time must not be set for finished query: ",
info->ToString());
}
return Status::OK();
}
};

/// \brief Schema to be returned for mocking the statement/prepared statement results.
///
/// Must be the same across all languages.
Expand Down Expand Up @@ -1825,6 +1894,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr<Scenario>*
} else if (scenario_name == "expiration_time:renew_flight_endpoint") {
*out = std::make_shared<ExpirationTimeRenewFlightEndpointScenario>();
return Status::OK();
} else if (scenario_name == "poll_flight_info") {
*out = std::make_shared<PollFlightInfoScenario>();
return Status::OK();
} else if (scenario_name == "flight_sql") {
*out = std::make_shared<FlightSqlScenario>();
return Status::OK();
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/flight/middleware.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ enum class FlightMethod : char {
DoAction = 7,
ListActions = 8,
DoExchange = 9,
PollFlightInfo = 10,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to update this in Python (I can put up a PR for that later)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll do this in another PR. Issue for this: #36954

};

/// \brief Get a human-readable name for a Flight method.
Expand Down
80 changes: 66 additions & 14 deletions cpp/src/arrow/flight/serialization_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,26 @@ namespace arrow {
namespace flight {
namespace internal {

// Timestamp

Status FromProto(const google::protobuf::Timestamp& pb_timestamp, Timestamp* timestamp) {
const auto seconds = std::chrono::seconds{pb_timestamp.seconds()};
const auto nanoseconds = std::chrono::nanoseconds{pb_timestamp.nanos()};
const auto duration =
std::chrono::duration_cast<Timestamp::duration>(seconds + nanoseconds);
*timestamp = Timestamp(duration);
return Status::OK();
}

Status ToProto(const Timestamp& timestamp, google::protobuf::Timestamp* pb_timestamp) {
const auto since_epoch = timestamp.time_since_epoch();
const auto since_epoch_ns =
std::chrono::duration_cast<std::chrono::nanoseconds>(since_epoch).count();
pb_timestamp->set_seconds(since_epoch_ns / std::nano::den);
pb_timestamp->set_nanos(since_epoch_ns % std::nano::den);
return Status::OK();
}

// ActionType

Status FromProto(const pb::ActionType& pb_type, ActionType* type) {
Expand Down Expand Up @@ -153,13 +173,9 @@ Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint
RETURN_NOT_OK(FromProto(pb_endpoint.location(i), &endpoint->locations[i]));
}
if (pb_endpoint.has_expiration_time()) {
const auto& pb_expiration_time = pb_endpoint.expiration_time();
const auto seconds = std::chrono::seconds{pb_expiration_time.seconds()};
const auto nanoseconds = std::chrono::nanoseconds{pb_expiration_time.nanos()};
const auto duration =
std::chrono::duration_cast<Timestamp::duration>(seconds + nanoseconds);
const Timestamp expiration_time(duration);
endpoint->expiration_time = expiration_time;
Timestamp expiration_time;
RETURN_NOT_OK(FromProto(pb_endpoint.expiration_time(), &expiration_time));
endpoint->expiration_time = std::move(expiration_time);
}
return Status::OK();
}
Expand All @@ -171,13 +187,8 @@ Status ToProto(const FlightEndpoint& endpoint, pb::FlightEndpoint* pb_endpoint)
RETURN_NOT_OK(ToProto(location, pb_endpoint->add_location()));
}
if (endpoint.expiration_time) {
const auto expiration_time = endpoint.expiration_time.value();
const auto since_epoch = expiration_time.time_since_epoch();
const auto since_epoch_ns =
std::chrono::duration_cast<std::chrono::nanoseconds>(since_epoch).count();
auto pb_expiration_time = pb_endpoint->mutable_expiration_time();
pb_expiration_time->set_seconds(since_epoch_ns / std::nano::den);
pb_expiration_time->set_nanos(since_epoch_ns % std::nano::den);
RETURN_NOT_OK(ToProto(endpoint.expiration_time.value(),
pb_endpoint->mutable_expiration_time()));
}
return Status::OK();
}
Expand Down Expand Up @@ -288,6 +299,47 @@ Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info) {
return Status::OK();
}

// PollInfo

Status FromProto(const pb::PollInfo& pb_info, PollInfo* info) {
ARROW_ASSIGN_OR_RAISE(auto flight_info, FromProto(pb_info.info()));
info->info = std::make_unique<FlightInfo>(std::move(flight_info));
if (pb_info.has_flight_descriptor()) {
FlightDescriptor descriptor;
RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &descriptor));
info->descriptor = std::move(descriptor);
} else {
info->descriptor = std::nullopt;
}
if (pb_info.has_progress()) {
info->progress = pb_info.progress();
} else {
info->progress = std::nullopt;
}
if (pb_info.has_expiration_time()) {
Timestamp expiration_time;
RETURN_NOT_OK(FromProto(pb_info.expiration_time(), &expiration_time));
info->expiration_time = std::move(expiration_time);
} else {
info->expiration_time = std::nullopt;
}
return Status::OK();
}

Status ToProto(const PollInfo& info, pb::PollInfo* pb_info) {
RETURN_NOT_OK(ToProto(*info.info, pb_info->mutable_info()));
if (info.descriptor) {
RETURN_NOT_OK(ToProto(*info.descriptor, pb_info->mutable_flight_descriptor()));
}
if (info.progress) {
pb_info->set_progress(info.progress.value());
}
if (info.expiration_time) {
RETURN_NOT_OK(ToProto(*info.expiration_time, pb_info->mutable_expiration_time()));
}
return Status::OK();
}

// CancelFlightInfoRequest

Status FromProto(const pb::CancelFlightInfoRequest& pb_request,
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/flight/serialization_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Status SchemaToString(const Schema& schema, std::string* out);

// These functions depend on protobuf types which are not exported in the Flight DLL.

Status FromProto(const google::protobuf::Timestamp& pb_timestamp, Timestamp* timestamp);
Status FromProto(const pb::ActionType& pb_type, ActionType* type);
Status FromProto(const pb::Action& pb_action, Action* action);
Status FromProto(const pb::Result& pb_result, Result* result);
Expand All @@ -60,16 +61,19 @@ Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint
Status FromProto(const pb::RenewFlightEndpointRequest& pb_request,
RenewFlightEndpointRequest* request);
arrow::Result<FlightInfo> FromProto(const pb::FlightInfo& pb_info);
Status FromProto(const pb::PollInfo& pb_info, PollInfo* info);
Status FromProto(const pb::CancelFlightInfoRequest& pb_request,
CancelFlightInfoRequest* request);
Status FromProto(const pb::SchemaResult& pb_result, std::string* result);
Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* info);

Status ToProto(const Timestamp& timestamp, google::protobuf::Timestamp* pb_timestamp);
Status ToProto(const FlightDescriptor& descr, pb::FlightDescriptor* pb_descr);
Status ToProto(const FlightEndpoint& endpoint, pb::FlightEndpoint* pb_endpoint);
Status ToProto(const RenewFlightEndpointRequest& request,
pb::RenewFlightEndpointRequest* pb_request);
Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info);
Status ToProto(const PollInfo& info, pb::PollInfo* pb_info);
Status ToProto(const CancelFlightInfoRequest& request,
pb::CancelFlightInfoRequest* pb_request);
Status ToProto(const ActionType& type, pb::ActionType* pb_type);
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/flight/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ Status FlightServerBase::GetFlightInfo(const ServerCallContext& context,
return Status::NotImplemented("NYI");
}

Status FlightServerBase::PollFlightInfo(const ServerCallContext& context,
const FlightDescriptor& request,
std::unique_ptr<PollInfo>* info) {
return Status::NotImplemented("NYI");
}

Status FlightServerBase::DoGet(const ServerCallContext& context, const Ticket& request,
std::unique_ptr<FlightDataStream>* data_stream) {
return Status::NotImplemented("NYI");
Expand Down
Loading