From 4937adc9a2d22069bd24f5c86d6f6359efcf347b Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 18 Aug 2022 12:04:42 -0400 Subject: [PATCH] ARROW-17254: [C++][Go][Java][FlightRPC] Implement and test Flight SQL GetSchema (#13898) Consistently implements and tests the GetSchema method in Flight SQL. Builds on #13897. Authored-by: David Li Signed-off-by: David Li --- .../flight/integration_tests/CMakeLists.txt | 11 + .../flight_integration_test.cc | 60 +++ .../integration_tests/test_integration.cc | 351 +++++++++++------- cpp/src/arrow/flight/sql/client.cc | 208 ++++++++--- cpp/src/arrow/flight/sql/client.h | 83 ++++- cpp/src/arrow/flight/sql/server.cc | 79 ++++ cpp/src/arrow/flight/sql/server.h | 23 ++ cpp/src/arrow/flight/types.cc | 5 +- cpp/src/arrow/flight/types.h | 2 +- cpp/src/arrow/python/flight.cc | 5 +- go/arrow/flight/flightsql/client.go | 88 +++++ go/arrow/flight/flightsql/server.go | 60 +++ .../internal/flight_integration/scenario.go | 131 ++++++- .../integration/tests/FlightSqlScenario.java | 36 +- .../tests/FlightSqlScenarioProducer.java | 9 + .../integration/tests/IntegrationTest.java | 65 ++++ .../arrow/flight/sql/FlightSqlClient.java | 135 +++++++ .../arrow/flight/sql/FlightSqlProducer.java | 45 ++- .../arrow/flight/sql/FlightSqlUtils.java | 2 +- 19 files changed, 1173 insertions(+), 225 deletions(-) create mode 100644 cpp/src/arrow/flight/integration_tests/flight_integration_test.cc create mode 100644 java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java diff --git a/cpp/src/arrow/flight/integration_tests/CMakeLists.txt b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt index 66a021b4b5975..1bbd923160642 100644 --- a/cpp/src/arrow/flight/integration_tests/CMakeLists.txt +++ b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt @@ -40,3 +40,14 @@ target_link_libraries(flight-test-integration-client add_dependencies(arrow-integration flight-test-integration-client flight-test-integration-server) + +if(ARROW_BUILD_TESTS) + add_arrow_test(flight_integration_test + SOURCES + flight_integration_test.cc + test_integration.cc + STATIC_LINK_LIBS + ${ARROW_FLIGHT_INTEGRATION_TEST_LINK_LIBS} + LABELS + "arrow_flight") +endif() diff --git a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc new file mode 100644 index 0000000000000..706ac3b7d931b --- /dev/null +++ b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Run the integration test scenarios in-process. + +#include +#include + +#include "arrow/flight/integration_tests/test_integration.h" +#include "arrow/status.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { +namespace flight { +namespace integration_tests { + +Status RunScenario(const std::string& scenario_name) { + std::shared_ptr scenario; + ARROW_RETURN_NOT_OK(GetScenario(scenario_name, &scenario)); + + std::unique_ptr server; + ARROW_ASSIGN_OR_RAISE(Location bind_location, + arrow::flight::Location::ForGrpcTcp("0.0.0.0", 0)); + FlightServerOptions server_options(bind_location); + ARROW_RETURN_NOT_OK(scenario->MakeServer(&server, &server_options)); + ARROW_RETURN_NOT_OK(server->Init(server_options)); + + ARROW_ASSIGN_OR_RAISE(Location location, + arrow::flight::Location::ForGrpcTcp("0.0.0.0", server->port())); + auto client_options = arrow::flight::FlightClientOptions::Defaults(); + ARROW_RETURN_NOT_OK(scenario->MakeClient(&client_options)); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr client, + FlightClient::Connect(location, client_options)); + ARROW_RETURN_NOT_OK(scenario->RunClient(std::move(client))); + return Status::OK(); +} + +TEST(FlightIntegration, AuthBasicProto) { ASSERT_OK(RunScenario("auth:basic_proto")); } + +TEST(FlightIntegration, Middleware) { ASSERT_OK(RunScenario("middleware")); } + +TEST(FlightIntegration, FlightSql) { ASSERT_OK(RunScenario("flight_sql")); } + +} // namespace integration_tests +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 7bdd27da79981..b228f9cceba06 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -117,7 +117,7 @@ class AuthBasicProtoScenario : public Scenario { /// regardless of what gRPC does. class TestServerMiddleware : public ServerMiddleware { public: - explicit TestServerMiddleware(std::string received) : received_(received) {} + explicit TestServerMiddleware(std::string received) : received_(std::move(received)) {} void SendingHeaders(AddCallHeaders* outgoing_headers) override { outgoing_headers->AddHeader("x-middleware", received_); @@ -154,11 +154,11 @@ class TestClientMiddleware : public ClientMiddleware { explicit TestClientMiddleware(std::string* received_header) : received_header_(received_header) {} - void SendingHeaders(AddCallHeaders* outgoing_headers) { + void SendingHeaders(AddCallHeaders* outgoing_headers) override { outgoing_headers->AddHeader("x-middleware", "expected value"); } - void ReceivedHeaders(const CallHeaders& incoming_headers) { + void ReceivedHeaders(const CallHeaders& incoming_headers) override { // We expect the server to always send this header. gRPC/Java may // send it in trailers instead of headers, so we expect Flight to // account for this. @@ -170,7 +170,7 @@ class TestClientMiddleware : public ClientMiddleware { } } - void CallCompleted(const Status& status) {} + void CallCompleted(const Status& status) override {} private: std::string* received_header_; @@ -178,7 +178,8 @@ class TestClientMiddleware : public ClientMiddleware { class TestClientMiddlewareFactory : public ClientMiddlewareFactory { public: - void StartCall(const CallInfo& info, std::unique_ptr* middleware) { + void StartCall(const CallInfo& info, + std::unique_ptr* middleware) override { *middleware = std::unique_ptr(new TestClientMiddleware(&received_header_)); } @@ -218,8 +219,8 @@ class MiddlewareServer : public FlightServerBase { class MiddlewareScenario : public Scenario { Status MakeServer(std::unique_ptr* server, FlightServerOptions* options) override { - options->middleware.push_back( - {"grpc_trailers", std::make_shared()}); + options->middleware.emplace_back("grpc_trailers", + std::make_shared()); server->reset(new MiddlewareServer()); return Status::OK(); } @@ -284,11 +285,13 @@ std::shared_ptr GetQuerySchema() { constexpr int64_t kUpdateStatementExpectedRows = 10000L; constexpr int64_t kUpdatePreparedStatementExpectedRows = 20000L; +constexpr char kSelectStatement[] = "SELECT STATEMENT"; template -arrow::Status AssertEq(const T& expected, const T& actual) { +arrow::Status AssertEq(const T& expected, const T& actual, const std::string& message) { if (expected != actual) { - return Status::Invalid("Expected \"", expected, "\", got \'", actual, "\""); + return Status::Invalid(message, ": expected \"", expected, "\", got \"", actual, + "\""); } return Status::OK(); } @@ -301,7 +304,9 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result> GetFlightInfoStatement( const ServerCallContext& context, const sql::StatementQuery& command, const FlightDescriptor& descriptor) override { - ARROW_RETURN_NOT_OK(AssertEq("SELECT STATEMENT", command.query)); + ARROW_RETURN_NOT_OK( + AssertEq(kSelectStatement, command.query, + "Unexpected statement in GetFlightInfoStatement")); ARROW_ASSIGN_OR_RAISE(auto handle, sql::CreateStatementQueryTicket("SELECT STATEMENT HANDLE")); @@ -313,6 +318,14 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { return std::unique_ptr(new FlightInfo(result)); } + arrow::Result> GetSchemaStatement( + const ServerCallContext& context, const sql::StatementQuery& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq( + kSelectStatement, command.query, "Unexpected statement in GetSchemaStatement")); + return SchemaResult::Make(*GetQuerySchema()); + } + arrow::Result> DoGetStatement( const ServerCallContext& context, const sql::StatementQueryTicket& command) override { @@ -323,11 +336,21 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { const ServerCallContext& context, const sql::PreparedStatementQuery& command, const FlightDescriptor& descriptor) override { ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", - command.prepared_statement_handle)); + command.prepared_statement_handle, + "Unexpected prepared statement handle")); return GetFlightInfoForCommand(descriptor, GetQuerySchema()); } + arrow::Result> GetSchemaPreparedStatement( + const ServerCallContext& context, const sql::PreparedStatementQuery& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", + command.prepared_statement_handle, + "Unexpected prepared statement handle")); + return SchemaResult::Make(*GetQuerySchema()); + } + arrow::Result> DoGetPreparedStatement( const ServerCallContext& context, const sql::PreparedStatementQuery& command) override { @@ -358,11 +381,14 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result> GetFlightInfoSqlInfo( const ServerCallContext& context, const sql::GetSqlInfo& command, const FlightDescriptor& descriptor) override { - ARROW_RETURN_NOT_OK(AssertEq(2, command.info.size())); - ARROW_RETURN_NOT_OK(AssertEq( - sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, command.info[0])); - ARROW_RETURN_NOT_OK(AssertEq( - sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, command.info[1])); + ARROW_RETURN_NOT_OK(AssertEq(2, command.info.size(), + "Wrong number of SqlInfo values passed")); + ARROW_RETURN_NOT_OK( + AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, + command.info[0], "Unexpected SqlInfo passed")); + ARROW_RETURN_NOT_OK( + AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, + command.info[1], "Unexpected SqlInfo passed")); return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema()); } @@ -375,9 +401,11 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result> GetFlightInfoSchemas( const ServerCallContext& context, const sql::GetDbSchemas& command, const FlightDescriptor& descriptor) override { - ARROW_RETURN_NOT_OK(AssertEq("catalog", command.catalog.value())); + ARROW_RETURN_NOT_OK(AssertEq("catalog", command.catalog.value(), + "Wrong catalog passed")); ARROW_RETURN_NOT_OK(AssertEq("db_schema_filter_pattern", - command.db_schema_filter_pattern.value())); + command.db_schema_filter_pattern.value(), + "Wrong db_schema_filter_pattern passed")); return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetDbSchemasSchema()); } @@ -390,15 +418,22 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result> GetFlightInfoTables( const ServerCallContext& context, const sql::GetTables& command, const FlightDescriptor& descriptor) override { - ARROW_RETURN_NOT_OK(AssertEq("catalog", command.catalog.value())); + ARROW_RETURN_NOT_OK(AssertEq("catalog", command.catalog.value(), + "Wrong catalog passed")); ARROW_RETURN_NOT_OK(AssertEq("db_schema_filter_pattern", - command.db_schema_filter_pattern.value())); + command.db_schema_filter_pattern.value(), + "Wrong db_schema_filter_pattern passed")); ARROW_RETURN_NOT_OK(AssertEq("table_filter_pattern", - command.table_name_filter_pattern.value())); - ARROW_RETURN_NOT_OK(AssertEq(2, command.table_types.size())); - ARROW_RETURN_NOT_OK(AssertEq("table", command.table_types[0])); - ARROW_RETURN_NOT_OK(AssertEq("view", command.table_types[1])); - ARROW_RETURN_NOT_OK(AssertEq(true, command.include_schema)); + command.table_name_filter_pattern.value(), + "Wrong table_filter_pattern passed")); + ARROW_RETURN_NOT_OK(AssertEq(2, command.table_types.size(), + "Wrong number of table types passed")); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_types[0], + "Wrong table type passed")); + ARROW_RETURN_NOT_OK( + AssertEq("view", command.table_types[1], "Wrong table type passed")); + ARROW_RETURN_NOT_OK( + AssertEq(true, command.include_schema, "include_schema should be true")); return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetTablesSchemaWithIncludedSchema()); @@ -422,11 +457,12 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result> GetFlightInfoPrimaryKeys( const ServerCallContext& context, const sql::GetPrimaryKeys& command, const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq( + "catalog", command.table_ref.catalog.value(), "Wrong catalog passed")); + ARROW_RETURN_NOT_OK(AssertEq( + "db_schema", command.table_ref.db_schema.value(), "Wrong db_schema passed")); ARROW_RETURN_NOT_OK( - AssertEq("catalog", command.table_ref.catalog.value())); - ARROW_RETURN_NOT_OK( - AssertEq("db_schema", command.table_ref.db_schema.value())); - ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + AssertEq("table", command.table_ref.table, "Wrong table passed")); return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetPrimaryKeysSchema()); } @@ -439,11 +475,12 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result> GetFlightInfoExportedKeys( const ServerCallContext& context, const sql::GetExportedKeys& command, const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq( + "catalog", command.table_ref.catalog.value(), "Wrong catalog passed")); + ARROW_RETURN_NOT_OK(AssertEq( + "db_schema", command.table_ref.db_schema.value(), "Wrong db_schema passed")); ARROW_RETURN_NOT_OK( - AssertEq("catalog", command.table_ref.catalog.value())); - ARROW_RETURN_NOT_OK( - AssertEq("db_schema", command.table_ref.db_schema.value())); - ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + AssertEq("table", command.table_ref.table, "Wrong table passed")); return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetExportedKeysSchema()); } @@ -456,11 +493,12 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result> GetFlightInfoImportedKeys( const ServerCallContext& context, const sql::GetImportedKeys& command, const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq( + "catalog", command.table_ref.catalog.value(), "Wrong catalog passed")); + ARROW_RETURN_NOT_OK(AssertEq( + "db_schema", command.table_ref.db_schema.value(), "Wrong db_schema passed")); ARROW_RETURN_NOT_OK( - AssertEq("catalog", command.table_ref.catalog.value())); - ARROW_RETURN_NOT_OK( - AssertEq("db_schema", command.table_ref.db_schema.value())); - ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + AssertEq("table", command.table_ref.table, "Wrong table passed")); return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetImportedKeysSchema()); } @@ -473,16 +511,20 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result> GetFlightInfoCrossReference( const ServerCallContext& context, const sql::GetCrossReference& command, const FlightDescriptor& descriptor) override { - ARROW_RETURN_NOT_OK( - AssertEq("pk_catalog", command.pk_table_ref.catalog.value())); - ARROW_RETURN_NOT_OK( - AssertEq("pk_db_schema", command.pk_table_ref.db_schema.value())); - ARROW_RETURN_NOT_OK(AssertEq("pk_table", command.pk_table_ref.table)); - ARROW_RETURN_NOT_OK( - AssertEq("fk_catalog", command.fk_table_ref.catalog.value())); - ARROW_RETURN_NOT_OK( - AssertEq("fk_db_schema", command.fk_table_ref.db_schema.value())); - ARROW_RETURN_NOT_OK(AssertEq("fk_table", command.fk_table_ref.table)); + ARROW_RETURN_NOT_OK(AssertEq( + "pk_catalog", command.pk_table_ref.catalog.value(), "Wrong pk catalog passed")); + ARROW_RETURN_NOT_OK(AssertEq("pk_db_schema", + command.pk_table_ref.db_schema.value(), + "Wrong pk db_schema passed")); + ARROW_RETURN_NOT_OK(AssertEq("pk_table", command.pk_table_ref.table, + "Wrong pk table passed")); + ARROW_RETURN_NOT_OK(AssertEq( + "fk_catalog", command.fk_table_ref.catalog.value(), "Wrong fk catalog passed")); + ARROW_RETURN_NOT_OK(AssertEq("fk_db_schema", + command.fk_table_ref.db_schema.value(), + "Wrong fk db_schema passed")); + ARROW_RETURN_NOT_OK(AssertEq("fk_table", command.fk_table_ref.table, + "Wrong fk table passed")); return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetTableTypesSchema()); } @@ -494,7 +536,9 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result DoPutCommandStatementUpdate( const ServerCallContext& context, const sql::StatementUpdate& command) override { - ARROW_RETURN_NOT_OK(AssertEq("UPDATE STATEMENT", command.query)); + ARROW_RETURN_NOT_OK( + AssertEq("UPDATE STATEMENT", command.query, + "Wrong query for DoPutCommandStatementUpdate")); return kUpdateStatementExpectedRows; } @@ -502,9 +546,10 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result CreatePreparedStatement( const ServerCallContext& context, const sql::ActionCreatePreparedStatementRequest& request) override { - ARROW_RETURN_NOT_OK( - AssertEq(true, request.query == "SELECT PREPARED STATEMENT" || - request.query == "UPDATE PREPARED STATEMENT")); + if (request.query != "SELECT PREPARED STATEMENT" && + request.query != "UPDATE PREPARED STATEMENT") { + return Status::Invalid("Unexpected query: ", request.query); + } sql::ActionCreatePreparedStatementResult result; result.prepared_statement_handle = request.query + " HANDLE"; @@ -515,6 +560,11 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { Status ClosePreparedStatement( const ServerCallContext& context, const sql::ActionClosePreparedStatementRequest& request) override { + if (request.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE" && + request.prepared_statement_handle != "UPDATE PREPARED STATEMENT HANDLE") { + return Status::Invalid("Invalid handle for ClosePreparedStatement: ", + request.prepared_statement_handle); + } return Status::OK(); } @@ -522,11 +572,14 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { const sql::PreparedStatementQuery& command, FlightMessageReader* reader, FlightMetadataWriter* writer) override { - ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", - command.prepared_statement_handle)); + if (command.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE") { + return Status::Invalid("Invalid handle for DoPutPreparedStatementQuery: ", + command.prepared_statement_handle); + } ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); - ARROW_RETURN_NOT_OK(AssertEq(*GetQuerySchema(), *actual_schema)); + ARROW_RETURN_NOT_OK(AssertEq(*GetQuerySchema(), *actual_schema, + "Wrong schema for DoPutPreparedStatementQuery")); return Status::OK(); } @@ -534,10 +587,11 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result DoPutPreparedStatementUpdate( const ServerCallContext& context, const sql::PreparedStatementUpdate& command, FlightMessageReader* reader) override { - ARROW_RETURN_NOT_OK(AssertEq("UPDATE PREPARED STATEMENT HANDLE", - command.prepared_statement_handle)); - - return kUpdatePreparedStatementExpectedRows; + if (command.prepared_statement_handle == "UPDATE PREPARED STATEMENT HANDLE") { + return kUpdatePreparedStatementExpectedRows; + } + return Status::Invalid("Invalid handle for DoPutPreparedStatementUpdate: ", + command.prepared_statement_handle); } private: @@ -569,19 +623,27 @@ class FlightSqlScenario : public Scenario { Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } - Status Validate(std::shared_ptr expected_schema, - arrow::Result> flight_info_result, - sql::FlightSqlClient* sql_client) { + Status Validate(const std::shared_ptr& expected_schema, + const FlightInfo& flight_info, sql::FlightSqlClient* sql_client) { FlightCallOptions call_options; - - ARROW_ASSIGN_OR_RAISE(auto flight_info, flight_info_result); ARROW_ASSIGN_OR_RAISE( - auto reader, sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); - + std::unique_ptr reader, + sql_client->DoGet(call_options, flight_info.endpoints()[0].ticket)); ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); + if (!expected_schema->Equals(*actual_schema, /*check_metadata=*/true)) { + return Status::Invalid("Schemas did not match. Expected:\n", *expected_schema, + "\nActual:\n", *actual_schema); + } + ARROW_RETURN_NOT_OK(reader->ToTable()); + return Status::OK(); + } - if (!actual_schema->Equals(*expected_schema, /*check_metadata=*/true)) { - return Status::Invalid("Schemas do not match. Expected:\n", *expected_schema, + Status ValidateSchema(const std::shared_ptr& expected_schema, + const SchemaResult& result) { + ipc::DictionaryMemo memo; + ARROW_ASSIGN_OR_RAISE(auto actual_schema, result.GetSchema(&memo)); + if (!expected_schema->Equals(*actual_schema, /*check_metadata=*/true)) { + return Status::Invalid("Schemas did not match. Expected:\n", *expected_schema, "\nActual:\n", *actual_schema); } return Status::OK(); @@ -589,13 +651,9 @@ class FlightSqlScenario : public Scenario { Status RunClient(std::unique_ptr client) override { sql::FlightSqlClient sql_client(std::move(client)); - ARROW_RETURN_NOT_OK(ValidateMetadataRetrieval(&sql_client)); - ARROW_RETURN_NOT_OK(ValidateStatementExecution(&sql_client)); - ARROW_RETURN_NOT_OK(ValidatePreparedStatementExecution(&sql_client)); - return Status::OK(); } @@ -613,82 +671,119 @@ class FlightSqlScenario : public Scenario { sql::TableRef pk_table_ref = {"pk_catalog", "pk_db_schema", "pk_table"}; sql::TableRef fk_table_ref = {"fk_catalog", "fk_db_schema", "fk_table"}; - ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetCatalogsSchema(), - sql_client->GetCatalogs(options), sql_client)); + std::unique_ptr info; + std::unique_ptr schema; + + ARROW_ASSIGN_OR_RAISE(info, sql_client->GetCatalogs(options)); + ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetCatalogsSchema(options)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetCatalogsSchema(), *info, sql_client)); + ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetCatalogsSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE( + info, sql_client->GetDbSchemas(options, &catalog, &db_schema_filter_pattern)); + ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetDbSchemasSchema(options)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetDbSchemasSchema(), *info, sql_client)); + ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetDbSchemasSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE( + info, sql_client->GetTables(options, &catalog, &db_schema_filter_pattern, + &table_filter_pattern, true, &table_types)); + ARROW_ASSIGN_OR_RAISE(schema, + sql_client->GetTablesSchema(options, /*include_schema=*/true)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetTablesSchemaWithIncludedSchema(), *info, sql_client)); + ARROW_RETURN_NOT_OK( + ValidateSchema(sql::SqlSchema::GetTablesSchemaWithIncludedSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE(schema, + sql_client->GetTablesSchema(options, /*include_schema=*/false)); + ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetTablesSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE(info, sql_client->GetTableTypes(options)); + ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetTableTypesSchema(options)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetTableTypesSchema(), *info, sql_client)); + ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetTableTypesSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE(info, sql_client->GetPrimaryKeys(options, table_ref)); + ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetPrimaryKeysSchema(options)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetPrimaryKeysSchema(), *info, sql_client)); + ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetPrimaryKeysSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE(info, sql_client->GetExportedKeys(options, table_ref)); + ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetExportedKeysSchema(options)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetExportedKeysSchema(), *info, sql_client)); + ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetExportedKeysSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE(info, sql_client->GetImportedKeys(options, table_ref)); + ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetImportedKeysSchema(options)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetImportedKeysSchema(), *info, sql_client)); + ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetImportedKeysSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE( + info, sql_client->GetCrossReference(options, pk_table_ref, fk_table_ref)); + ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetCrossReferenceSchema(options)); ARROW_RETURN_NOT_OK( - Validate(sql::SqlSchema::GetDbSchemasSchema(), - sql_client->GetDbSchemas(options, &catalog, &db_schema_filter_pattern), - sql_client)); + Validate(sql::SqlSchema::GetCrossReferenceSchema(), *info, sql_client)); ARROW_RETURN_NOT_OK( - Validate(sql::SqlSchema::GetTablesSchemaWithIncludedSchema(), - sql_client->GetTables(options, &catalog, &db_schema_filter_pattern, - &table_filter_pattern, true, &table_types), - sql_client)); - ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetTableTypesSchema(), - sql_client->GetTableTypes(options), sql_client)); - ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetPrimaryKeysSchema(), - sql_client->GetPrimaryKeys(options, table_ref), - sql_client)); - ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetExportedKeysSchema(), - sql_client->GetExportedKeys(options, table_ref), - sql_client)); - ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetImportedKeysSchema(), - sql_client->GetImportedKeys(options, table_ref), - sql_client)); - ARROW_RETURN_NOT_OK(Validate( - sql::SqlSchema::GetCrossReferenceSchema(), - sql_client->GetCrossReference(options, pk_table_ref, fk_table_ref), sql_client)); - ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetXdbcTypeInfoSchema(), - sql_client->GetXdbcTypeInfo(options), sql_client)); - ARROW_RETURN_NOT_OK(Validate( - sql::SqlSchema::GetSqlInfoSchema(), - sql_client->GetSqlInfo( - options, {sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, - sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY}), - sql_client)); + ValidateSchema(sql::SqlSchema::GetCrossReferenceSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE(info, sql_client->GetXdbcTypeInfo(options)); + ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetXdbcTypeInfoSchema(options)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetXdbcTypeInfoSchema(), *info, sql_client)); + ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetXdbcTypeInfoSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE( + info, sql_client->GetSqlInfo( + options, {sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, + sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY})); + ARROW_ASSIGN_OR_RAISE(schema, sql_client->GetSqlInfoSchema(options)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetSqlInfoSchema(), *info, sql_client)); + ARROW_RETURN_NOT_OK(ValidateSchema(sql::SqlSchema::GetSqlInfoSchema(), *schema)); return Status::OK(); } Status ValidateStatementExecution(sql::FlightSqlClient* sql_client) { - FlightCallOptions options; + ARROW_ASSIGN_OR_RAISE(auto info, sql_client->Execute({}, kSelectStatement)); + ARROW_RETURN_NOT_OK(Validate(GetQuerySchema(), *info, sql_client)); - ARROW_RETURN_NOT_OK(Validate( - GetQuerySchema(), sql_client->Execute(options, "SELECT STATEMENT"), sql_client)); - ARROW_ASSIGN_OR_RAISE(auto update_statement_result, - sql_client->ExecuteUpdate(options, "UPDATE STATEMENT")); - if (update_statement_result != kUpdateStatementExpectedRows) { - return Status::Invalid("Expected 'UPDATE STATEMENT' return ", - kUpdateStatementExpectedRows, ", got ", - update_statement_result); - } + ARROW_ASSIGN_OR_RAISE(auto schema, + sql_client->GetExecuteSchema({}, kSelectStatement)); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQuerySchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE(auto updated_rows, + sql_client->ExecuteUpdate({}, "UPDATE STATEMENT")); + ARROW_RETURN_NOT_OK(AssertEq(kUpdateStatementExpectedRows, updated_rows, + "Wrong number of updated rows for ExecuteUpdate")); return Status::OK(); } Status ValidatePreparedStatementExecution(sql::FlightSqlClient* sql_client) { - FlightCallOptions options; - - ARROW_ASSIGN_OR_RAISE(auto select_prepared_statement, - sql_client->Prepare(options, "SELECT PREPARED STATEMENT")); - auto parameters = RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")}); - ARROW_RETURN_NOT_OK(select_prepared_statement->SetParameters(parameters)); - ARROW_RETURN_NOT_OK( - Validate(GetQuerySchema(), select_prepared_statement->Execute(), sql_client)); + ARROW_ASSIGN_OR_RAISE(auto select_prepared_statement, + sql_client->Prepare({}, "SELECT PREPARED STATEMENT")); + ARROW_RETURN_NOT_OK(select_prepared_statement->SetParameters(parameters)); + ARROW_ASSIGN_OR_RAISE(auto info, select_prepared_statement->Execute()); + ARROW_RETURN_NOT_OK(Validate(GetQuerySchema(), *info, sql_client)); + ARROW_ASSIGN_OR_RAISE(auto schema, select_prepared_statement->GetSchema({})); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQuerySchema(), *schema)); ARROW_RETURN_NOT_OK(select_prepared_statement->Close()); ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement, - sql_client->Prepare(options, "UPDATE PREPARED STATEMENT")); - ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement_result, - update_prepared_statement->ExecuteUpdate()); - if (update_prepared_statement_result != kUpdatePreparedStatementExpectedRows) { - return Status::Invalid("Expected 'UPDATE STATEMENT' return ", - kUpdatePreparedStatementExpectedRows, ", got ", - update_prepared_statement_result); - } + sql_client->Prepare({}, "UPDATE PREPARED STATEMENT")); + ARROW_ASSIGN_OR_RAISE(auto updated_rows, update_prepared_statement->ExecuteUpdate()); + ARROW_RETURN_NOT_OK( + AssertEq(kUpdatePreparedStatementExpectedRows, updated_rows, + "Wrong number of updated rows for prepared statement ExecuteUpdate")); ARROW_RETURN_NOT_OK(update_prepared_statement->Close()); return Status::OK(); diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc index 10ff1eea6f4cb..e299b7ceb11d4 100644 --- a/cpp/src/arrow/flight/sql/client.cc +++ b/cpp/src/arrow/flight/sql/client.cc @@ -36,15 +36,45 @@ namespace arrow { namespace flight { namespace sql { +namespace { +arrow::Result GetFlightDescriptorForCommand( + const google::protobuf::Message& command) { + google::protobuf::Any any; + if (!any.PackFrom(command)) { + return Status::SerializationError("Failed to pack ", command.GetTypeName()); + } + + std::string buf; + if (!any.SerializeToString(&buf)) { + return Status::SerializationError("Failed to serialize ", command.GetTypeName()); + } + return FlightDescriptor::Command(buf); +} + +arrow::Result> GetFlightInfoForCommand( + FlightSqlClient* client, const FlightCallOptions& options, + const google::protobuf::Message& command) { + ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, + GetFlightDescriptorForCommand(command)); + return client->GetFlightInfo(options, descriptor); +} + +arrow::Result> GetSchemaForCommand( + FlightSqlClient* client, const FlightCallOptions& options, + const google::protobuf::Message& command) { + ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, + GetFlightDescriptorForCommand(command)); + return client->GetSchema(options, descriptor); +} +} // namespace + FlightSqlClient::FlightSqlClient(std::shared_ptr client) : impl_(std::move(client)) {} PreparedStatement::PreparedStatement(FlightSqlClient* client, std::string handle, std::shared_ptr dataset_schema, - std::shared_ptr parameter_schema, - FlightCallOptions options) + std::shared_ptr parameter_schema) : client_(client), - options_(std::move(options)), handle_(std::move(handle)), dataset_schema_(std::move(dataset_schema)), parameter_schema_(std::move(parameter_schema)), @@ -59,30 +89,20 @@ PreparedStatement::~PreparedStatement() { } } -inline FlightDescriptor GetFlightDescriptorForCommand( - const google::protobuf::Message& command) { - google::protobuf::Any any; - any.PackFrom(command); - - const std::string& string = any.SerializeAsString(); - return FlightDescriptor::Command(string); -} - -arrow::Result> GetFlightInfoForCommand( - FlightSqlClient& client, const FlightCallOptions& options, - const google::protobuf::Message& command) { - const FlightDescriptor& descriptor = GetFlightDescriptorForCommand(command); +arrow::Result> FlightSqlClient::Execute( + const FlightCallOptions& options, const std::string& query) { + flight_sql_pb::CommandStatementQuery command; + command.set_query(query); - ARROW_ASSIGN_OR_RAISE(auto flight_info, client.GetFlightInfo(options, descriptor)); - return std::move(flight_info); + return GetFlightInfoForCommand(this, options, command); } -arrow::Result> FlightSqlClient::Execute( +arrow::Result> FlightSqlClient::GetExecuteSchema( const FlightCallOptions& options, const std::string& query) { flight_sql_pb::CommandStatementQuery command; command.set_query(query); - return GetFlightInfoForCommand(*this, options, command); + return GetSchemaForCommand(this, options, command); } arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& options, @@ -90,7 +110,8 @@ arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& o flight_sql_pb::CommandStatementUpdate command; command.set_query(query); - const FlightDescriptor& descriptor = GetFlightDescriptorForCommand(command); + ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, + GetFlightDescriptorForCommand(command)); std::unique_ptr writer; std::unique_ptr reader; @@ -114,8 +135,13 @@ arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& o arrow::Result> FlightSqlClient::GetCatalogs( const FlightCallOptions& options) { flight_sql_pb::CommandGetCatalogs command; + return GetFlightInfoForCommand(this, options, command); +} - return GetFlightInfoForCommand(*this, options, command); +arrow::Result> FlightSqlClient::GetCatalogsSchema( + const FlightCallOptions& options) { + flight_sql_pb::CommandGetCatalogs command; + return GetSchemaForCommand(this, options, command); } arrow::Result> FlightSqlClient::GetDbSchemas( @@ -129,7 +155,13 @@ arrow::Result> FlightSqlClient::GetDbSchemas( command.set_db_schema_filter_pattern(*db_schema_filter_pattern); } - return GetFlightInfoForCommand(*this, options, command); + return GetFlightInfoForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetDbSchemasSchema( + const FlightCallOptions& options) { + flight_sql_pb::CommandGetDbSchemas command; + return GetSchemaForCommand(this, options, command); } arrow::Result> FlightSqlClient::GetTables( @@ -158,7 +190,14 @@ arrow::Result> FlightSqlClient::GetTables( } } - return GetFlightInfoForCommand(*this, options, command); + return GetFlightInfoForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetTablesSchema( + const FlightCallOptions& options, bool include_schema) { + flight_sql_pb::CommandGetTables command; + command.set_include_schema(include_schema); + return GetSchemaForCommand(this, options, command); } arrow::Result> FlightSqlClient::GetPrimaryKeys( @@ -175,7 +214,13 @@ arrow::Result> FlightSqlClient::GetPrimaryKeys( command.set_table(table_ref.table); - return GetFlightInfoForCommand(*this, options, command); + return GetFlightInfoForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetPrimaryKeysSchema( + const FlightCallOptions& options) { + flight_sql_pb::CommandGetPrimaryKeys command; + return GetSchemaForCommand(this, options, command); } arrow::Result> FlightSqlClient::GetExportedKeys( @@ -192,7 +237,13 @@ arrow::Result> FlightSqlClient::GetExportedKeys( command.set_table(table_ref.table); - return GetFlightInfoForCommand(*this, options, command); + return GetFlightInfoForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetExportedKeysSchema( + const FlightCallOptions& options) { + flight_sql_pb::CommandGetExportedKeys command; + return GetSchemaForCommand(this, options, command); } arrow::Result> FlightSqlClient::GetImportedKeys( @@ -209,7 +260,13 @@ arrow::Result> FlightSqlClient::GetImportedKeys( command.set_table(table_ref.table); - return GetFlightInfoForCommand(*this, options, command); + return GetFlightInfoForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetImportedKeysSchema( + const FlightCallOptions& options) { + flight_sql_pb::CommandGetImportedKeys command; + return GetSchemaForCommand(this, options, command); } arrow::Result> FlightSqlClient::GetCrossReference( @@ -233,21 +290,33 @@ arrow::Result> FlightSqlClient::GetCrossReference( } command.set_fk_table(fk_table_ref.table); - return GetFlightInfoForCommand(*this, options, command); + return GetFlightInfoForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetCrossReferenceSchema( + const FlightCallOptions& options) { + flight_sql_pb::CommandGetCrossReference command; + return GetSchemaForCommand(this, options, command); } arrow::Result> FlightSqlClient::GetTableTypes( const FlightCallOptions& options) { flight_sql_pb::CommandGetTableTypes command; - return GetFlightInfoForCommand(*this, options, command); + return GetFlightInfoForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetTableTypesSchema( + const FlightCallOptions& options) { + flight_sql_pb::CommandGetTableTypes command; + return GetSchemaForCommand(this, options, command); } arrow::Result> FlightSqlClient::GetXdbcTypeInfo( const FlightCallOptions& options) { flight_sql_pb::CommandGetXdbcTypeInfo command; - return GetFlightInfoForCommand(*this, options, command); + return GetFlightInfoForCommand(this, options, command); } arrow::Result> FlightSqlClient::GetXdbcTypeInfo( @@ -256,7 +325,27 @@ arrow::Result> FlightSqlClient::GetXdbcTypeInfo( command.set_data_type(data_type); - return GetFlightInfoForCommand(*this, options, command); + return GetFlightInfoForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetXdbcTypeInfoSchema( + const FlightCallOptions& options) { + flight_sql_pb::CommandGetXdbcTypeInfo command; + return GetSchemaForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetSqlInfo( + const FlightCallOptions& options, const std::vector& sql_info) { + flight_sql_pb::CommandGetSqlInfo command; + for (const int& info : sql_info) command.add_info(info); + + return GetFlightInfoForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetSqlInfoSchema( + const FlightCallOptions& options) { + flight_sql_pb::CommandGetSqlInfo command; + return GetSchemaForCommand(this, options, command); } arrow::Result> FlightSqlClient::DoGet( @@ -319,28 +408,24 @@ arrow::Result> FlightSqlClient::Prepare( auto handle = prepared_statement_result.prepared_statement_handle(); return std::make_shared(this, handle, dataset_schema, - parameter_schema, options); + parameter_schema); } -arrow::Result> PreparedStatement::Execute() { +arrow::Result> PreparedStatement::Execute( + const FlightCallOptions& options) { if (is_closed_) { return Status::Invalid("Statement already closed."); } - flight_sql_pb::CommandPreparedStatementQuery execute_query_command; - - execute_query_command.set_prepared_statement_handle(handle_); - - google::protobuf::Any any; - any.PackFrom(execute_query_command); - - const std::string& string = any.SerializeAsString(); - const FlightDescriptor descriptor = FlightDescriptor::Command(string); + flight_sql_pb::CommandPreparedStatementQuery command; + command.set_prepared_statement_handle(handle_); + ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, + GetFlightDescriptorForCommand(command)); if (parameter_binding_ && parameter_binding_->num_rows() > 0) { std::unique_ptr writer; std::unique_ptr reader; - ARROW_RETURN_NOT_OK(client_->DoPut(options_, descriptor, parameter_binding_->schema(), + ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, parameter_binding_->schema(), &writer, &reader)); ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_)); @@ -350,28 +435,30 @@ arrow::Result> PreparedStatement::Execute() { ARROW_RETURN_NOT_OK(reader->ReadMetadata(&buffer)); } - ARROW_ASSIGN_OR_RAISE(auto flight_info, client_->GetFlightInfo(options_, descriptor)); + ARROW_ASSIGN_OR_RAISE(auto flight_info, client_->GetFlightInfo(options, descriptor)); return std::move(flight_info); } -arrow::Result PreparedStatement::ExecuteUpdate() { +arrow::Result PreparedStatement::ExecuteUpdate( + const FlightCallOptions& options) { if (is_closed_) { return Status::Invalid("Statement already closed."); } flight_sql_pb::CommandPreparedStatementUpdate command; command.set_prepared_statement_handle(handle_); - const FlightDescriptor& descriptor = GetFlightDescriptorForCommand(command); + ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, + GetFlightDescriptorForCommand(command)); std::unique_ptr writer; std::unique_ptr reader; if (parameter_binding_ && parameter_binding_->num_rows() > 0) { - ARROW_RETURN_NOT_OK(client_->DoPut(options_, descriptor, parameter_binding_->schema(), + ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, parameter_binding_->schema(), &writer, &reader)); ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_)); } else { const std::shared_ptr schema = arrow::schema({}); - ARROW_RETURN_NOT_OK(client_->DoPut(options_, descriptor, schema, &writer, &reader)); + ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, schema, &writer, &reader)); const ArrayVector columns; const auto& record_batch = arrow::RecordBatch::Make(schema, 0, columns); ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*record_batch)); @@ -406,7 +493,20 @@ std::shared_ptr PreparedStatement::parameter_schema() const { return parameter_schema_; } -Status PreparedStatement::Close() { +arrow::Result> PreparedStatement::GetSchema( + const FlightCallOptions& options) { + if (is_closed_) { + return Status::Invalid("Statement already closed"); + } + + flight_sql_pb::CommandPreparedStatementQuery command; + command.set_prepared_statement_handle(handle_); + ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, + GetFlightDescriptorForCommand(command)); + return client_->GetSchema(options, descriptor); +} + +Status PreparedStatement::Close(const FlightCallOptions& options) { if (is_closed_) { return Status::Invalid("Statement already closed."); } @@ -422,7 +522,7 @@ Status PreparedStatement::Close() { std::unique_ptr results; - ARROW_RETURN_NOT_OK(client_->DoAction(options_, action, &results)); + ARROW_RETURN_NOT_OK(client_->DoAction(options, action, &results)); is_closed_ = true; @@ -431,14 +531,6 @@ Status PreparedStatement::Close() { Status FlightSqlClient::Close() { return impl_->Close(); } -arrow::Result> FlightSqlClient::GetSqlInfo( - const FlightCallOptions& options, const std::vector& sql_info) { - flight_sql_pb::CommandGetSqlInfo command; - for (const int& info : sql_info) command.add_info(info); - - return GetFlightInfoForCommand(*this, options, command); -} - } // namespace sql } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h index 7c8cb640e8d13..26315e0d234fe 100644 --- a/cpp/src/arrow/flight/sql/client.h +++ b/cpp/src/arrow/flight/sql/client.h @@ -54,6 +54,10 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { arrow::Result> Execute(const FlightCallOptions& options, const std::string& query); + /// \brief Get the result set schema from the server. + arrow::Result> GetExecuteSchema( + const FlightCallOptions& options, const std::string& query); + /// \brief Execute an update query on the server. /// \param[in] options RPC-layer hints for this call. /// \param[in] query The query to be executed in the UTF-8 format. @@ -67,6 +71,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { arrow::Result> GetCatalogs( const FlightCallOptions& options); + /// \brief Get the catalogs schema from the server (should be + /// identical to SqlSchema::GetCatalogsSchema). + arrow::Result> GetCatalogsSchema( + const FlightCallOptions& options); + /// \brief Request a list of database schemas. /// \param[in] options RPC-layer hints for this call. /// \param[in] catalog The catalog. @@ -76,6 +85,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { const FlightCallOptions& options, const std::string* catalog, const std::string* db_schema_filter_pattern); + /// \brief Get the database schemas schema from the server (should be + /// identical to SqlSchema::GetDbSchemasSchema). + arrow::Result> GetDbSchemasSchema( + const FlightCallOptions& options); + /// \brief Given a flight ticket and schema, request to be sent the /// stream. Returns record batch stream reader /// \param[in] options Per-RPC options @@ -99,6 +113,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { const std::string* table_filter_pattern, bool include_schema, const std::vector* table_types); + /// \brief Get the tables schema from the server (should be + /// identical to SqlSchema::GetTablesSchema). + arrow::Result> GetTablesSchema( + const FlightCallOptions& options, bool include_schema); + /// \brief Request the primary keys for a table. /// \param[in] options RPC-layer hints for this call. /// \param[in] table_ref The table reference. @@ -106,6 +125,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { arrow::Result> GetPrimaryKeys( const FlightCallOptions& options, const TableRef& table_ref); + /// \brief Get the primary keys schema from the server (should be + /// identical to SqlSchema::GetPrimaryKeysSchema). + arrow::Result> GetPrimaryKeysSchema( + const FlightCallOptions& options); + /// \brief Retrieves a description about the foreign key columns that reference the /// primary key columns of the given table. /// \param[in] options RPC-layer hints for this call. @@ -114,6 +138,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { arrow::Result> GetExportedKeys( const FlightCallOptions& options, const TableRef& table_ref); + /// \brief Get the exported keys schema from the server (should be + /// identical to SqlSchema::GetExportedKeysSchema). + arrow::Result> GetExportedKeysSchema( + const FlightCallOptions& options); + /// \brief Retrieves the foreign key columns for the given table. /// \param[in] options RPC-layer hints for this call. /// \param[in] table_ref The table reference. @@ -121,6 +150,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { arrow::Result> GetImportedKeys( const FlightCallOptions& options, const TableRef& table_ref); + /// \brief Get the imported keys schema from the server (should be + /// identical to SqlSchema::GetImportedKeysSchema). + arrow::Result> GetImportedKeysSchema( + const FlightCallOptions& options); + /// \brief Retrieves a description of the foreign key columns in the given foreign key /// table that reference the primary key or the columns representing a unique /// constraint of the parent table (could be the same or a different table). @@ -132,12 +166,22 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { const FlightCallOptions& options, const TableRef& pk_table_ref, const TableRef& fk_table_ref); + /// \brief Get the cross reference schema from the server (should be + /// identical to SqlSchema::GetCrossReferenceSchema). + arrow::Result> GetCrossReferenceSchema( + const FlightCallOptions& options); + /// \brief Request a list of table types. /// \param[in] options RPC-layer hints for this call. /// \return The FlightInfo describing where to access the dataset. arrow::Result> GetTableTypes( const FlightCallOptions& options); + /// \brief Get the table types schema from the server (should be + /// identical to SqlSchema::GetTableTypesSchema). + arrow::Result> GetTableTypesSchema( + const FlightCallOptions& options); + /// \brief Request the information about all the data types supported. /// \param[in] options RPC-layer hints for this call. /// \return The FlightInfo describing where to access the dataset. @@ -151,6 +195,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { arrow::Result> GetXdbcTypeInfo( const FlightCallOptions& options, int data_type); + /// \brief Get the type info schema from the server (should be + /// identical to SqlSchema::GetXdbcTypeInfoSchema). + arrow::Result> GetXdbcTypeInfoSchema( + const FlightCallOptions& options); + /// \brief Request a list of SQL information. /// \param[in] options RPC-layer hints for this call. /// \param[in] sql_info the SQL info required. @@ -158,6 +207,11 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { arrow::Result> GetSqlInfo(const FlightCallOptions& options, const std::vector& sql_info); + /// \brief Get the SQL information schema from the server (should be + /// identical to SqlSchema::GetSqlInfoSchema). + arrow::Result> GetSqlInfoSchema( + const FlightCallOptions& options); + /// \brief Create a prepared statement object. /// \param[in] options RPC-layer hints for this call. /// \param[in] query The query that will be executed. @@ -165,17 +219,18 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { arrow::Result> Prepare( const FlightCallOptions& options, const std::string& query); - /// \brief Retrieve the FlightInfo. - /// \param[in] options RPC-layer hints for this call. - /// \param[in] descriptor The flight descriptor. - /// \return The flight info with the metadata. - // NOTE: This is public because it is been used by the anonymous - // function GetFlightInfoForCommand. + /// \brief Call the underlying Flight client's GetFlightInfo. virtual arrow::Result> GetFlightInfo( const FlightCallOptions& options, const FlightDescriptor& descriptor) { return impl_->GetFlightInfo(options, descriptor); } + /// \brief Call the underlying Flight client's GetSchema. + virtual arrow::Result> GetSchema( + const FlightCallOptions& options, const FlightDescriptor& descriptor) { + return impl_->GetSchema(options, descriptor); + } + /// \brief Explicitly shut down and clean up the client. Status Close(); @@ -212,10 +267,9 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { /// \param[in] handle Handle for this prepared statement. /// \param[in] dataset_schema Schema of the resulting dataset. /// \param[in] parameter_schema Schema of the parameters (if any). - /// \param[in] options RPC-layer hints for this call. PreparedStatement(FlightSqlClient* client, std::string handle, std::shared_ptr dataset_schema, - std::shared_ptr parameter_schema, FlightCallOptions options); + std::shared_ptr parameter_schema); /// \brief Default destructor for the PreparedStatement class. /// The destructor will call the Close method from the class in order, @@ -226,11 +280,12 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { /// \brief Executes the prepared statement query on the server. /// \return A FlightInfo object representing the stream(s) to fetch. - arrow::Result> Execute(); + arrow::Result> Execute( + const FlightCallOptions& options = {}); /// \brief Executes the prepared statement update query on the server. /// \return The number of rows affected. - arrow::Result ExecuteUpdate(); + arrow::Result ExecuteUpdate(const FlightCallOptions& options = {}); /// \brief Retrieve the parameter schema from the query. /// \return The parameter schema from the query. @@ -245,10 +300,15 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { /// \return Status. Status SetParameters(std::shared_ptr parameter_binding); + /// \brief Re-request the result set schema from the server (should + /// be identical to dataset_schema). + arrow::Result> GetSchema( + const FlightCallOptions& options = {}); + /// \brief Close the prepared statement, so that this PreparedStatement can not used /// anymore and server can free up any resources. /// \return Status. - Status Close(); + Status Close(const FlightCallOptions& options = {}); /// \brief Check if the prepared statement is closed. /// \return The state of the prepared statement. @@ -256,7 +316,6 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { private: FlightSqlClient* client_; - FlightCallOptions options_; std::string handle_; std::shared_ptr dataset_schema_; std::shared_ptr parameter_schema_; diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index 0ebe647ba1490..78fbff0c33a4e 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -344,6 +344,72 @@ Status FlightSqlServerBase::GetFlightInfo(const ServerCallContext& context, return Status::Invalid("The defined request is invalid."); } +Status FlightSqlServerBase::GetSchema(const ServerCallContext& context, + const FlightDescriptor& request, + std::unique_ptr* schema) { + google::protobuf::Any any; + if (!any.ParseFromArray(request.cmd.data(), static_cast(request.cmd.size()))) { + return Status::Invalid("Unable to parse command"); + } + + if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(StatementQuery internal_command, + ParseCommandStatementQuery(any)); + ARROW_ASSIGN_OR_RAISE(*schema, + GetSchemaStatement(context, internal_command, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, + ParseCommandPreparedStatementQuery(any)); + ARROW_ASSIGN_OR_RAISE(*schema, + GetSchemaPreparedStatement(context, internal_command, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*schema, SchemaResult::Make(*SqlSchema::GetCatalogsSchema())); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*schema, + SchemaResult::Make(*SqlSchema::GetCrossReferenceSchema())); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*schema, SchemaResult::Make(*SqlSchema::GetDbSchemasSchema())); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*schema, + SchemaResult::Make(*SqlSchema::GetExportedKeysSchema())); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*schema, + SchemaResult::Make(*SqlSchema::GetImportedKeysSchema())); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*schema, + SchemaResult::Make(*SqlSchema::GetPrimaryKeysSchema())); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*schema, SchemaResult::Make(*SqlSchema::GetSqlInfoSchema())); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetTables command, ParseCommandGetTables(any)); + if (command.include_schema) { + ARROW_ASSIGN_OR_RAISE( + *schema, SchemaResult::Make(*SqlSchema::GetTablesSchemaWithIncludedSchema())); + } else { + ARROW_ASSIGN_OR_RAISE(*schema, SchemaResult::Make(*SqlSchema::GetTablesSchema())); + } + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*schema, SchemaResult::Make(*SqlSchema::GetTableTypesSchema())); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*schema, + SchemaResult::Make(*SqlSchema::GetXdbcTypeInfoSchema())); + return Status::OK(); + } + + return Status::NotImplemented("Command not recognized: ", any.type_url()); +} + Status FlightSqlServerBase::DoGet(const ServerCallContext& context, const Ticket& request, std::unique_ptr* stream) { google::protobuf::Any any; @@ -531,6 +597,12 @@ arrow::Result> FlightSqlServerBase::GetFlightInfoSta return Status::NotImplemented("GetFlightInfoStatement not implemented"); } +arrow::Result> FlightSqlServerBase::GetSchemaStatement( + const ServerCallContext& context, const StatementQuery& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetSchemaStatement not implemented"); +} + arrow::Result> FlightSqlServerBase::DoGetStatement( const ServerCallContext& context, const StatementQueryTicket& command) { return Status::NotImplemented("DoGetStatement not implemented"); @@ -543,6 +615,13 @@ FlightSqlServerBase::GetFlightInfoPreparedStatement(const ServerCallContext& con return Status::NotImplemented("GetFlightInfoPreparedStatement not implemented"); } +arrow::Result> +FlightSqlServerBase::GetSchemaPreparedStatement(const ServerCallContext& context, + const PreparedStatementQuery& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetSchemaPreparedStatement not implemented"); +} + arrow::Result> FlightSqlServerBase::DoGetPreparedStatement(const ServerCallContext& context, const PreparedStatementQuery& command) { diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h index f077c5d5d5d1f..49e239a0cddd4 100644 --- a/cpp/src/arrow/flight/sql/server.h +++ b/cpp/src/arrow/flight/sql/server.h @@ -28,6 +28,7 @@ #include "arrow/flight/sql/server.h" #include "arrow/flight/sql/types.h" #include "arrow/flight/sql/visibility.h" +#include "arrow/flight/types.h" #include "arrow/util/optional.h" namespace arrow { @@ -221,6 +222,25 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { virtual arrow::Result> GetFlightInfoCatalogs( const ServerCallContext& context, const FlightDescriptor& descriptor); + /// \brief Get the schema of the result set of a query. + /// \param[in] context Per-call context. + /// \param[in] command The StatementQuery containing the SQL query. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The schema of the result set. + virtual arrow::Result> GetSchemaStatement( + const ServerCallContext& context, const StatementQuery& command, + const FlightDescriptor& descriptor); + + /// \brief Get the schema of the result set of a prepared statement. + /// \param[in] context Per-call context. + /// \param[in] command The PreparedStatementQuery containing the + /// prepared statement handle. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The schema of the result set. + virtual arrow::Result> GetSchemaPreparedStatement( + const ServerCallContext& context, const PreparedStatementQuery& command, + const FlightDescriptor& descriptor); + /// \brief Get a FlightDataStream containing the list of catalogs. /// \param[in] context Per-call context. /// \return An interface for sending data back to the client. @@ -462,6 +482,9 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, std::unique_ptr* info) final; + Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* schema) override; + Status DoGet(const ServerCallContext& context, const Ticket& request, std::unique_ptr* stream) final; diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index ddb8a036fbc42..6e80f40cfbf38 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -28,6 +28,7 @@ #include "arrow/ipc/reader.h" #include "arrow/status.h" #include "arrow/table.h" +#include "arrow/util/make_unique.h" #include "arrow/util/string_view.h" #include "arrow/util/uri.h" @@ -150,10 +151,10 @@ arrow::Result> SchemaResult::GetSchema( return ipc::ReadSchema(&schema_reader, dictionary_memo); } -arrow::Result SchemaResult::Make(const Schema& schema) { +arrow::Result> SchemaResult::Make(const Schema& schema) { std::string schema_in; RETURN_NOT_OK(internal::SchemaToString(schema, &schema_in)); - return SchemaResult(std::move(schema_in)); + return arrow::internal::make_unique(std::move(schema_in)); } Status SchemaResult::GetSchema(ipc::DictionaryMemo* dictionary_memo, diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index a061f33afec0b..2ec24ff586851 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -397,7 +397,7 @@ struct ARROW_FLIGHT_EXPORT SchemaResult { explicit SchemaResult(std::string schema) : raw_schema_(std::move(schema)) {} /// \brief Factory method to construct a SchemaResult. - static arrow::Result Make(const Schema& schema); + static arrow::Result> Make(const Schema& schema); /// \brief return schema /// \param[in,out] dictionary_memo for dictionary bookkeeping, will diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc index 9077bbe4acb7d..bf7af27ac726e 100644 --- a/cpp/src/arrow/python/flight.cc +++ b/cpp/src/arrow/python/flight.cc @@ -380,10 +380,7 @@ Status CreateFlightInfo(const std::shared_ptr& schema, Status CreateSchemaResult(const std::shared_ptr& schema, std::unique_ptr* out) { - ARROW_ASSIGN_OR_RAISE(auto result, arrow::flight::SchemaResult::Make(*schema)); - *out = std::unique_ptr( - new arrow::flight::SchemaResult(std::move(result))); - return Status::OK(); + return arrow::flight::SchemaResult::Make(*schema).Value(out); } } // namespace flight diff --git a/go/arrow/flight/flightsql/client.go b/go/arrow/flight/flightsql/client.go index 5f7f693d2b2f7..b8ee01cfdeaab 100644 --- a/go/arrow/flight/flightsql/client.go +++ b/go/arrow/flight/flightsql/client.go @@ -77,6 +77,14 @@ func flightInfoForCommand(ctx context.Context, cl *Client, cmd proto.Message, op return cl.getFlightInfo(ctx, desc, opts...) } +func schemaForCommand(ctx context.Context, cl *Client, cmd proto.Message, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + desc, err := descForCommand(cmd) + if err != nil { + return nil, err + } + return cl.getSchema(ctx, desc, opts...) +} + // Execute executes the desired query on the server and returns a FlightInfo // object describing where to retrieve the results. func (c *Client) Execute(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.FlightInfo, error) { @@ -84,6 +92,13 @@ func (c *Client) Execute(ctx context.Context, query string, opts ...grpc.CallOpt return flightInfoForCommand(ctx, c, &cmd, opts...) } +// GetExecuteSchema gets the schema of the result set of a query without +// executing the query itself. +func (c *Client) GetExecuteSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + cmd := pb.CommandStatementQuery{Query: query} + return schemaForCommand(ctx, c, &cmd, opts...) +} + // ExecuteUpdate is for executing an update query and only returns the number of affected rows. func (c *Client) ExecuteUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) { var ( @@ -128,12 +143,22 @@ func (c *Client) GetCatalogs(ctx context.Context, opts ...grpc.CallOption) (*fli return flightInfoForCommand(ctx, c, &pb.CommandGetCatalogs{}, opts...) } +// GetCatalogsSchema requests the schema of GetCatalogs from the server +func (c *Client) GetCatalogsSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + return schemaForCommand(ctx, c, &pb.CommandGetCatalogs{}, opts...) +} + // GetDBSchemas requests the list of schemas from the database and // returns a FlightInfo object where the response can be retrieved func (c *Client) GetDBSchemas(ctx context.Context, cmdOpts *GetDBSchemasOpts, opts ...grpc.CallOption) (*flight.FlightInfo, error) { return flightInfoForCommand(ctx, c, (*pb.CommandGetDbSchemas)(cmdOpts), opts...) } +// GetDBSchemasSchema requests the schema of GetDBSchemas from the server +func (c *Client) GetDBSchemasSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + return schemaForCommand(ctx, c, &pb.CommandGetDbSchemas{}, opts...) +} + // DoGet uses the provided flight ticket to request the stream of data. // It returns a recordbatch reader to stream the results. Release // should be called on the reader when done. @@ -154,6 +179,11 @@ func (c *Client) GetTables(ctx context.Context, reqOptions *GetTablesOpts, opts return flightInfoForCommand(ctx, c, (*pb.CommandGetTables)(reqOptions), opts...) } +// GetTablesSchema requests the schema of GetTables from the server. +func (c *Client) GetTablesSchema(ctx context.Context, reqOptions *GetTablesOpts, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + return schemaForCommand(ctx, c, (*pb.CommandGetTables)(reqOptions), opts...) +} + // GetPrimaryKeys requests the primary keys for a specific table from the // server, specified using a TableRef. Returns a FlightInfo object where // the response can be retrieved. @@ -166,6 +196,11 @@ func (c *Client) GetPrimaryKeys(ctx context.Context, ref TableRef, opts ...grpc. return flightInfoForCommand(ctx, c, &cmd, opts...) } +// GetPrimaryKeysSchema requests the schema of GetPrimaryKeys from the server. +func (c *Client) GetPrimaryKeysSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + return schemaForCommand(ctx, c, &pb.CommandGetPrimaryKeys{}, opts...) +} + // GetExportedKeys retrieves a description about the foreign key columns // that reference the primary key columns of the specified table. Returns // a FlightInfo object where the response can be retrieved. @@ -178,6 +213,11 @@ func (c *Client) GetExportedKeys(ctx context.Context, ref TableRef, opts ...grpc return flightInfoForCommand(ctx, c, &cmd, opts...) } +// GetExportedKeysSchema requests the schema of GetExportedKeys from the server. +func (c *Client) GetExportedKeysSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + return schemaForCommand(ctx, c, &pb.CommandGetExportedKeys{}, opts...) +} + // GetImportedKeys returns the foreign key columns for the specified table. // Returns a FlightInfo object indicating where the response can be retrieved. func (c *Client) GetImportedKeys(ctx context.Context, ref TableRef, opts ...grpc.CallOption) (*flight.FlightInfo, error) { @@ -189,6 +229,11 @@ func (c *Client) GetImportedKeys(ctx context.Context, ref TableRef, opts ...grpc return flightInfoForCommand(ctx, c, &cmd, opts...) } +// GetImportedKeysSchema requests the schema of GetImportedKeys from the server. +func (c *Client) GetImportedKeysSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + return schemaForCommand(ctx, c, &pb.CommandGetImportedKeys{}, opts...) +} + // GetCrossReference retrieves a description of the foreign key columns // in the specified ForeignKey table that reference the primary key or // columns representing a restraint of the parent table (could be the same @@ -206,6 +251,11 @@ func (c *Client) GetCrossReference(ctx context.Context, pkTable, fkTable TableRe return flightInfoForCommand(ctx, c, &cmd, opts...) } +// GetCrossReferenceSchema requests the schema of GetCrossReference from the server. +func (c *Client) GetCrossReferenceSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + return schemaForCommand(ctx, c, &pb.CommandGetCrossReference{}, opts...) +} + // GetTableTypes requests a list of the types of tables available on this // server. Returns a FlightInfo object indicating where the response can // be retrieved. @@ -213,6 +263,11 @@ func (c *Client) GetTableTypes(ctx context.Context, opts ...grpc.CallOption) (*f return flightInfoForCommand(ctx, c, &pb.CommandGetTableTypes{}, opts...) } +// GetTableTypesSchema requests the schema of GetTableTypes from the server. +func (c *Client) GetTableTypesSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + return schemaForCommand(ctx, c, &pb.CommandGetTableTypes{}, opts...) +} + // GetXdbcTypeInfo requests the information about all the data types supported // (dataType == nil) or a specific data type. Returns a FlightInfo object // indicating where the response can be retrieved. @@ -220,6 +275,11 @@ func (c *Client) GetXdbcTypeInfo(ctx context.Context, dataType *int32, opts ...g return flightInfoForCommand(ctx, c, &pb.CommandGetXdbcTypeInfo{DataType: dataType}, opts...) } +// GetXdbcTypeInfoSchema requests the schema of GetXdbcTypeInfo from the server. +func (c *Client) GetXdbcTypeInfoSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + return schemaForCommand(ctx, c, &pb.CommandGetXdbcTypeInfo{}, opts...) +} + // GetSqlInfo returns a list of the requested SQL information corresponding // to the values in the info slice. Returns a FlightInfo object indicating // where the response can be retrieved. @@ -232,6 +292,11 @@ func (c *Client) GetSqlInfo(ctx context.Context, info []SqlInfo, opts ...grpc.Ca return flightInfoForCommand(ctx, c, cmd, opts...) } +// GetSqlInfoSchema requests the schema of GetSqlInfo from the server. +func (c *Client) GetSqlInfoSchema(ctx context.Context, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + return schemaForCommand(ctx, c, &pb.CommandGetSqlInfo{}, opts...) +} + // Prepare creates a PreparedStatement object for the specified query. // The resulting PreparedStatement object should be Closed when no longer // needed. It will maintain a reference to this Client for use to execute @@ -302,6 +367,10 @@ func (c *Client) getFlightInfo(ctx context.Context, desc *flight.FlightDescripto return c.Client.GetFlightInfo(ctx, desc, opts...) } +func (c *Client) getSchema(ctx context.Context, desc *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.SchemaResult, error) { + return c.Client.GetSchema(ctx, desc, opts...) +} + // Close will close the underlying flight Client in use by this flightsql.Client func (c *Client) Close() error { return c.Client.Close() } @@ -430,6 +499,25 @@ func (p *PreparedStatement) DatasetSchema() *arrow.Schema { return p.datasetSche // the prepared statement. func (p *PreparedStatement) ParameterSchema() *arrow.Schema { return p.paramSchema } +// GetSchema re-requests the schema of the result set of the prepared +// statement from the server. It should otherwise be identical to DatasetSchema. +// +// Will error if already closed. +func (p *PreparedStatement) GetSchema(ctx context.Context) (*flight.SchemaResult, error) { + if p.closed { + return nil, errors.New("arrow/flightsql: prepared statement already closed") + } + + cmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: p.handle} + + desc, err := descForCommand(cmd) + if err != nil { + return nil, err + } + + return p.client.getSchema(ctx, desc, p.opts...) +} + // SetParameters takes a record batch to send as the parameter bindings when // executing. It should match the schema from ParameterSchema. // diff --git a/go/arrow/flight/flightsql/server.go b/go/arrow/flight/flightsql/server.go index 17bc9e188aa9c..8080df9e4bded 100644 --- a/go/arrow/flight/flightsql/server.go +++ b/go/arrow/flight/flightsql/server.go @@ -181,6 +181,10 @@ func (BaseServer) GetFlightInfoStatement(context.Context, StatementQuery, *fligh return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoStatement not implemented") } +func (BaseServer) GetSchemaStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error) { + return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented") +} + func (BaseServer) DoGetStatement(context.Context, StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetStatement not implemented") } @@ -189,6 +193,10 @@ func (BaseServer) GetFlightInfoPreparedStatement(context.Context, PreparedStatem return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoPreparedStatement not implemented") } +func (BaseServer) GetSchemaPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error) { + return nil, status.Errorf(codes.Unimplemented, "GetSchemaPreparedStatement not implemented") +} + func (BaseServer) DoGetPreparedStatement(context.Context, PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error) { return nil, nil, status.Errorf(codes.Unimplemented, "DoGetPreparedStatement not implemented") } @@ -367,12 +375,17 @@ func (BaseServer) DoPutPreparedStatementUpdate(context.Context, PreparedStatemen type Server interface { // GetFlightInfoStatement returns a FlightInfo for executing the requested sql query GetFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.FlightInfo, error) + // GetFlightInfoStatement returns the schema of the result set of the requested sql query + GetSchemaStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error) // DoGetStatement returns a stream containing the query results for the // requested statement handle that was populated by GetFlightInfoStatement DoGetStatement(context.Context, StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) // GetFlightInfoPreparedStatement returns a FlightInfo for executing an already // prepared statement with the provided statement handle. GetFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.FlightInfo, error) + // GetSchemaPreparedStatement returns the schema of the result set of executing an already + // prepared statement with the provided statement handle. + GetSchemaPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.SchemaResult, error) // DoGetPreparedStatement returns a stream containing the results from executing // a prepared statement query with the provided statement handle. DoGetPreparedStatement(context.Context, PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error) @@ -519,6 +532,53 @@ func (f *flightSqlServer) GetFlightInfo(ctx context.Context, request *flight.Fli return nil, status.Error(codes.InvalidArgument, "requested command is invalid") } +func (f *flightSqlServer) GetSchema(ctx context.Context, request *flight.FlightDescriptor) (*flight.SchemaResult, error) { + var ( + anycmd anypb.Any + cmd proto.Message + err error + ) + if err = proto.Unmarshal(request.Cmd, &anycmd); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "unable to parse command: %s", err.Error()) + } + + if cmd, err = anycmd.UnmarshalNew(); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "could not unmarshal Any to a command type: %s", err.Error()) + } + + switch cmd := cmd.(type) { + case *pb.CommandStatementQuery: + return f.srv.GetSchemaStatement(ctx, cmd, request) + case *pb.CommandPreparedStatementQuery: + return f.srv.GetSchemaPreparedStatement(ctx, cmd, request) + case *pb.CommandGetCatalogs: + return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.Catalogs, f.mem)}, nil + case *pb.CommandGetDbSchemas: + return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.DBSchemas, f.mem)}, nil + case *pb.CommandGetTables: + if cmd.GetIncludeSchema() { + return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.TablesWithIncludedSchema, f.mem)}, nil + } + return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.Tables, f.mem)}, nil + case *pb.CommandGetTableTypes: + return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.TableTypes, f.mem)}, nil + case *pb.CommandGetXdbcTypeInfo: + return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.XdbcTypeInfo, f.mem)}, nil + case *pb.CommandGetSqlInfo: + return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.SqlInfo, f.mem)}, nil + case *pb.CommandGetPrimaryKeys: + return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.PrimaryKeys, f.mem)}, nil + case *pb.CommandGetExportedKeys: + return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.ExportedKeys, f.mem)}, nil + case *pb.CommandGetImportedKeys: + return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.ImportedKeys, f.mem)}, nil + case *pb.CommandGetCrossReference: + return &flight.SchemaResult{Schema: flight.SerializeSchema(schema_ref.CrossReference, f.mem)}, nil + } + + return nil, status.Errorf(codes.InvalidArgument, "requested command is invalid: %s", anycmd.GetTypeUrl()) +} + func (f *flightSqlServer) DoGet(request *flight.Ticket, stream flight.FlightService_DoGetServer) (err error) { var ( anycmd anypb.Any diff --git a/go/arrow/internal/flight_integration/scenario.go b/go/arrow/internal/flight_integration/scenario.go index c89334002d1da..4e96d7100abb6 100644 --- a/go/arrow/internal/flight_integration/scenario.go +++ b/go/arrow/internal/flight_integration/scenario.go @@ -599,6 +599,22 @@ func (m *flightSqlScenarioTester) validate(expected *arrow.Schema, result *fligh if !expected.Equal(rdr.Schema()) { return fmt.Errorf("expected: %s, got: %s", expected, rdr.Schema()) } + for { + _, err := rdr.Read() + if err == io.EOF { break } + if err != nil { return err } + } + return nil +} + +func (m *flightSqlScenarioTester) validateSchema(expected *arrow.Schema, result *flight.SchemaResult) error { + schema, err := flight.DeserializeSchema(result.GetSchema(), memory.DefaultAllocator) + if err != nil { + return err + } + if !expected.Equal(schema) { + return fmt.Errorf("expected: %s, got: %s", expected, schema) + } return nil } @@ -626,6 +642,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl return err } + schema, err := client.GetCatalogsSchema(ctx) + if err != nil { + return err + } + if err := m.validateSchema(schema_ref.Catalogs, schema); err != nil { + return err + } + info, err = client.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{Catalog: &catalog, DbSchemaFilterPattern: &dbSchemaFilterPattern}) if err != nil { return err @@ -634,6 +658,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl return err } + schema, err = client.GetDBSchemasSchema(ctx) + if err != nil { + return err + } + if err = m.validateSchema(schema_ref.DBSchemas, schema); err != nil { + return err + } + info, err = client.GetTables(ctx, &flightsql.GetTablesOpts{Catalog: &catalog, DbSchemaFilterPattern: &dbSchemaFilterPattern, TableNameFilterPattern: &tableFilterPattern, IncludeSchema: true, TableTypes: tableTypes}) if err != nil { return err @@ -642,6 +674,22 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl return err } + schema, err = client.GetTablesSchema(ctx, &flightsql.GetTablesOpts{IncludeSchema: true}) + if err != nil { + return err + } + if err = m.validateSchema(schema_ref.TablesWithIncludedSchema, schema); err != nil { + return err + } + + schema, err = client.GetTablesSchema(ctx, &flightsql.GetTablesOpts{IncludeSchema: false}) + if err != nil { + return err + } + if err = m.validateSchema(schema_ref.Tables, schema); err != nil { + return err + } + info, err = client.GetTableTypes(ctx) if err != nil { return err @@ -650,6 +698,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl return err } + schema, err = client.GetTableTypesSchema(ctx) + if err != nil { + return err + } + if err = m.validateSchema(schema_ref.TableTypes, schema); err != nil { + return err + } + info, err = client.GetPrimaryKeys(ctx, ref) if err != nil { return err @@ -658,6 +714,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl return err } + schema, err = client.GetPrimaryKeysSchema(ctx) + if err != nil { + return err + } + if err = m.validateSchema(schema_ref.PrimaryKeys, schema); err != nil { + return err + } + info, err = client.GetExportedKeys(ctx, ref) if err != nil { return err @@ -666,6 +730,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl return err } + schema, err = client.GetExportedKeysSchema(ctx) + if err != nil { + return err + } + if err = m.validateSchema(schema_ref.ExportedKeys, schema); err != nil { + return err + } + info, err = client.GetImportedKeys(ctx, ref) if err != nil { return err @@ -674,6 +746,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl return err } + schema, err = client.GetImportedKeysSchema(ctx) + if err != nil { + return err + } + if err = m.validateSchema(schema_ref.ImportedKeys, schema); err != nil { + return err + } + info, err = client.GetCrossReference(ctx, pkRef, fkRef) if err != nil { return err @@ -682,6 +762,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl return err } + schema, err = client.GetCrossReferenceSchema(ctx) + if err != nil { + return err + } + if err = m.validateSchema(schema_ref.CrossReference, schema); err != nil { + return err + } + info, err = client.GetXdbcTypeInfo(ctx, nil) if err != nil { return err @@ -690,6 +778,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl return err } + schema, err = client.GetXdbcTypeInfoSchema(ctx) + if err != nil { + return err + } + if err = m.validateSchema(schema_ref.XdbcTypeInfo, schema); err != nil { + return err + } + info, err = client.GetSqlInfo(ctx, []flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerName, flightsql.SqlInfoFlightSqlServerReadOnly}) if err != nil { return err @@ -698,6 +794,14 @@ func (m *flightSqlScenarioTester) ValidateMetadataRetrieval(client *flightsql.Cl return err } + schema, err = client.GetSqlInfoSchema(ctx) + if err != nil { + return err + } + if err = m.validateSchema(schema_ref.SqlInfo, schema); err != nil { + return err + } + return nil } @@ -711,6 +815,14 @@ func (m *flightSqlScenarioTester) ValidateStatementExecution(client *flightsql.C return err } + schema, err := client.GetExecuteSchema(ctx, "SELECT STATEMENT") + if err != nil { + return err + } + if err = m.validateSchema(QuerySchema, schema); err != nil { + return err + } + updateResult, err := client.ExecuteUpdate(ctx, "UPDATE STATEMENT") if err != nil { return err @@ -740,6 +852,13 @@ func (m *flightSqlScenarioTester) ValidatePreparedStatementExecution(client *fli if err = m.validate(QuerySchema, info, client); err != nil { return err } + schema, err := prepared.GetSchema(ctx) + if err != nil { + return err + } + if err = m.validateSchema(QuerySchema, schema); err != nil { + return err + } if err = prepared.Close(ctx); err != nil { return err @@ -762,9 +881,7 @@ func (m *flightSqlScenarioTester) ValidatePreparedStatementExecution(client *fli func (m *flightSqlScenarioTester) doGetForTestCase(schema *arrow.Schema) chan flight.StreamChunk { ch := make(chan flight.StreamChunk) - go func() { - ch <- flight.StreamChunk{Data: array.NewRecord(schema, []arrow.Array{}, 0)} - }() + close(ch) return ch } @@ -789,6 +906,10 @@ func (m *flightSqlScenarioTester) GetFlightInfoStatement(ctx context.Context, cm }, nil } +func (m *flightSqlScenarioTester) GetSchemaStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) { + return &flight.SchemaResult{Schema: flight.SerializeSchema(QuerySchema, memory.DefaultAllocator)}, nil +} + func (m *flightSqlScenarioTester) DoGetStatement(ctx context.Context, cmd flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) { return QuerySchema, m.doGetForTestCase(QuerySchema), nil } @@ -801,6 +922,10 @@ func (m *flightSqlScenarioTester) GetFlightInfoPreparedStatement(_ context.Conte return m.flightInfoForCommand(desc, QuerySchema), nil } +func (m *flightSqlScenarioTester) GetSchemaPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.SchemaResult, error) { + return &flight.SchemaResult{Schema: flight.SerializeSchema(QuerySchema, memory.DefaultAllocator)}, nil +} + func (m *flightSqlScenarioTester) DoGetPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error) { return QuerySchema, m.doGetForTestCase(QuerySchema), nil } diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java index cf17349064cb2..19c1378cfe6c5 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java @@ -26,6 +26,7 @@ import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.SchemaResult; import org.apache.arrow.flight.Ticket; import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.flight.sql.FlightSqlProducer; @@ -72,32 +73,52 @@ private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Excepti validate(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA, sqlClient.getCatalogs(options), sqlClient); + validateSchema(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA, sqlClient.getCatalogsSchema(options)); + validate(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA, sqlClient.getSchemas("catalog", "db_schema_filter_pattern", options), sqlClient); + validateSchema(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA, sqlClient.getSchemasSchema()); + validate(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA, sqlClient.getTables("catalog", "db_schema_filter_pattern", "table_filter_pattern", Arrays.asList("table", "view"), true, options), sqlClient); - validate(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA, sqlClient.getTableTypes(options), - sqlClient); + validateSchema(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA, + sqlClient.getTablesSchema(/*includeSchema*/true, options)); + validateSchema(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA, + sqlClient.getTablesSchema(/*includeSchema*/false, options)); + + validate(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA, sqlClient.getTableTypes(options), sqlClient); + validateSchema(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA, sqlClient.getTableTypesSchema(options)); + validate(FlightSqlProducer.Schemas.GET_PRIMARY_KEYS_SCHEMA, sqlClient.getPrimaryKeys(TableRef.of("catalog", "db_schema", "table"), options), sqlClient); + validateSchema(FlightSqlProducer.Schemas.GET_PRIMARY_KEYS_SCHEMA, sqlClient.getPrimaryKeysSchema(options)); + validate(FlightSqlProducer.Schemas.GET_EXPORTED_KEYS_SCHEMA, sqlClient.getExportedKeys(TableRef.of("catalog", "db_schema", "table"), options), sqlClient); + validateSchema(FlightSqlProducer.Schemas.GET_EXPORTED_KEYS_SCHEMA, sqlClient.getExportedKeysSchema(options)); + validate(FlightSqlProducer.Schemas.GET_IMPORTED_KEYS_SCHEMA, sqlClient.getImportedKeys(TableRef.of("catalog", "db_schema", "table"), options), sqlClient); + validateSchema(FlightSqlProducer.Schemas.GET_IMPORTED_KEYS_SCHEMA, sqlClient.getImportedKeysSchema(options)); + validate(FlightSqlProducer.Schemas.GET_CROSS_REFERENCE_SCHEMA, sqlClient.getCrossReference(TableRef.of("pk_catalog", "pk_db_schema", "pk_table"), TableRef.of("fk_catalog", "fk_db_schema", "fk_table"), options), sqlClient); - validate(FlightSqlProducer.Schemas.GET_TYPE_INFO_SCHEMA, - sqlClient.getXdbcTypeInfo(options), sqlClient); + validateSchema(FlightSqlProducer.Schemas.GET_CROSS_REFERENCE_SCHEMA, sqlClient.getCrossReferenceSchema(options)); + + validate(FlightSqlProducer.Schemas.GET_TYPE_INFO_SCHEMA, sqlClient.getXdbcTypeInfo(options), sqlClient); + validateSchema(FlightSqlProducer.Schemas.GET_TYPE_INFO_SCHEMA, sqlClient.getXdbcTypeInfoSchema(options)); + validate(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, sqlClient.getSqlInfo(new FlightSql.SqlInfo[] {FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY}, options), sqlClient); + validateSchema(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, sqlClient.getSqlInfoSchema(options)); } private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception { @@ -105,6 +126,8 @@ private void validateStatementExecution(FlightSqlClient sqlClient) throws Except validate(FlightSqlScenarioProducer.getQuerySchema(), sqlClient.execute("SELECT STATEMENT", options), sqlClient); + validateSchema(FlightSqlScenarioProducer.getQuerySchema(), + sqlClient.getExecuteSchema("SELECT STATEMENT", options)); IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", options), UPDATE_STATEMENT_EXPECTED_ROWS); @@ -122,6 +145,7 @@ private void validatePreparedStatementExecution(FlightSqlClient sqlClient, validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(options), sqlClient); + validateSchema(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.fetchSchema()); } try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( @@ -139,4 +163,8 @@ private void validate(Schema expectedSchema, FlightInfo flightInfo, IntegrationAssertions.assertEquals(expectedSchema, actualSchema); } } + + private void validateSchema(Schema expected, SchemaResult actual) { + IntegrationAssertions.assertEquals(expected, actual.getSchema()); + } } diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java index 7db99187c466e..33d62b650e176 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java @@ -125,9 +125,18 @@ public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatem return getFlightInfoForSchema(command, descriptor, getQuerySchema()); } + @Override + public SchemaResult getSchemaPreparedStatement(FlightSql.CommandPreparedStatementQuery command, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), + "SELECT PREPARED STATEMENT HANDLE"); + return new SchemaResult(getQuerySchema()); + } + @Override public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery command, CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getQuery(), "SELECT STATEMENT"); return new SchemaResult(getQuerySchema()); } diff --git a/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java new file mode 100644 index 0000000000000..dfb9a810857ba --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.jupiter.api.Test; + +/** + * Run the integration test scenarios in-process. + */ +class IntegrationTest { + @Test + void authBasicProto() throws Exception { + testScenario("auth:basic_proto"); + } + + @Test + void middleware() throws Exception { + testScenario("middleware"); + } + + @Test + void flightSql() throws Exception { + testScenario("flight_sql"); + } + + void testScenario(String scenarioName) throws Exception { + try (final BufferAllocator allocator = new RootAllocator()) { + final FlightServer.Builder builder = FlightServer.builder() + .allocator(allocator) + .location(Location.forGrpcInsecure("0.0.0.0", 0)); + final Scenario scenario = Scenarios.getScenario(scenarioName); + scenario.buildServer(builder); + builder.producer(scenario.producer(allocator, Location.forGrpcInsecure("0.0.0.0", 0))); + + try (final FlightServer server = builder.build()) { + server.start(); + + final Location location = Location.forGrpcInsecure("localhost", server.getPort()); + try (final FlightClient client = FlightClient.builder(allocator, location).build()) { + scenario.client(allocator, location, client); + } + } + } + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java index dd9480f40041b..f1f07a1588f57 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java @@ -97,6 +97,16 @@ public FlightInfo execute(final String query, final CallOption... options) { return client.getInfo(descriptor, options); } + /** + * Get the schema of the result set of a query. + */ + public SchemaResult getExecuteSchema(final String query, final CallOption... options) { + final CommandStatementQuery.Builder builder = CommandStatementQuery.newBuilder(); + builder.setQuery(query); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getSchema(descriptor, options); + } + /** * Execute an update query on the server. * @@ -137,6 +147,17 @@ public FlightInfo getCatalogs(final CallOption... options) { return client.getInfo(descriptor, options); } + /** + * Get the schema of {@link #getCatalogs(CallOption...)} from the server. + * + *

Should be identical to {@link FlightSqlProducer.Schemas#GET_CATALOGS_SCHEMA}. + */ + public SchemaResult getCatalogsSchema(final CallOption... options) { + final CommandGetCatalogs command = CommandGetCatalogs.getDefaultInstance(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray()); + return client.getSchema(descriptor, options); + } + /** * Request a list of schemas. * @@ -160,6 +181,17 @@ public FlightInfo getSchemas(final String catalog, final String dbSchemaFilterPa return client.getInfo(descriptor, options); } + /** + * Get the schema of {@link #getSchemas(String, String, CallOption...)} from the server. + * + *

Should be identical to {@link FlightSqlProducer.Schemas#GET_SCHEMAS_SCHEMA}. + */ + public SchemaResult getSchemasSchema(final CallOption... options) { + final CommandGetDbSchemas command = CommandGetDbSchemas.getDefaultInstance(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray()); + return client.getSchema(descriptor, options); + } + /** * Get schema for a stream. * @@ -231,6 +263,17 @@ public FlightInfo getSqlInfo(final Iterable info, final CallOption... o return client.getInfo(descriptor, options); } + /** + * Get the schema of {@link #getSqlInfo(SqlInfo...)} from the server. + * + *

Should be identical to {@link FlightSqlProducer.Schemas#GET_SQL_INFO_SCHEMA}. + */ + public SchemaResult getSqlInfoSchema(final CallOption... options) { + final CommandGetSqlInfo command = CommandGetSqlInfo.getDefaultInstance(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray()); + return client.getSchema(descriptor, options); + } + /** * Request the information about the data types supported related to * a filter data type. @@ -261,6 +304,17 @@ public FlightInfo getXdbcTypeInfo(final CallOption... options) { return client.getInfo(descriptor, options); } + /** + * Get the schema of {@link #getXdbcTypeInfo(CallOption...)} from the server. + * + *

Should be identical to {@link FlightSqlProducer.Schemas#GET_TYPE_INFO_SCHEMA}. + */ + public SchemaResult getXdbcTypeInfoSchema(final CallOption... options) { + final CommandGetXdbcTypeInfo command = CommandGetXdbcTypeInfo.getDefaultInstance(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray()); + return client.getSchema(descriptor, options); + } + /** * Request a list of tables. * @@ -298,6 +352,18 @@ public FlightInfo getTables(final String catalog, final String dbSchemaFilterPat return client.getInfo(descriptor, options); } + /** + * Get the schema of {@link #getTables(String, String, String, List, boolean, CallOption...)} from the server. + * + *

Should be identical to {@link FlightSqlProducer.Schemas#GET_TABLES_SCHEMA} or + * {@link FlightSqlProducer.Schemas#GET_TABLES_SCHEMA_NO_SCHEMA}. + */ + public SchemaResult getTablesSchema(boolean includeSchema, final CallOption... options) { + final CommandGetTables command = CommandGetTables.newBuilder().setIncludeSchema(includeSchema).build(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray()); + return client.getSchema(descriptor, options); + } + /** * Request the primary keys for a table. * @@ -323,6 +389,17 @@ public FlightInfo getPrimaryKeys(final TableRef tableRef, final CallOption... op return client.getInfo(descriptor, options); } + /** + * Get the schema of {@link #getPrimaryKeys(TableRef, CallOption...)} from the server. + * + *

Should be identical to {@link FlightSqlProducer.Schemas#GET_PRIMARY_KEYS_SCHEMA}. + */ + public SchemaResult getPrimaryKeysSchema(final CallOption... options) { + final CommandGetPrimaryKeys command = CommandGetPrimaryKeys.getDefaultInstance(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray()); + return client.getSchema(descriptor, options); + } + /** * Retrieves a description about the foreign key columns that reference the primary key columns of the given table. * @@ -350,6 +427,17 @@ public FlightInfo getExportedKeys(final TableRef tableRef, final CallOption... o return client.getInfo(descriptor, options); } + /** + * Get the schema of {@link #getExportedKeys(TableRef, CallOption...)} from the server. + * + *

Should be identical to {@link FlightSqlProducer.Schemas#GET_EXPORTED_KEYS_SCHEMA}. + */ + public SchemaResult getExportedKeysSchema(final CallOption... options) { + final CommandGetExportedKeys command = CommandGetExportedKeys.getDefaultInstance(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray()); + return client.getSchema(descriptor, options); + } + /** * Retrieves the foreign key columns for the given table. * @@ -378,6 +466,17 @@ public FlightInfo getImportedKeys(final TableRef tableRef, return client.getInfo(descriptor, options); } + /** + * Get the schema of {@link #getImportedKeys(TableRef, CallOption...)} from the server. + * + *

Should be identical to {@link FlightSqlProducer.Schemas#GET_IMPORTED_KEYS_SCHEMA}. + */ + public SchemaResult getImportedKeysSchema(final CallOption... options) { + final CommandGetImportedKeys command = CommandGetImportedKeys.getDefaultInstance(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray()); + return client.getSchema(descriptor, options); + } + /** * Retrieves a description of the foreign key columns that reference the given table's * primary key columns (the foreign keys exported by a table). @@ -417,6 +516,17 @@ public FlightInfo getCrossReference(final TableRef pkTableRef, return client.getInfo(descriptor, options); } + /** + * Get the schema of {@link #getCrossReference(TableRef, TableRef, CallOption...)} from the server. + * + *

Should be identical to {@link FlightSqlProducer.Schemas#GET_CROSS_REFERENCE_SCHEMA}. + */ + public SchemaResult getCrossReferenceSchema(final CallOption... options) { + final CommandGetCrossReference command = CommandGetCrossReference.getDefaultInstance(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray()); + return client.getSchema(descriptor, options); + } + /** * Request a list of table types. * @@ -429,6 +539,17 @@ public FlightInfo getTableTypes(final CallOption... options) { return client.getInfo(descriptor, options); } + /** + * Get the schema of {@link #getTableTypes(CallOption...)} from the server. + * + *

Should be identical to {@link FlightSqlProducer.Schemas#GET_TABLE_TYPES_SCHEMA}. + */ + public SchemaResult getTableTypesSchema(final CallOption... options) { + final CommandGetTableTypes command = CommandGetTableTypes.getDefaultInstance(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray()); + return client.getSchema(descriptor, options); + } + /** * Create a prepared statement on the server. * @@ -534,6 +655,20 @@ public Schema getParameterSchema() { return parameterSchema; } + /** + * Get the schema of the result set (should be identical to {@link #getResultSetSchema()}). + */ + public SchemaResult fetchSchema(CallOption... options) { + checkOpen(); + + final FlightDescriptor descriptor = FlightDescriptor + .command(Any.pack(CommandPreparedStatementQuery.newBuilder() + .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) + .build()) + .toByteArray()); + return client.getSchema(descriptor, options); + } + private Schema deserializeSchema(final ByteString bytes) { try { return bytes.isEmpty() ? diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java index c617c6a03eec9..4226ec9e228cf 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java @@ -147,26 +147,32 @@ default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) if (command.is(CommandStatementQuery.class)) { return getSchemaStatement( FlightSqlUtils.unpackOrThrow(command, CommandStatementQuery.class), context, descriptor); + } else if (command.is(CommandPreparedStatementQuery.class)) { + return getSchemaPreparedStatement( + FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementQuery.class), context, descriptor); } else if (command.is(CommandGetCatalogs.class)) { return new SchemaResult(Schemas.GET_CATALOGS_SCHEMA); + } else if (command.is(CommandGetCrossReference.class)) { + return new SchemaResult(Schemas.GET_CROSS_REFERENCE_SCHEMA); } else if (command.is(CommandGetDbSchemas.class)) { return new SchemaResult(Schemas.GET_SCHEMAS_SCHEMA); + } else if (command.is(CommandGetExportedKeys.class)) { + return new SchemaResult(Schemas.GET_EXPORTED_KEYS_SCHEMA); + } else if (command.is(CommandGetImportedKeys.class)) { + return new SchemaResult(Schemas.GET_IMPORTED_KEYS_SCHEMA); + } else if (command.is(CommandGetPrimaryKeys.class)) { + return new SchemaResult(Schemas.GET_PRIMARY_KEYS_SCHEMA); } else if (command.is(CommandGetTables.class)) { - return new SchemaResult(Schemas.GET_TABLES_SCHEMA); + if (FlightSqlUtils.unpackOrThrow(command, CommandGetTables.class).getIncludeSchema()) { + return new SchemaResult(Schemas.GET_TABLES_SCHEMA); + } + return new SchemaResult(Schemas.GET_TABLES_SCHEMA_NO_SCHEMA); } else if (command.is(CommandGetTableTypes.class)) { return new SchemaResult(Schemas.GET_TABLE_TYPES_SCHEMA); } else if (command.is(CommandGetSqlInfo.class)) { return new SchemaResult(Schemas.GET_SQL_INFO_SCHEMA); } else if (command.is(CommandGetXdbcTypeInfo.class)) { return new SchemaResult(Schemas.GET_TYPE_INFO_SCHEMA); - } else if (command.is(CommandGetPrimaryKeys.class)) { - return new SchemaResult(Schemas.GET_PRIMARY_KEYS_SCHEMA); - } else if (command.is(CommandGetImportedKeys.class)) { - return new SchemaResult(Schemas.GET_IMPORTED_KEYS_SCHEMA); - } else if (command.is(CommandGetExportedKeys.class)) { - return new SchemaResult(Schemas.GET_EXPORTED_KEYS_SCHEMA); - } else if (command.is(CommandGetCrossReference.class)) { - return new SchemaResult(Schemas.GET_CROSS_REFERENCE_SCHEMA); } throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid command provided.").toRuntimeException(); @@ -336,16 +342,31 @@ FlightInfo getFlightInfoPreparedStatement(CommandPreparedStatementQuery command, CallContext context, FlightDescriptor descriptor); /** - * Gets schema about a particular SQL query based data stream. + * Get the schema of the result set of a query. * - * @param command The sql command to generate the data stream. + * @param command The SQL query. * @param context Per-call context. * @param descriptor The descriptor identifying the data stream. - * @return Schema for the stream. + * @return the schema of the result set. */ SchemaResult getSchemaStatement(CommandStatementQuery command, CallContext context, FlightDescriptor descriptor); + /** + * Get the schema of the result set of a prepared statement. + * + * @param command The prepared statement handle. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return the schema of the result set. + */ + default SchemaResult getSchemaPreparedStatement(CommandPreparedStatementQuery command, CallContext context, + FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED + .withDescription("GetSchema with CommandPreparedStatementQuery is not implemented") + .toRuntimeException(); + } + /** * Returns data for a SQL query based data stream. * @param ticket Ticket message containing the statement handle. diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java index 25affa8f08aaa..e461515c40ecd 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java @@ -76,7 +76,7 @@ public static T unpackOrThrow(Any source, Class as) { return source.unpack(as); } catch (final InvalidProtocolBufferException e) { throw CallStatus.INVALID_ARGUMENT - .withDescription("Provided message cannot be unpacked as desired type.") + .withDescription("Provided message cannot be unpacked as " + as.getName() + ": " + e) .withCause(e) .toRuntimeException(); }