From 6938db17724f51548c23d719885e89a6b8b3fa89 Mon Sep 17 00:00:00 2001 From: Jose Almeida <53087160+jcralmeida@users.noreply.github.com> Date: Tue, 12 Oct 2021 14:31:14 -0300 Subject: [PATCH] [CPP] Implements GetPrimaryKeys on flight sql server (#162) * Add Schema template for the primary keys * Implement GetPrimaryKeys on server * Add an integrated test for GetPrimaryKeys * Fix checkstyle * Add a comment to the query on primary keys query * Use GetFlightInfoForCommand helper method on GetFlightInfoPrimaryKeys Co-authored-by: Rafael Telles --- .../flight-sql/example/sqlite_server.cc | 45 +++++++++++++++++++ .../flight/flight-sql/example/sqlite_server.h | 11 ++++- .../arrow/flight/flight-sql/sql_server.cpp | 6 +++ cpp/src/arrow/flight/flight-sql/sql_server.h | 5 +++ .../flight/flight-sql/sql_server_test.cc | 26 +++++++++++ 5 files changed, 92 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/flight/flight-sql/example/sqlite_server.cc b/cpp/src/arrow/flight/flight-sql/example/sqlite_server.cc index 12e767c995424..e9882bea21552 100644 --- a/cpp/src/arrow/flight/flight-sql/example/sqlite_server.cc +++ b/cpp/src/arrow/flight/flight-sql/example/sqlite_server.cc @@ -311,6 +311,51 @@ Status SQLiteFlightSqlServer::DoGetTableTypes(const ServerCallContext& context, return Status::OK(); } +Status SQLiteFlightSqlServer::GetFlightInfoPrimaryKeys( + const pb::sql::CommandGetPrimaryKeys& command, const ServerCallContext& context, + const FlightDescriptor& descriptor, std::unique_ptr* info) { + return GetFlightInfoForCommand(descriptor, info, command, + SqlSchema::GetPrimaryKeysSchema()); +} + +Status +SQLiteFlightSqlServer::DoGetPrimaryKeys(const pb::sql::CommandGetPrimaryKeys &command, + const ServerCallContext &context, + std::unique_ptr *result) { + std::stringstream table_query; + + // The field key_name can not be recovered by the sqlite, so it is being set + // to null following the same pattern for catalog_name and schema_name. + table_query << "SELECT null as catalog_name, null as schema_name, table_name, " + "name as column_name, pk as key_sequence, null as key_name\n" + "FROM pragma_table_info(table_name)\n" + " JOIN (SELECT null as catalog_name, null as schema_name, name as " + "table_name, type as table_type\n" + "FROM sqlite_master) where 1=1 and pk != 0"; + + if (command.has_catalog()) { + table_query << " and catalog_name LIKE '" << command.catalog() << "'"; + } + + if (command.has_schema()) { + table_query << " and schema_name LIKE '" << command.schema() << "'"; + } + + table_query << " and table_name LIKE '" << command.table() << "'"; + + std::shared_ptr statement; + ARROW_RETURN_NOT_OK(SqliteStatement::Create(db_, table_query.str(), &statement)); + + std::shared_ptr reader; + ARROW_RETURN_NOT_OK(SqliteStatementBatchReader::Create( + statement, SqlSchema::GetPrimaryKeysSchema(), &reader)); + + *result = std::unique_ptr( + new RecordBatchStream(reader)); + + return Status::OK(); +} + } // namespace example } // namespace sql } // namespace flight diff --git a/cpp/src/arrow/flight/flight-sql/example/sqlite_server.h b/cpp/src/arrow/flight/flight-sql/example/sqlite_server.h index af5b46db6cbf9..ea8f167d367e7 100644 --- a/cpp/src/arrow/flight/flight-sql/example/sqlite_server.h +++ b/cpp/src/arrow/flight/flight-sql/example/sqlite_server.h @@ -80,7 +80,16 @@ class SQLiteFlightSqlServer : public FlightSqlServerBase { Status DoGetTableTypes(const ServerCallContext &context, std::unique_ptr *result) override; - private: + Status GetFlightInfoPrimaryKeys(const pb::sql::CommandGetPrimaryKeys &command, + const ServerCallContext &context, + const FlightDescriptor &descriptor, + std::unique_ptr *info) override; + + Status DoGetPrimaryKeys(const pb::sql::CommandGetPrimaryKeys &command, + const ServerCallContext &context, + std::unique_ptr *result) override; + +private: sqlite3* db_; }; diff --git a/cpp/src/arrow/flight/flight-sql/sql_server.cpp b/cpp/src/arrow/flight/flight-sql/sql_server.cpp index 221edee390410..87a22b5806f7e 100644 --- a/cpp/src/arrow/flight/flight-sql/sql_server.cpp +++ b/cpp/src/arrow/flight/flight-sql/sql_server.cpp @@ -294,6 +294,12 @@ std::shared_ptr SqlSchema::GetTableTypesSchema() { return arrow::schema({field("table_type", utf8())}); } +std::shared_ptr SqlSchema::GetPrimaryKeysSchema() { +return arrow::schema({field("catalog_name", utf8()), field("schema_name", utf8()), + field("table_name", utf8()), field("column_name", utf8()), + field("key_sequence", int64()), field("key_name", utf8())}); +} + } // namespace sql } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/flight-sql/sql_server.h b/cpp/src/arrow/flight/flight-sql/sql_server.h index 25af7eb64ce61..515b36eba5c10 100644 --- a/cpp/src/arrow/flight/flight-sql/sql_server.h +++ b/cpp/src/arrow/flight/flight-sql/sql_server.h @@ -298,6 +298,11 @@ class SqlSchema { /// \brief Gets the Schema used on CommandGetTableTypes response. /// \return The default schema template. static std::shared_ptr GetTableTypesSchema(); + + /// \brief Gets the Schema used on CommandGetPrimaryKeys response when included schema + /// flags is set to true. + /// \return The default schema template. + static std::shared_ptr GetPrimaryKeysSchema(); }; } // namespace sql } // namespace flight diff --git a/cpp/src/arrow/flight/flight-sql/sql_server_test.cc b/cpp/src/arrow/flight/flight-sql/sql_server_test.cc index 2ffbf96c80681..239238135c8dd 100644 --- a/cpp/src/arrow/flight/flight-sql/sql_server_test.cc +++ b/cpp/src/arrow/flight/flight-sql/sql_server_test.cc @@ -323,6 +323,32 @@ TEST(TestFlightSqlServer, TestCommandStatementUpdate) { ASSERT_EQ(3, result); } +TEST(TestFlightSqlServer, TestCommandGetPrimaryKeys) { + std::unique_ptr flight_info; + std::vector table_types; + ASSERT_OK(sql_client->GetPrimaryKeys({}, nullptr, nullptr, "int%", + &flight_info)); + + std::unique_ptr stream; + ASSERT_OK(sql_client->DoGet({}, flight_info->endpoints()[0].ticket, &stream)); + + std::shared_ptr table; + ASSERT_OK(stream->ReadAll(&table)); + + DECLARE_NULL_ARRAY(catalog_name, String, 1); + DECLARE_NULL_ARRAY(schema_name, String, 1); + DECLARE_ARRAY(table_name, String, ({"intTable"})); + DECLARE_ARRAY(column_name, String, ({"id"})); + DECLARE_ARRAY(key_sequence, Int64, ({1})); + DECLARE_NULL_ARRAY(key_name, String, 1); + + const std::shared_ptr
& expected_table = Table::Make( + SqlSchema::GetPrimaryKeysSchema(), + {catalog_name, schema_name, table_name, column_name, key_sequence, key_name}); + + ASSERT_TRUE(expected_table->Equals(*table)); +} + auto env = ::testing::AddGlobalTestEnvironment(new TestFlightSqlServer);